aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/main/stanza/compilers.stanza8
-rw-r--r--src/main/stanza/passes.stanza50
-rw-r--r--test/passes/lower-to-ground/bundle-vecs.fir34
-rw-r--r--test/passes/split-exp/split-in-when.fir18
4 files changed, 74 insertions, 36 deletions
diff --git a/src/main/stanza/compilers.stanza b/src/main/stanza/compilers.stanza
index 3ca4f8da..0d0191bf 100644
--- a/src/main/stanza/compilers.stanza
+++ b/src/main/stanza/compilers.stanza
@@ -78,8 +78,6 @@ public defmethod passes (c:StandardVerilog) -> List<Pass> :
;===============
ConstProp()
;===============
- SplitExp()
- ;===============
ResolveKinds()
InferTypes()
CheckTypes()
@@ -98,6 +96,8 @@ public defmethod passes (c:StandardVerilog) -> List<Pass> :
InferWidths()
CheckWidths()
;===============
+ VerilogWrap()
+ SplitExp()
VerilogRename()
Verilog(with-output(c))
;===============
@@ -152,8 +152,6 @@ public defmethod passes (c:StandardLoFIRRTL) -> List<Pass> :
;===============
ConstProp()
;===============
- SplitExp()
- ;===============
ResolveKinds()
InferTypes()
CheckTypes()
@@ -172,6 +170,8 @@ public defmethod passes (c:StandardLoFIRRTL) -> List<Pass> :
InferWidths()
CheckWidths()
;===============
+ SplitExp()
+ ;===============
FIRRTL(with-output(c))
]
diff --git a/src/main/stanza/passes.stanza b/src/main/stanza/passes.stanza
index 2ad3d596..f6fd1533 100644
--- a/src/main/stanza/passes.stanza
+++ b/src/main/stanza/passes.stanza
@@ -1767,7 +1767,53 @@ defn resolve (c:Circuit) -> Circuit :
; val top = (for m in modules(c) find : name(m) == main(c)) as InModule
; Circuit(info(c),list(InModule(info(top),name(top),ports(top),inline-inst(body(top)))),main(c))
+;;================= Verilog Wrap ========================
+
+; --------- Utils --------------
+
+;---------- Pass ---------------
+;; Intended to only work on low firrtl
+public defstruct VerilogWrap <: Pass
+public defmethod pass (b:VerilogWrap) -> (Circuit -> Circuit) : v-wrap
+public defmethod name (b:VerilogWrap) -> String : "Verilog Wrap"
+public defmethod short-name (b:VerilogWrap) -> String : "v-wrap"
+
+public definterface WPrimOp <: PrimOp
+val ADDW-OP = new WPrimOp
+val SUBW-OP = new WPrimOp
+
+defmethod print (o:OutputStream,op:WPrimOp) :
+ print{o, _} $ switch {op == _} :
+ ADDW-OP : "addw"
+ SUBW-OP : "subw"
+
+defn v-wrap-e (e:Expression) -> Expression :
+ match(map(v-wrap-e,e)) :
+ (e:DoPrim) :
+ if op(e) == TAIL-OP :
+ match(args(e)[0]) :
+ (e0:DoPrim) :
+ if op(e0) == ADD-OP :
+ DoPrim(ADDW-OP,args(e0),list(),type(e))
+ else if op(e0) == SUB-OP :
+ DoPrim(SUBW-OP,args(e0),list(),type(e))
+ else : e
+ (e0) : e
+ else : e
+ (e) : e
+defn v-wrap-s (s:Stmt) -> Stmt :
+ map{v-wrap-e,_} $ map(v-wrap-s,s)
+defn v-wrap (c:Circuit) -> Circuit :
+ val modules* = for m in modules(c) map :
+ match(m) :
+ (m:InModule) :
+ mname = name(m)
+ InModule(info(m),name(m),ports(m),v-wrap-s(body(m)))
+ (m:ExModule) : m
+ Circuit(info(c),modules*,main(c))
+
;;================= Split Expressions ========================
+
;; Intended to only work on low firrtl
public defstruct SplitExp <: Pass
public defmethod pass (b:SplitExp) -> (Circuit -> Circuit) : split-exp
@@ -1789,7 +1835,7 @@ defn split-exp (m:InModule) -> InModule :
add(v,DefNode(info(s),n,e))
WRef(n,type(e),kind(e),gender(e))
defn split-exp-e (e:Expression,i:Int) -> Expression :
- match(map(split-exp-e{_,i + 1},e)) :
+ match(map(split-exp-e{_,i + 1},e)) :
(e:DoPrim) :
if i > 0 : split(e)
else : e
@@ -2485,7 +2531,9 @@ defn op-stream (doprim:DoPrim) -> Streamable :
switch {_ == op(doprim)} :
ADD-OP : [cast-if(a0()) " + " cast-if(a1())]
+ ADDW-OP : [cast-if(a0()) " + " cast-if(a1())]
SUB-OP : [cast-if(a0()) " - " cast-if(a1())]
+ SUBW-OP : [cast-if(a0()) " - " cast-if(a1())]
MUL-OP : [cast-if(a0()) " * " cast-if(a1()) ]
DIV-OP : [cast-if(a0()) " / " cast-if(a1()) ]
REM-OP : [cast-if(a0()) " % " cast-if(a1()) ]
diff --git a/test/passes/lower-to-ground/bundle-vecs.fir b/test/passes/lower-to-ground/bundle-vecs.fir
index c42766ad..cf581ab7 100644
--- a/test/passes/lower-to-ground/bundle-vecs.fir
+++ b/test/passes/lower-to-ground/bundle-vecs.fir
@@ -19,26 +19,20 @@ circuit top :
j <= a[i]
a[i] <= j
-;CHECK: wire GEN_0 : UInt<32>
-;CHECK: wire GEN_1 : UInt<32>
-;CHECK: wire GEN_2 : UInt<32>
-;CHECK: wire GEN_3 : UInt<32>
-;CHECK: j_x <= GEN_0
-;CHECK: j_y <= GEN_3
-;CHECK: node GEN_4 = eq(UInt("h0"), i)
-;CHECK: a_0_x <= mux(GEN_4, GEN_2, UInt("h0"))
-;CHECK: node GEN_5 = eq(UInt("h0"), i)
-;CHECK: a_0_y <= mux(GEN_5, GEN_1, UInt("h0"))
-;CHECK: node GEN_6 = eq(UInt("h1"), i)
-;CHECK: a_1_x <= mux(GEN_6, GEN_2, UInt("h0"))
-;CHECK: node GEN_7 = eq(UInt("h1"), i)
-;CHECK: a_1_y <= mux(GEN_7, GEN_1, UInt("h0"))
-;CHECK: node GEN_8 = eq(UInt("h1"), i)
-;CHECK: GEN_0 <= mux(GEN_8, a_1_x, a_0_x)
-;CHECK: GEN_1 <= j_y
-;CHECK: GEN_2 <= j_x
-;CHECK: node GEN_9 = eq(UInt("h1"), i)
-;CHECK: GEN_3 <= mux(GEN_9, a_1_y, a_0_y)
+; CHECK: wire GEN_0 : UInt<32>
+; CHECK: wire GEN_1 : UInt<32>
+; CHECK: wire GEN_2 : UInt<32>
+; CHECK: wire GEN_3 : UInt<32>
+; CHECK: j_x <= GEN_0
+; CHECK: j_y <= GEN_3
+; CHECK: a_0_x <= mux(eq(UInt("h0"), i), GEN_2, UInt("h0"))
+; CHECK: a_0_y <= mux(eq(UInt("h0"), i), GEN_1, UInt("h0"))
+; CHECK: a_1_x <= mux(eq(UInt("h1"), i), GEN_2, UInt("h0"))
+; CHECK: a_1_y <= mux(eq(UInt("h1"), i), GEN_1, UInt("h0"))
+; CHECK: GEN_0 <= mux(eq(UInt("h1"), i), a_1_x, a_0_x)
+; CHECK: GEN_1 <= j_y
+; CHECK: GEN_2 <= j_x
+; CHECK: GEN_3 <= mux(eq(UInt("h1"), i), a_1_y, a_0_y)
; CHECK: Finished Lower Types
diff --git a/test/passes/split-exp/split-in-when.fir b/test/passes/split-exp/split-in-when.fir
index 06d1463d..207ad757 100644
--- a/test/passes/split-exp/split-in-when.fir
+++ b/test/passes/split-exp/split-in-when.fir
@@ -14,16 +14,12 @@ circuit Top :
when bits(tail(sub(a,c),1),3,3) : out <= mux(eq(bits(UInt(32),4,0),UInt(13)),tail(add(a,tail(add(b,c),1)),1),tail(sub(c,b),1))
-;CHECK: node GEN_0 = sub(a, c)
-;CHECK: node GEN_1 = tail(GEN_0, 1)
-;CHECK: node GEN_2 = bits(GEN_1, 3, 3)
-;CHECK: node GEN_3 = eq(UInt("h0"), UInt("hd"))
-;CHECK: node GEN_4 = add(b, c)
-;CHECK: node GEN_5 = tail(GEN_4, 1)
-;CHECK: node GEN_6 = add(a, GEN_5)
-;CHECK: node GEN_7 = tail(GEN_6, 1)
-;CHECK: node GEN_8 = sub(c, b)
-;CHECK: node GEN_9 = tail(GEN_8, 1)
-;CHECK: out <= mux(GEN_2, mux(GEN_3, GEN_7, GEN_9), out)
+;CHECK: node GEN_0 = subw(a, c)
+;CHECK: node GEN_1 = bits(GEN_0, 3, 3)
+;CHECK: node GEN_2 = eq(UInt("h0"), UInt("hd"))
+;CHECK: node GEN_3 = addw(b, c)
+;CHECK: node GEN_4 = addw(a, GEN_3)
+;CHECK: node GEN_5 = subw(c, b)
+;CHECK: out <= mux(GEN_1, mux(GEN_2, GEN_4, GEN_5), out)
;CHECK: Finished Split Expressions