aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/stanza/passes.stanza63
-rw-r--r--test/chisel3/Mul.fir61
2 files changed, 64 insertions, 60 deletions
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
index e67ec21f..429a1406 100644
--- a/src/main/stanza/passes.stanza
+++ b/src/main/stanza/passes.stanza
@@ -1349,6 +1349,11 @@ defmethod equal? (w1:Width,w2:Width) -> True|False :
if not contains?(args(w2),w) : ret(false)
ret(true)
(w1:IntWidth,w2:IntWidth) : width(w1) == width(w2)
+ (w1:PlusWidth,w2:PlusWidth) :
+ (arg1(w1) == arg1(w2) and arg2(w1) == arg2(w2)) or (arg1(w1) == arg2(w2) and arg2(w1) == arg1(w2))
+ (w1:MinusWidth,w2:MinusWidth) :
+ (arg1(w1) == arg1(w2) and arg2(w1) == arg2(w2)) or (arg1(w1) == arg2(w2) and arg2(w1) == arg1(w2))
+ (w1:ExpWidth,w2:ExpWidth) : arg1(w1) == arg1(w2)
(w1:UnknownWidth,w2:UnknownWidth) : true
(w1,w2) : false
defn apply (a:Int|False,b:Int|False, f: (Int,Int) -> Int) -> Int|False :
@@ -1380,6 +1385,18 @@ defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
for x in args(w*) do : add(v,x)
(w*) : add(v,w*)
MaxWidth(unique(v))
+ (w:PlusWidth) :
+ match(arg1(w),arg2(w)) :
+ (w1:IntWidth,w2:IntWidth) : IntWidth(width(w1) + width(w2))
+ (w1,w2) : w
+ (w:MinusWidth) :
+ match(arg1(w),arg2(w)) :
+ (w1:IntWidth,w2:IntWidth) : IntWidth(width(w1) - width(w2))
+ (w1,w2) : w
+ (w:ExpWidth) :
+ match(arg1(w)) :
+ (w1:IntWidth) : IntWidth(pow(2,width(w1)))
+ (w1) : w
(w) : w
defn substitute (w:Width,h:HashTable<Symbol,Width>) -> Width :
;println-all-debug(["Substituting for [" w "]"])
@@ -1458,25 +1475,25 @@ defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
; 2) Remove Cycles
; 3) Move to solved if not self-recursive
val u = make-unique(l)
- ;println-debug("======== UNIQUE CONSTRAINTS ========")
- ;for x in u do : println-debug(x)
- ;println-debug("====================================")
+ println-debug("======== UNIQUE CONSTRAINTS ========")
+ for x in u do : println-debug(x)
+ println-debug("====================================")
val f = HashTable<Symbol,Width>(symbol-hash)
val o = Vector<Symbol>()
for x in u do :
- ;println-debug("==== SOLUTIONS TABLE ====")
- ;for x in f do : println-debug(x)
- ;println-debug("=========================")
+ println-debug("==== SOLUTIONS TABLE ====")
+ for x in f do : println-debug(x)
+ println-debug("=========================")
val [n e] = [key(x) value(x)]
val e-sub = substitute(e,f)
- ;println-debug(["Solving " n " => " e])
- ;println-debug(["After Substitute: " n " => " e-sub])
- ;println-debug("==== SOLUTIONS TABLE (Post Substitute) ====")
- ;for x in f do : println-debug(x)
- ;println-debug("=========================")
+ println-debug(["Solving " n " => " e])
+ println-debug(["After Substitute: " n " => " e-sub])
+ println-debug("==== SOLUTIONS TABLE (Post Substitute) ====")
+ for x in f do : println-debug(x)
+ println-debug("=========================")
val e* = remove-cycle{n,_} $ e-sub
;println-debug(["After Remove Cycle: " n " => " e*])
if not self-rec?(n,e*) :
@@ -1485,28 +1502,28 @@ defn solve-constraints (l:List<WGeq>) -> HashTable<Symbol,Int> :
add(o,n)
f[n] = e*
- ;println-debug("Forward Solved Constraints")
- ;for x in f do : println-debug(x)
+ println-debug("Forward Solved Constraints")
+ for x in f do : println-debug(x)
; Backwards Solve
val b = HashTable<Symbol,Width>(symbol-hash)
for i in (length(o) - 1) through 0 by -1 do :
val n = o[i]
- ;println-all-debug(["SOLVE BACK: [" n " => " f[n] "]"])
- ;println-debug("==== SOLUTIONS TABLE ====")
- ;for x in b do : println-debug(x)
- ;println-debug("=========================")
+ println-all-debug(["SOLVE BACK: [" n " => " f[n] "]"])
+ println-debug("==== SOLUTIONS TABLE ====")
+ for x in b do : println-debug(x)
+ println-debug("=========================")
val e* = simplify(b-sub(f[n],b))
- ;println-all-debug(["BACK RETURN: [" n " => " e* "]"])
+ println-all-debug(["BACK RETURN: [" n " => " e* "]"])
b[n] = e*
- ;println-debug("==== SOLUTIONS TABLE (Post backsolve) ====")
- ;for x in b do : println-debug(x)
- ;println-debug("=========================")
+ println-debug("==== SOLUTIONS TABLE (Post backsolve) ====")
+ for x in b do : println-debug(x)
+ println-debug("=========================")
; Evaluate
val e = evaluate(b)
- ;println-debug("Evaluated Constraints")
- ;for x in e do : println-debug(x)
+ println-debug("Evaluated Constraints")
+ for x in e do : println-debug(x)
e
public defn width! (t:Type) -> Width :
diff --git a/test/chisel3/Mul.fir b/test/chisel3/Mul.fir
index 1ce6f797..ec991197 100644
--- a/test/chisel3/Mul.fir
+++ b/test/chisel3/Mul.fir
@@ -1,45 +1,32 @@
; RUN: firrtl -i %s -o %s.flo -x X -p c | tee %s.out | FileCheck %s
; CHECK: Done!
+
circuit Mul :
module Mul :
- input y : UInt<2>
input x : UInt<2>
output z : UInt<4>
+ output a : UInt<4>
+ input y : UInt<2>
- node T_43 = UInt<4>(0)
- node T_44 = UInt<4>(0)
- node T_45 = UInt<4>(0)
- node T_46 = UInt<4>(0)
- node T_47 = UInt<4>(0)
- node T_48 = UInt<4>(1)
- node T_49 = UInt<4>(2)
- node T_50 = UInt<4>(3)
- node T_51 = UInt<4>(0)
- node T_52 = UInt<4>(2)
- node T_53 = UInt<4>(4)
- node T_54 = UInt<4>(6)
- node T_55 = UInt<4>(0)
- node T_56 = UInt<4>(3)
- node T_57 = UInt<4>(6)
- node T_58 = UInt<4>(9)
wire tbl : UInt<4>[16]
- tbl[0] := T_43
- tbl[1] := T_44
- tbl[2] := T_45
- tbl[3] := T_46
- tbl[4] := T_47
- tbl[5] := T_48
- tbl[6] := T_49
- tbl[7] := T_50
- tbl[8] := T_51
- tbl[9] := T_52
- tbl[10] := T_53
- tbl[11] := T_54
- tbl[12] := T_55
- tbl[13] := T_56
- tbl[14] := T_57
- tbl[15] := T_58
- node T_60 = shl(x, 2)
- node T_61 = bit-or(T_60, y)
- accessor T_62 = tbl[T_61]
- z := T_62
+ tbl[0] := UInt<4>(0)
+ tbl[1] := UInt<4>(0)
+ tbl[2] := UInt<4>(0)
+ tbl[3] := UInt<4>(0)
+ tbl[4] := UInt<4>(0)
+ tbl[5] := UInt<4>(1)
+ tbl[6] := UInt<4>(2)
+ tbl[7] := UInt<4>(3)
+ tbl[8] := UInt<4>(0)
+ tbl[9] := UInt<4>(2)
+ tbl[10] := UInt<4>(4)
+ tbl[11] := UInt<4>(6)
+ tbl[12] := UInt<4>(0)
+ tbl[13] := UInt<4>(3)
+ tbl[14] := UInt<4>(6)
+ tbl[15] := UInt<4>(9)
+ node T_43 = shl(x, 2)
+ node ad = bit-or(Pad(T_43,?), Pad(y,?))
+ a := Pad(ad,?)
+ accessor T_44 = tbl[ad]
+ z := Pad(T_44,?)