aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorazidar2015-05-01 13:44:53 -0700
committerazidar2015-05-01 13:44:53 -0700
commit723c48b1ed0c341a10d1eba5a226787c33398505 (patch)
tree4cad751567699478358013536501e1a8c8bfe633
parent0a00a6aaa846b695a7a750cf40079d56a9bb94d6 (diff)
Fixed performance bug where PlusWidth, MinusWidth, and ExpWidth could be simplified earlier, and also now have equal? defined so mMaxWidth doesn't blow up during width inference
-rw-r--r--src/main/stanza/passes.stanza63
-rw-r--r--test/chisel3/Mul.fir61
2 files changed, 64 insertions, 60 deletions
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
index e67ec21f..429a1406 100644
--- a/src/main/stanza/passes.stanza
+++ b/src/main/stanza/passes.stanza
@@ -1349,6 +1349,11 @@ defmethod equal? (w1:Width,w2:Width) -> True|False :
if not contains?(args(w2),w) : ret(false)
ret(true)
(w1:IntWidth,w2:IntWidth) : width(w1) == width(w2)
+ (w1:PlusWidth,w2:PlusWidth) :
+ (arg1(w1) == arg1(w2) and arg2(w1) == arg2(w2)) or (arg1(w1) == arg2(w2) and arg2(w1) == arg1(w2))
+ (w1:MinusWidth,w2:MinusWidth) :
+ (arg1(w1) == arg1(w2) and arg2(w1) == arg2(w2)) or (arg1(w1) == arg2(w2) and arg2(w1) == arg1(w2))
+ (w1:ExpWidth,w2:ExpWidth) : arg1(w1) == arg1(w2)
(w1:UnknownWidth,w2:UnknownWidth) : true
(w1,w2) : false
defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False :
@@ -1380,6 +1385,18 @@ defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
for x in args(w*) do : add(v,x)
(w*) : add(v,w*)
MaxWidth(unique(v))
+ (w:PlusWidth) :
+ match(arg1(w),arg2(w)) :
+ (w1:IntWidth,w2:IntWidth) : IntWidth(width(w1) + width(w2))
+ (w1,w2) : w
+ (w:MinusWidth) :
+ match(arg1(w),arg2(w)) :
+ (w1:IntWidth,w2:IntWidth) : IntWidth(width(w1) - width(w2))
+ (w1,w2) : w
+ (w:ExpWidth) :
+ match(arg1(w)) :
+ (w1:IntWidth) : IntWidth(pow(2,width(w1)))
+ (w1) : w
(w) : w
defn substitute (w:Width,h:HashTable<Symbol,Width>) -> Width :
;println-all-debug(["Substituting for [" w "]"])
@@ -1458,25 +1475,25 @@ defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
; 2) Remove Cycles
; 3) Move to solved if not self-recursive
val u = make-unique(l)
- ;println-debug("======== UNIQUE CONSTRAINTS ========")
- ;for x in u do : println-debug(x)
- ;println-debug("====================================")
+ println-debug("======== UNIQUE CONSTRAINTS ========")
+ for x in u do : println-debug(x)
+ println-debug("====================================")
val f = HashTable<Symbol,Width>(symbol-hash)
val o = Vector<Symbol>()
for x in u do :
- ;println-debug("==== SOLUTIONS TABLE ====")
- ;for x in f do : println-debug(x)
- ;println-debug("=========================")
+ println-debug("==== SOLUTIONS TABLE ====")
+ for x in f do : println-debug(x)
+ println-debug("=========================")
val [n e] = [key(x) value(x)]
val e-sub = substitute(e,f)
- ;println-debug(["Solving " n " => " e])
- ;println-debug(["After Substitute: " n " => " e-sub])
- ;println-debug("==== SOLUTIONS TABLE (Post Substitute) ====")
- ;for x in f do : println-debug(x)
- ;println-debug("=========================")
+ println-debug(["Solving " n " => " e])
+ println-debug(["After Substitute: " n " => " e-sub])
+ println-debug("==== SOLUTIONS TABLE (Post Substitute) ====")
+ for x in f do : println-debug(x)
+ println-debug("=========================")
val e* = remove-cycle{n,_} $ e-sub
;println-debug(["After Remove Cycle: " n " => " e*])
if not self-rec?(n,e*) :
@@ -1485,28 +1502,28 @@ defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
add(o,n)
f[n] = e*
- ;println-debug("Forward Solved Constraints")
- ;for x in f do : println-debug(x)
+ println-debug("Forward Solved Constraints")
+ for x in f do : println-debug(x)
; Backwards Solve
val b = HashTable<Symbol,Width>(symbol-hash)
for i in (length(o) - 1) through 0 by -1 do :
val n = o[i]
- ;println-all-debug(["SOLVE BACK: [" n " => " f[n] "]"])
- ;println-debug("==== SOLUTIONS TABLE ====")
- ;for x in b do : println-debug(x)
- ;println-debug("=========================")
+ println-all-debug(["SOLVE BACK: [" n " => " f[n] "]"])
+ println-debug("==== SOLUTIONS TABLE ====")
+ for x in b do : println-debug(x)
+ println-debug("=========================")
val e* = simplify(b-sub(f[n],b))
- ;println-all-debug(["BACK RETURN: [" n " => " e* "]"])
+ println-all-debug(["BACK RETURN: [" n " => " e* "]"])
b[n] = e*
- ;println-debug("==== SOLUTIONS TABLE (Post backsolve) ====")
- ;for x in b do : println-debug(x)
- ;println-debug("=========================")
+ println-debug("==== SOLUTIONS TABLE (Post backsolve) ====")
+ for x in b do : println-debug(x)
+ println-debug("=========================")
; Evaluate
val e = evaluate(b)
- ;println-debug("Evaluated Constraints")
- ;for x in e do : println-debug(x)
+ println-debug("Evaluated Constraints")
+ for x in e do : println-debug(x)
e
public defn width! (t:Type) -> Width :
diff --git a/test/chisel3/Mul.fir b/test/chisel3/Mul.fir
index 1ce6f797..ec991197 100644
--- a/test/chisel3/Mul.fir
+++ b/test/chisel3/Mul.fir
@@ -1,45 +1,32 @@
; RUN: firrtl -i %s -o %s.flo -x X -p c | tee %s.out | FileCheck %s
; CHECK: Done!
+
circuit Mul :
module Mul :
- input y : UInt<2>
input x : UInt<2>
output z : UInt<4>
+ output a : UInt<4>
+ input y : UInt<2>
- node T_43 = UInt<4>(0)
- node T_44 = UInt<4>(0)
- node T_45 = UInt<4>(0)
- node T_46 = UInt<4>(0)
- node T_47 = UInt<4>(0)
- node T_48 = UInt<4>(1)
- node T_49 = UInt<4>(2)
- node T_50 = UInt<4>(3)
- node T_51 = UInt<4>(0)
- node T_52 = UInt<4>(2)
- node T_53 = UInt<4>(4)
- node T_54 = UInt<4>(6)
- node T_55 = UInt<4>(0)
- node T_56 = UInt<4>(3)
- node T_57 = UInt<4>(6)
- node T_58 = UInt<4>(9)
wire tbl : UInt<4>[16]
- tbl[0] := T_43
- tbl[1] := T_44
- tbl[2] := T_45
- tbl[3] := T_46
- tbl[4] := T_47
- tbl[5] := T_48
- tbl[6] := T_49
- tbl[7] := T_50
- tbl[8] := T_51
- tbl[9] := T_52
- tbl[10] := T_53
- tbl[11] := T_54
- tbl[12] := T_55
- tbl[13] := T_56
- tbl[14] := T_57
- tbl[15] := T_58
- node T_60 = shl(x, 2)
- node T_61 = bit-or(T_60, y)
- accessor T_62 = tbl[T_61]
- z := T_62
+ tbl[0] := UInt<4>(0)
+ tbl[1] := UInt<4>(0)
+ tbl[2] := UInt<4>(0)
+ tbl[3] := UInt<4>(0)
+ tbl[4] := UInt<4>(0)
+ tbl[5] := UInt<4>(1)
+ tbl[6] := UInt<4>(2)
+ tbl[7] := UInt<4>(3)
+ tbl[8] := UInt<4>(0)
+ tbl[9] := UInt<4>(2)
+ tbl[10] := UInt<4>(4)
+ tbl[11] := UInt<4>(6)
+ tbl[12] := UInt<4>(0)
+ tbl[13] := UInt<4>(3)
+ tbl[14] := UInt<4>(6)
+ tbl[15] := UInt<4>(9)
+ node T_43 = shl(x, 2)
+ node ad = bit-or(Pad(T_43,?), Pad(y,?))
+ a := Pad(ad,?)
+ accessor T_44 = tbl[ad]
+ z := Pad(T_44,?)