aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAndrew Waterman2016-04-06 21:55:53 -0700
committerjackkoenig2016-04-07 13:50:54 -0700
commitecc5c3d0934b11a9b727390853f84996c13dbb42 (patch)
tree1f4c7c8366cb88c6497070497ad92a51d8a43c72 /src
parentc99be4e7fa4359e9298a3ff57ff73b35db1684b7 (diff)
Make ConstProp pass more concise
I was going to augment it, but thought it best to clean it up first.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala105
1 files changed, 45 insertions, 60 deletions
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 0bedcbeb..72b96fb5 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -1164,68 +1164,53 @@ object Legalize extends Pass {
}
object ConstProp extends Pass {
- def name = "Constant Propogation"
- var mname = ""
+ def name = "Constant Propogation"
- def const_prop_e (e:Expression) : Expression = {
- e map (const_prop_e) match {
- case (e:DoPrim) => {
- e.op match {
- case SHIFT_RIGHT_OP => {
- val amount = e.consts(0).toInt
- e.args(0) match {
- case x: UIntValue => {
- val v = x.value >> amount
- val w = (x.width - IntWidth(amount)) max IntWidth(1)
- UIntValue(v, w)
- }
- case x: SIntValue => { // take sign bit if shift amount is larger than arg width
- val v = x.value >> amount
- val w = (x.width - IntWidth(amount)) max IntWidth(1)
- SIntValue(v, w)
- }
- case _ => e
- }
- }
- case BITS_SELECT_OP => {
- e.args(0) match {
- case (x:UIntValue) => {
- val hi = e.consts(0).toInt
- val lo = e.consts(1).toInt
- require(hi >= lo)
- val b = (x.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1)
- UIntValue(b,tpe(e).as[UIntType].get.width)
- }
- case (x) => {
- if (long_BANG(tpe(e)) == long_BANG(tpe(x))) {
- tpe(x) match {
- case (t:UIntType) => x
- case _ => DoPrim(AS_UINT_OP,Seq(x),Seq(),tpe(e))
- }
- }
- else e
- }
- }
- }
- case (_) => e
- }
- }
- case (e) => e
+ private def constPropPrim(e: DoPrim): Expression = e.op match {
+ case SHIFT_RIGHT_OP => {
+ val amount = e.consts(0).toInt
+ def shiftWidth(w: Width) = (w - IntWidth(amount)) max IntWidth(1)
+ e.args(0) match {
+ // TODO when amount >= x.width, return a zero-width wire
+ case UIntValue(v, w) => UIntValue(v >> amount, shiftWidth(w))
+ // take sign bit if shift amount is larger than arg width
+ case SIntValue(v, w) => SIntValue(v >> amount, shiftWidth(w))
+ case _ => e
}
- }
- def const_prop_s (s:Stmt) : Stmt = s map (const_prop_s) map (const_prop_e)
- def run (c:Circuit): Circuit = {
- val modulesx = c.modules.map{ m => {
- m match {
- case (m:ExModule) => m
- case (m:InModule) => {
- mname = m.name
- InModule(m.info,m.name,m.ports,const_prop_s(m.body))
- }
- }
- }}
- Circuit(c.info,modulesx,c.main)
- }
+ }
+ case BITS_SELECT_OP => e.args(0) match {
+ case UIntValue(v, w) => {
+ val hi = e.consts(0).toInt
+ val lo = e.consts(1).toInt
+ require(hi >= lo)
+ UIntValue((v >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), w)
+ }
+ case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match {
+ case t: UIntType => x
+ case _ => DoPrim(AS_UINT_OP, Seq(x), Seq(), tpe(e))
+ }
+ case _ => e
+ }
+ case _ => e
+ }
+
+ private def constPropExpression(e: Expression): Expression = {
+ e map constPropExpression match {
+ case p: DoPrim => constPropPrim(p)
+ case x => x
+ }
+ }
+
+ private def constPropStmt(s: Stmt): Stmt =
+ s map constPropStmt map constPropExpression
+
+ def run(c: Circuit): Circuit = {
+ val modulesx = c.modules.map {
+ case m: ExModule => m
+ case m: InModule => InModule(m.info, m.name, m.ports, constPropStmt(m.body))
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
}
object LoToVerilog extends Pass with StanzaPass {