diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 91 | ||||
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 38 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 1225 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckChirrtl.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 71 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ConstProp.scala | 10 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 82 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/PadWidths.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 72 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/SplitExpressions.scala | 14 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Uniquify.scala | 11 |
14 files changed, 617 insertions, 1018 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index a4f5c14d..6c658257 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -68,9 +68,9 @@ class VerilogEmitter extends Emitter { var mname = "" def wref (n:String,t:Type) = WRef(n,t,ExpKind(),UNKNOWNGENDER) def remove_root (ex:Expression) : Expression = { - (ex.as[WSubField].get.exp) match { + (ex.asInstanceOf[WSubField].exp) match { case (e:WSubField) => remove_root(e) - case (e:WRef) => WRef(ex.as[WSubField].get.name,tpe(ex),InstanceKind(),UNKNOWNGENDER) + case (e:WRef) => WRef(ex.asInstanceOf[WSubField].name,ex.tpe,InstanceKind(),UNKNOWNGENDER) } } def not_empty (s:ArrayBuffer[_]) : Boolean = if (s.size == 0) false else true @@ -120,7 +120,7 @@ class VerilogEmitter extends Emitter { case (i:Long) => w.get.write(i.toString) case (t:VIndent) => w.get.write(" ") case (s:Seq[Any]) => { - s.foreach((x:Any) => emit2(x.as[Any].get, top + 1)) + s.foreach((x:Any) => emit2(x, top + 1)) if (top == 0) w.get.write("\n") } } @@ -142,9 +142,12 @@ class VerilogEmitter extends Emitter { } def op_stream (doprim:DoPrim) : Seq[Any] = { def cast_if (e:Expression) : Any = { - val signed = doprim.args.find(x => tpe(x).typeof[SIntType]) + val signed = doprim.args.find(x => x.tpe match { + case _: SIntType => true + case _ => false + }) if (signed == None) e - else tpe(e) match { + else e.tpe match { case (t:SIntType) => Seq("$signed(",e,")") case (t:UIntType) => Seq("$signed({1'b0,",e,"})") } @@ -156,7 +159,7 @@ class VerilogEmitter extends Emitter { } } def cast_as (e:Expression) : Any = { - (tpe(e)) match { + (e.tpe) match { case (t:UIntType) => e case (t:SIntType) => Seq("$signed(",e,")") } @@ -192,7 +195,7 @@ class VerilogEmitter extends Emitter { case Eq => Seq(cast_if(a0())," == ", cast_if(a1())) case Neq => Seq(cast_if(a0())," != ", cast_if(a1())) case Pad => { - val w = long_BANG(tpe(a0())) + val w = long_BANG(a0().tpe) val diff = (c0() - w) if (w == 0) Seq(a0()) else doprim.tpe match { @@ -219,13 +222,13 @@ class VerilogEmitter extends Emitter { case Shlw => Seq(cast(a0())," << ", c0()) case Shl => Seq(cast(a0())," << ",c0()) case Shr => { - if (c0 >= long_BANG(tpe(a0))) + if (c0 >= long_BANG(a0.tpe)) error("Verilog emitter does not support SHIFT_RIGHT >= arg width") - Seq(a0(),"[", long_BANG(tpe(a0())) - 1,":",c0(),"]") + Seq(a0(),"[", long_BANG(a0().tpe) - 1,":",c0(),"]") } case Neg => Seq("-{",cast(a0()),"}") case Cvt => { - tpe(a0()) match { + a0().tpe match { case (t:UIntType) => Seq("{1'b0,",cast(a0()),"}") case (t:SIntType) => Seq(cast(a0())) } @@ -258,18 +261,18 @@ class VerilogEmitter extends Emitter { case Cat => Seq("{",cast(a0()),",",cast(a1()),"}") case Bits => { // If selecting zeroth bit and single-bit wire, just emit the wire - if (c0() == 0 && c1() == 0 && long_BANG(tpe(a0())) == 1) Seq(a0()) + if (c0() == 0 && c1() == 0 && long_BANG(a0().tpe) == 1) Seq(a0()) else if (c0() == c1()) Seq(a0(),"[",c0(),"]") else Seq(a0(),"[",c0(),":",c1(),"]") } case Head => { - val w = long_BANG(tpe(a0())) + val w = long_BANG(a0().tpe) val high = w - 1 val low = w - c0() Seq(a0(),"[",high,":",low,"]") } case Tail => { - val w = long_BANG(tpe(a0())) + val w = long_BANG(a0().tpe) val low = w - c0() - 1 Seq(a0(),"[",low,":",0,"]") } @@ -286,7 +289,7 @@ class VerilogEmitter extends Emitter { case (s:Connect) => netlist(s.loc) = s.expr case (s:IsInvalid) => { val n = namespace.newTemp - val e = wref(n,tpe(s.expr)) + val e = wref(n,s.expr.tpe) netlist(s.expr) = e } case (s:Conditionally) => simlist += s @@ -319,12 +322,12 @@ class VerilogEmitter extends Emitter { assigns += Seq("`ifndef RANDOMIZE_GARBAGE_ASSIGN") assigns += Seq("assign ", e, " = ", syn, ";") assigns += Seq("`else") - assigns += Seq("assign ", e, " = ", garbageCond, " ? ", rand_string(tpe(syn)), " : ", syn, ";") + assigns += Seq("assign ", e, " = ", garbageCond, " ? ", rand_string(syn.tpe), " : ", syn, ";") assigns += Seq("`endif") } def invalidAssign(e: Expression) = { assigns += Seq("`ifdef RANDOMIZE_INVALID_ASSIGN") - assigns += Seq("assign ", e, " = ", rand_string(tpe(e)), ";") + assigns += Seq("assign ", e, " = ", rand_string(e.tpe), ";") assigns += Seq("`endif") } def update_and_reset(r: Expression, clk: Expression, reset: Expression, init: Expression) = { @@ -387,7 +390,7 @@ class VerilogEmitter extends Emitter { } def initialize(e: Expression) = { initials += Seq("`ifdef RANDOMIZE_REG_INIT") - initials += Seq(e, " = ", rand_string(tpe(e)), ";") + initials += Seq(e, " = ", rand_string(e.tpe), ";") initials += Seq("`endif") } def initialize_mem(s: DefMemory) = { @@ -407,8 +410,8 @@ class VerilogEmitter extends Emitter { }} instdeclares += Seq(");") for (e <- es) { - declare("wire",LowerTypes.loweredName(e),tpe(e)) - val ex = WRef(LowerTypes.loweredName(e),tpe(e),kind(e),gender(e)) + declare("wire",LowerTypes.loweredName(e),e.tpe) + val ex = WRef(LowerTypes.loweredName(e),e.tpe,kind(e),gender(e)) if (gender(e) == FEMALE) { assign(ex,netlist(e)) } @@ -444,8 +447,8 @@ class VerilogEmitter extends Emitter { def delay (e:Expression, n:Int, clk:Expression) : Expression = { ((0 until n) foldLeft e){(ex, i) => val name = namespace.newTemp - declare("reg",name,tpe(e)) - val exx = WRef(name,tpe(e),ExpKind(),UNKNOWNGENDER) + declare("reg",name,e.tpe) + val exx = WRef(name,e.tpe,ExpKind(),UNKNOWNGENDER) initialize(exx) update(exx,ex,clk,one) exx @@ -478,13 +481,13 @@ class VerilogEmitter extends Emitter { initialize(e) } case (s:IsInvalid) => { - val wref = netlist(s.expr).as[WRef].get - declare("wire",wref.name,tpe(s.expr)) + val wref = netlist(s.expr).asInstanceOf[WRef] + declare("wire",wref.name,s.expr.tpe) invalidAssign(wref) } case (s:DefNode) => { - declare("wire",s.name,tpe(s.value)) - assign(WRef(s.name,tpe(s.value),NodeKind(),MALE),s.value) + declare("wire",s.name,s.value.tpe) + assign(WRef(s.name,s.value.tpe,NodeKind(),MALE),s.value) } case (s:Stop) => { val errorString = StringLit(s"${s.ret}\n".getBytes) @@ -513,9 +516,9 @@ class VerilogEmitter extends Emitter { //Ports should share an always@posedge, so can't have intermediary wire val clk = netlist(mem_exp(r,"clk")) - declare("wire",LowerTypes.loweredName(data),tpe(data)) - declare("wire",LowerTypes.loweredName(addr),tpe(addr)) - declare("wire",LowerTypes.loweredName(en),tpe(en)) + declare("wire",LowerTypes.loweredName(data),data.tpe) + declare("wire",LowerTypes.loweredName(addr),addr.tpe) + declare("wire",LowerTypes.loweredName(en),en.tpe) //; Read port assign(addr,netlist(addr)) //;Connects value to m.r.addr @@ -524,8 +527,8 @@ class VerilogEmitter extends Emitter { val en_pipe = if (weq(en,one)) one else delay(en,s.readLatency-1,clk) val addrx = if (s.readLatency > 0) { val name = namespace.newTemp - val ref = WRef(name,tpe(addr),ExpKind(),UNKNOWNGENDER) - declare("reg",name,tpe(addr)) + val ref = WRef(name,addr.tpe,ExpKind(),UNKNOWNGENDER) + declare("reg",name,addr.tpe) initialize(ref) update(ref,addr_pipe,clk,en_pipe) ref @@ -548,10 +551,10 @@ class VerilogEmitter extends Emitter { //Ports should share an always@posedge, so can't have intermediary wire val clk = netlist(mem_exp(w,"clk")) - declare("wire",LowerTypes.loweredName(data),tpe(data)) - declare("wire",LowerTypes.loweredName(addr),tpe(addr)) - declare("wire",LowerTypes.loweredName(mask),tpe(mask)) - declare("wire",LowerTypes.loweredName(en),tpe(en)) + declare("wire",LowerTypes.loweredName(data),data.tpe) + declare("wire",LowerTypes.loweredName(addr),addr.tpe) + declare("wire",LowerTypes.loweredName(mask),mask.tpe) + declare("wire",LowerTypes.loweredName(en),en.tpe) //; Write port assign(data,netlist(data)) @@ -577,12 +580,12 @@ class VerilogEmitter extends Emitter { //Ports should share an always@posedge, so can't have intermediary wire val clk = netlist(mem_exp(rw,"clk")) - declare("wire",LowerTypes.loweredName(wmode),tpe(wmode)) - declare("wire",LowerTypes.loweredName(rdata),tpe(rdata)) - declare("wire",LowerTypes.loweredName(wdata),tpe(wdata)) - declare("wire",LowerTypes.loweredName(wmask),tpe(wmask)) - declare("wire",LowerTypes.loweredName(addr),tpe(addr)) - declare("wire",LowerTypes.loweredName(en),tpe(en)) + declare("wire",LowerTypes.loweredName(wmode),wmode.tpe) + declare("wire",LowerTypes.loweredName(rdata),rdata.tpe) + declare("wire",LowerTypes.loweredName(wdata),wdata.tpe) + declare("wire",LowerTypes.loweredName(wmask),wmask.tpe) + declare("wire",LowerTypes.loweredName(addr),addr.tpe) + declare("wire",LowerTypes.loweredName(en),en.tpe) //; Assigned to lowered wires of each assign(addr,netlist(addr)) @@ -602,8 +605,8 @@ class VerilogEmitter extends Emitter { val raddrxx = if (s.readLatency > 0) { val name = namespace.newTemp - val ref = WRef(name,tpe(raddrx),ExpKind(),UNKNOWNGENDER) - declare("reg",name,tpe(raddrx)) + val ref = WRef(name,raddrx.tpe,ExpKind(),UNKNOWNGENDER) + declare("reg",name,raddrx.tpe) initialize(ref) ref } else addr @@ -613,8 +616,8 @@ class VerilogEmitter extends Emitter { def declare_and_assign(exp: Expression) = { val name = namespace.newTemp - val ref = wref(name, tpe(exp)) - declare("wire", name, tpe(exp)) + val ref = wref(name, exp.tpe) + declare("wire", name, exp.tpe) assign(ref, exp) ref } diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index 1bf8947a..8b705b29 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -146,20 +146,20 @@ object PrimOps extends LazyLogging { o match { case Add => { val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => UIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:UIntType, t2:UIntType) => UIntType(PLUS(MAX(w1(),w2()),IntWidth(1))) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1))) + case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1))) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1))) case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } case Sub => { val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:UIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1))) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1))) + case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1))) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1))) case (t1, t2) => UnknownType } DoPrim(o,a,c,t) @@ -177,9 +177,9 @@ object PrimOps extends LazyLogging { case Div => { val t = (t1(),t2()) match { case (t1:UIntType, t2:UIntType) => UIntType(w1()) - case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),IntWidth(1))) case (t1:SIntType, t2:UIntType) => SIntType(w1()) - case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),IntWidth(1))) case (t1, t2) => UnknownType } DoPrim(o,a,c,t) @@ -188,7 +188,7 @@ object PrimOps extends LazyLogging { val t = (t1(),t2()) match { case (t1:UIntType, t2:UIntType) => UIntType(MIN(w1(),w2())) case (t1:UIntType, t2:SIntType) => UIntType(MIN(w1(),w2())) - case (t1:SIntType, t2:UIntType) => SIntType(MIN(w1(),PLUS(w2(),Utils.ONE))) + case (t1:SIntType, t2:UIntType) => SIntType(MIN(w1(),PLUS(w2(),IntWidth(1)))) case (t1:SIntType, t2:SIntType) => SIntType(MIN(w1(),w2())) case (t1, t2) => UnknownType } @@ -266,7 +266,7 @@ object PrimOps extends LazyLogging { val t = (t1()) match { case (t1:UIntType) => UIntType(w1()) case (t1:SIntType) => UIntType(w1()) - case ClockType => UIntType(Utils.ONE) + case ClockType => UIntType(IntWidth(1)) case (t1) => UnknownType } DoPrim(o,a,c,t) @@ -275,7 +275,7 @@ object PrimOps extends LazyLogging { val t = (t1()) match { case (t1:UIntType) => SIntType(w1()) case (t1:SIntType) => SIntType(w1()) - case ClockType => SIntType(Utils.ONE) + case ClockType => SIntType(IntWidth(1)) case (t1) => UnknownType } DoPrim(o,a,c,t) @@ -299,8 +299,8 @@ object PrimOps extends LazyLogging { } case Shr => { val t = (t1()) match { - case (t1:UIntType) => UIntType(MAX(MINUS(w1(),c1()),Utils.ONE)) - case (t1:SIntType) => SIntType(MAX(MINUS(w1(),c1()),Utils.ONE)) + case (t1:UIntType) => UIntType(MAX(MINUS(w1(),c1()),IntWidth(1))) + case (t1:SIntType) => SIntType(MAX(MINUS(w1(),c1()),IntWidth(1))) case (t1) => UnknownType } DoPrim(o,a,c,t) @@ -323,7 +323,7 @@ object PrimOps extends LazyLogging { } case Cvt => { val t = (t1()) match { - case (t1:UIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1:UIntType) => SIntType(PLUS(w1(),IntWidth(1))) case (t1:SIntType) => SIntType(w1()) case (t1) => UnknownType } @@ -331,8 +331,8 @@ object PrimOps extends LazyLogging { } case Neg => { val t = (t1()) match { - case (t1:UIntType) => SIntType(PLUS(w1(),Utils.ONE)) - case (t1:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1:UIntType) => SIntType(PLUS(w1(),IntWidth(1))) + case (t1:SIntType) => SIntType(PLUS(w1(),IntWidth(1))) case (t1) => UnknownType } DoPrim(o,a,c,t) @@ -396,7 +396,7 @@ object PrimOps extends LazyLogging { } case Bits => { val t = (t1()) match { - case (_:UIntType|_:SIntType) => UIntType(PLUS(MINUS(c1(),c2()),Utils.ONE)) + case (_:UIntType|_:SIntType) => UIntType(PLUS(MINUS(c1(),c2()),IntWidth(1))) case (t1) => UnknownType } DoPrim(o,a,c,t) diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 1db8ce78..9404e5e2 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -36,16 +36,14 @@ MODIFICATIONS. package firrtl -import scala.collection.mutable.StringBuilder +import firrtl.ir._ +import firrtl.PrimOps._ +import firrtl.Mappers._ +import firrtl.WrappedExpression._ +import firrtl.WrappedType._ +import scala.collection.mutable.{StringBuilder, ArrayBuffer, LinkedHashMap, HashMap, HashSet} import java.io.PrintWriter import com.typesafe.scalalogging.LazyLogging -import WrappedExpression._ -import firrtl.WrappedType._ -import firrtl.Mappers._ -import firrtl.PrimOps._ -import firrtl.ir._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.LinkedHashMap //import scala.reflect.runtime.universe._ class FIRRTLException(str: String) extends Exception(str) @@ -82,535 +80,361 @@ object Utils extends LazyLogging { if (bi < BigInt(0)) "\"h" + bi.toString(16).substring(1) + "\"" else "\"h" + bi.toString(16) + "\"" - implicit class WithAs[T](x: T) { - import scala.reflect._ - def as[O: ClassTag]: Option[O] = x match { - case o: O => Some(o) - case _ => None } - def typeof[O: ClassTag]: Boolean = x match { - case o: O => true - case _ => false } - } - implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x) - def ceil_log2(x: BigInt): BigInt = (x-1).bitLength - def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt - def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b - def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a - def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 - val gen_names = Map[String,Int]() - val delin = "_" - val BoolType = UIntType(IntWidth(1)) - val one = UIntLiteral(BigInt(1),IntWidth(1)) - val zero = UIntLiteral(BigInt(0),IntWidth(1)) - def uint (i:Int) : UIntLiteral = { - val num_bits = req_num_bits(i) - val w = IntWidth(scala.math.max(1,num_bits - 1)) - UIntLiteral(BigInt(i),w) - } - def req_num_bits (i: Int) : Int = { - val ix = if (i < 0) ((-1 * i) - 1) else i - ceil_log2(ix + 1) + 1 - } - def AND (e1:WrappedExpression,e2:WrappedExpression) : Expression = { - if (e1 == e2) e1.e1 - else if ((e1 == we(zero)) | (e2 == we(zero))) zero - else if (e1 == we(one)) e2.e1 - else if (e2 == we(one)) e1.e1 - else DoPrim(And,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) - } - - def OR (e1:WrappedExpression,e2:WrappedExpression) : Expression = { - if (e1 == e2) e1.e1 - else if ((e1 == we(one)) | (e2 == we(one))) one - else if (e1 == we(zero)) e2.e1 - else if (e2 == we(zero)) e1.e1 - else DoPrim(Or,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) - } - def EQV (e1:Expression,e2:Expression) : Expression = { DoPrim(Eq,Seq(e1,e2),Seq(),tpe(e1)) } - def NOT (e1:WrappedExpression) : Expression = { - if (e1 == we(one)) zero - else if (e1 == we(zero)) one - else DoPrim(Eq,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1))) - } + implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x) + def ceil_log2(x: BigInt): BigInt = (x-1).bitLength + def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt + def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b + def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a + def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 + val BoolType = UIntType(IntWidth(1)) + val one = UIntLiteral(BigInt(1),IntWidth(1)) + val zero = UIntLiteral(BigInt(0),IntWidth(1)) + def uint (i:Int) : UIntLiteral = { + val num_bits = req_num_bits(i) + val w = IntWidth(scala.math.max(1,num_bits - 1)) + UIntLiteral(BigInt(i),w) + } + def req_num_bits (i: Int) : Int = { + val ix = if (i < 0) ((-1 * i) - 1) else i + ceil_log2(ix + 1) + 1 + } + def EQV (e1:Expression,e2:Expression) : Expression = + DoPrim(Eq, Seq(e1, e2), Nil, e1.tpe) + // TODO: these should be fixed + def AND (e1:WrappedExpression,e2:WrappedExpression) : Expression = { + if (e1 == e2) e1.e1 + else if ((e1 == we(zero)) | (e2 == we(zero))) zero + else if (e1 == we(one)) e2.e1 + else if (e2 == we(one)) e1.e1 + else DoPrim(And,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) + } + def OR (e1:WrappedExpression,e2:WrappedExpression) : Expression = { + if (e1 == e2) e1.e1 + else if ((e1 == we(one)) | (e2 == we(one))) one + else if (e1 == we(zero)) e2.e1 + else if (e2 == we(zero)) e1.e1 + else DoPrim(Or,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) + } + def NOT (e1:WrappedExpression) : Expression = { + if (e1 == we(one)) zero + else if (e1 == we(zero)) one + else DoPrim(Eq,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1))) + } - - //def MUX (p:Expression,e1:Expression,e2:Expression) : Expression = { - // Mux(p,e1,e2,mux_type(tpe(e1),tpe(e2))) - //} + def create_mask(dt: Type): Type = dt match { + case t: VectorType => VectorType(create_mask(t.tpe),t.size) + case t: BundleType => BundleType(t.fields.map (f => f.copy(tpe=create_mask(f.tpe)))) + case t: UIntType => BoolType + case t: SIntType => BoolType + } - def create_mask (dt:Type) : Type = { - dt match { - case t:VectorType => VectorType(create_mask(t.tpe),t.size) - case t:BundleType => { - val fieldss = t.fields.map { f => Field(f.name,f.flip,create_mask(f.tpe)) } - BundleType(fieldss) - } - case t:UIntType => BoolType - case t:SIntType => BoolType - } - } - def create_exps (n:String, t:Type) : Seq[Expression] = - create_exps(WRef(n,t,ExpKind(),UNKNOWNGENDER)) - def create_exps (e:Expression) : Seq[Expression] = e match { - case (e:Mux) => - val e1s = create_exps(e.tval) - val e2s = create_exps(e.fval) - (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type_and_widths(e1,e2))) - case (e:ValidIf) => create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1))) - case (e) => tpe(e) match { - case (_:GroundType) => Seq(e) - case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => - exps ++ create_exps(WSubField(e,f.name,f.tpe,times(gender(e), f.flip)))) - case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => - exps ++ create_exps(WSubIndex(e,i,t.tpe,gender(e)))) + def create_exps(n: String, t: Type): Seq[Expression] = + create_exps(WRef(n, t, ExpKind(), UNKNOWNGENDER)) + def create_exps(e: Expression): Seq[Expression] = e match { + case (e: Mux) => + val e1s = create_exps(e.tval) + val e2s = create_exps(e.fval) + e1s zip e2s map {case (e1, e2) => + Mux(e.cond, e1, e2, mux_type_and_widths(e1,e2)) } - } - def get_flip (t:Type, i:Int, f:Orientation) : Orientation = { - if (i >= get_size(t)) error("Shouldn't be here") - val x = t match { - case (t:UIntType) => f - case (t:SIntType) => f - case ClockType => f - case (t:BundleType) => { - var n = i - var ret:Option[Orientation] = None - t.fields.foreach { x => { - if (n < get_size(x.tpe)) { - ret match { - case None => ret = Some(get_flip(x.tpe,n,times(x.flip,f))) - case ret => {} - } - } else { n = n - get_size(x.tpe) } - }} - ret.asInstanceOf[Some[Orientation]].get - } - case (t:VectorType) => { - var n = i - var ret:Option[Orientation] = None - for (j <- 0 until t.size) { - if (n < get_size(t.tpe)) { - ret = Some(get_flip(t.tpe,n,f)) - } else { - n = n - get_size(t.tpe) - } + case (e: ValidIf) => create_exps(e.value) map (e1 => ValidIf(e.cond, e1, e1.tpe)) + case (e) => e.tpe match { + case (_: GroundType) => Seq(e) + case (t: BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => + exps ++ create_exps(WSubField(e, f.name, f.tpe,times(gender(e), f.flip)))) + case (t: VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => + exps ++ create_exps(WSubIndex(e, i, t.tpe,gender(e)))) + } + } + def get_flip(t: Type, i: Int, f: Orientation): Orientation = { + if (i >= get_size(t)) error("Shouldn't be here") + t match { + case (_: GroundType) => f + case (t: BundleType) => + val (_, flip) = ((t.fields foldLeft (i, None: Option[Orientation])){ + case ((n, ret), x) if n < get_size(x.tpe) => ret match { + case None => (n, Some(get_flip(x.tpe,n,times(x.flip,f)))) + case Some(_) => (n, ret) } - ret.asInstanceOf[Some[Orientation]].get - } - } - x - } - - def get_point (e:Expression) : Int = { - e match { - case (e:WRef) => 0 - case (e:WSubField) => { - var i = 0 - tpe(e.exp).asInstanceOf[BundleType].fields.find { f => { - val b = f.name == e.name - if (!b) { i = i + get_size(f.tpe)} - b - }} - i - } - case (e:WSubIndex) => e.value * get_size(e.tpe) - case (e:WSubAccess) => get_point(e.exp) - } + case ((n, ret), x) => (n - get_size(x.tpe), ret) + }) + flip.get + case (t: VectorType) => + val (_, flip) = (((0 until t.size) foldLeft (i, None: Option[Orientation])){ + case ((n, ret), x) if n < get_size(t.tpe) => ret match { + case None => (n, Some(get_flip(t.tpe,n,f))) + case Some(_) => (n, ret) + } + case ((n, ret), x) => (n - get_size(t.tpe), ret) + }) + flip.get + } } + + def get_point (e:Expression) : Int = e match { + case (e: WRef) => 0 + case (e: WSubField) => e.exp.tpe match {case b: BundleType => + (b.fields takeWhile (_.name != e.name) foldLeft 0)( + (point, f) => point + get_size(f.tpe)) + } + case (e: WSubIndex) => e.value * get_size(e.tpe) + case (e: WSubAccess) => get_point(e.exp) + } - /** Returns true if t, or any subtype, contains a flipped field - * @param t [[firrtl.ir.Type]] - * @return if t contains [[firrtl.ir.Flip]] - */ - def hasFlip(t: Type): Boolean = { - var has = false - def findFlip(t: Type): Type = t map (findFlip) match { - case t: BundleType => - for (f <- t.fields) { if (f.flip == Flip) has = true } - t - case t: Type => t - } - findFlip(t) - has - } + /** Returns true if t, or any subtype, contains a flipped field + * @param t [[firrtl.ir.Type]] + * @return if t contains [[firrtl.ir.Flip]] + */ + def hasFlip(t: Type): Boolean = t match { + case t: BundleType => + (t.fields exists (_.flip == Flip)) || + (t.fields exists (f => hasFlip(f.tpe))) + case t: VectorType => hasFlip(t.tpe) + case _ => false + } //============== TYPES ================ - def mux_type (e1:Expression,e2:Expression) : Type = mux_type(tpe(e1),tpe(e2)) - def mux_type (t1:Type,t2:Type) : Type = { - if (wt(t1) == wt(t2)) { - (t1,t2) match { - case (t1:UIntType,t2:UIntType) => UIntType(UnknownWidth) - case (t1:SIntType,t2:SIntType) => SIntType(UnknownWidth) - case (t1:VectorType,t2:VectorType) => VectorType(mux_type(t1.tpe,t2.tpe),t1.size) - case (t1:BundleType,t2:BundleType) => - BundleType((t1.fields,t2.fields).zipped.map((f1,f2) => { - Field(f1.name,f1.flip,mux_type(f1.tpe,f2.tpe)) - })) - } - } else UnknownType - } - def mux_type_and_widths (e1:Expression,e2:Expression) : Type = mux_type_and_widths(tpe(e1),tpe(e2)) - def mux_type_and_widths (t1:Type,t2:Type) : Type = { - def wmax (w1:Width,w2:Width) : Width = { - (w1,w2) match { - case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width.max(w2.width)) - case (w1,w2) => MaxWidth(Seq(w1,w2)) - } - } - val wt1 = new WrappedType(t1) - val wt2 = new WrappedType(t2) - if (wt1 == wt2) { - (t1,t2) match { - case (t1:UIntType,t2:UIntType) => UIntType(wmax(t1.width,t2.width)) - case (t1:SIntType,t2:SIntType) => SIntType(wmax(t1.width,t2.width)) - case (t1:VectorType,t2:VectorType) => VectorType(mux_type_and_widths(t1.tpe,t2.tpe),t1.size) - case (t1:BundleType,t2:BundleType) => BundleType((t1.fields zip t2.fields).map{case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe,f2.tpe))}) - } - } else UnknownType - } - def module_type (m:DefModule) : Type = { - BundleType(m.ports.map(p => p.toField)) - } - def sub_type (v:Type) : Type = { - v match { - case v:VectorType => v.tpe - case v => UnknownType - } - } - def field_type (v:Type,s:String) : Type = { - v match { - case v:BundleType => { - val ft = v.fields.find(p => p.name == s) - ft match { - case ft:Some[Field] => ft.get.tpe - case ft => UnknownType - } - } - case v => UnknownType - } - } + def mux_type (e1:Expression, e2:Expression) : Type = mux_type(e1.tpe, e2.tpe) + def mux_type (t1:Type, t2:Type) : Type = (t1,t2) match { + case (t1:UIntType, t2:UIntType) => UIntType(UnknownWidth) + case (t1:SIntType, t2:SIntType) => SIntType(UnknownWidth) + case (t1:VectorType, t2:VectorType) => VectorType(mux_type(t1.tpe, t2.tpe), t1.size) + case (t1:BundleType, t2:BundleType) => BundleType((t1.fields zip t2.fields) map { + case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe)) + }) + case _ => UnknownType + } + def mux_type_and_widths (e1:Expression,e2:Expression) : Type = + mux_type_and_widths(e1.tpe, e2.tpe) + def mux_type_and_widths (t1:Type, t2:Type) : Type = { + def wmax (w1:Width, w2:Width) : Width = (w1,w2) match { + case (w1:IntWidth, w2:IntWidth) => IntWidth(w1.width max w2.width) + case (w1, w2) => MaxWidth(Seq(w1, w2)) + } + (t1, t2) match { + case (t1:UIntType, t2:UIntType) => UIntType(wmax(t1.width, t2.width)) + case (t1:SIntType, t2:SIntType) => SIntType(wmax(t1.width, t2.width)) + case (t1:VectorType, t2:VectorType) => VectorType( + mux_type_and_widths(t1.tpe, t2.tpe), t1.size) + case (t1:BundleType, t2:BundleType) => BundleType((t1.fields zip t2.fields) map { + case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe, f2.tpe)) + }) + case _ => UnknownType + } + } + + def module_type(m: DefModule): Type = BundleType(m.ports map { + case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe) + }) + def sub_type(v: Type): Type = v match { + case v: VectorType => v.tpe + case v => UnknownType + } + def field_type(v:Type, s: String) : Type = v match { + case v: BundleType => v.fields find (_.name == s) match { + case Some(f) => f.tpe + case None => UnknownType + } + case v => UnknownType + } ////===================================== - def widthBANG (t:Type) : Width = { - t match { - case g: GroundType => g.width - case t => error("No width!") - } - } - def long_BANG (t:Type) : Long = { - (t) match { - case g: GroundType => - g.width match { - case IntWidth(x) => x.toLong - case _ => throw new FIRRTLException(s"Expecting IntWidth, got: ${g.width}") - } - case (t:BundleType) => { - var w = 0 - for (f <- t.fields) { w = w + long_BANG(f.tpe).toInt } - w - } - case (t:VectorType) => t.size * long_BANG(t.tpe) - } - } -// ================================= - def error(str:String) = throw new FIRRTLException(str) + def widthBANG (t:Type) : Width = t match { + case g: GroundType => g.width + case t => error("No width!") + } + def long_BANG(t: Type): Long = t match { + case (g: GroundType) => g.width match { + case IntWidth(x) => x.toLong + case _ => error(s"Expecting IntWidth, got: ${g.width}") + } + case (t: BundleType) => (t.fields foldLeft 0)((w, f) => + w + long_BANG(f.tpe).toInt) + case (t: VectorType) => t.size * long_BANG(t.tpe) + } - implicit class FirrtlNodeUtils(node: FirrtlNode) { - def getType(): Type = - node match { - case e: Expression => e.getType - case s: Statement => s.getType - //case f: Field => f.getType - case t: Type => t.getType - case p: Port => p.getType - case _ => UnknownType - } - } +// ================================= + def error(str: String) = throw new FIRRTLException(str) //// =============== EXPANSION FUNCTIONS ================ - def get_size (t:Type) : Int = { - t match { - case (t:BundleType) => { - var sum = 0 - for (f <- t.fields) { - sum = sum + get_size(f.tpe) - } - sum - } - case (t:VectorType) => t.size * get_size(t.tpe) - case (t) => 1 - } - } - def get_valid_points (t1:Type, t2:Type, flip1:Orientation, flip2:Orientation) : Seq[(Int,Int)] = { - //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) - (t1,t2) match { - case (t1:UIntType,t2:UIntType) => if (flip1 == flip2) Seq((0, 0)) else Seq() - case (t1:SIntType,t2:SIntType) => if (flip1 == flip2) Seq((0, 0)) else Seq() - case (t1:BundleType,t2:BundleType) => { - val points = ArrayBuffer[(Int,Int)]() - var ilen = 0 - var jlen = 0 - for (i <- 0 until t1.fields.size) { - for (j <- 0 until t2.fields.size) { - val f1 = t1.fields(i) - val f2 = t2.fields(j) - if (f1.name == f2.name) { - val ls = get_valid_points(f1.tpe,f2.tpe,times(flip1, f1.flip),times(flip2, f2.flip)) - for (x <- ls) { - points += ((x._1 + ilen, x._2 + jlen)) - } - } - jlen = jlen + get_size(t2.fields(j).tpe) - } - ilen = ilen + get_size(t1.fields(i).tpe) - jlen = 0 - } - points - } - case (t1:VectorType,t2:VectorType) => { - val points = ArrayBuffer[(Int,Int)]() - var ilen = 0 - var jlen = 0 - for (i <- 0 until scala.math.min(t1.size,t2.size)) { - val ls = get_valid_points(t1.tpe,t2.tpe,flip1,flip2) - for (x <- ls) { - val y = ((x._1 + ilen), (x._2 + jlen)) - points += y - } - ilen = ilen + get_size(t1.tpe) - jlen = jlen + get_size(t2.tpe) - } - points - } - case (ClockType,ClockType) => if (flip1 == flip2) Seq((0, 0)) else Seq() - } - } + def get_size(t: Type): Int = t match { + case (t: BundleType) => (t.fields foldLeft 0)( + (sum, f) => sum + get_size(f.tpe)) + case (t: VectorType) => t.size * get_size(t.tpe) + case (t) => 1 + } + + def get_valid_points(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Seq[(Int,Int)] = { + //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) + (t1, t2) match { + case (t1: UIntType, t2: UIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil + case (t1: SIntType, t2: SIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil + case (t1: BundleType, t2: BundleType) => + def emptyMap = Map[String, (Type, Orientation, Int)]() + val t1_fields = ((t1.fields foldLeft (emptyMap, 0)){case ((map, ilen), f1) => + (map + (f1.name -> (f1.tpe, f1.flip, ilen)), ilen + get_size(f1.tpe))})._1 + ((t2.fields foldLeft (Seq[(Int, Int)](), 0)){case ((points, jlen), f2) => + t1_fields get f2.name match { + case None => (points, jlen + get_size(f2.tpe)) + case Some((f1_tpe, f1_flip, ilen))=> + val f1_times = times(flip1, f1_flip) + val f2_times = times(flip2, f2.flip) + val ls = get_valid_points(f1_tpe, f2.tpe, f1_times, f2_times) + (points ++ (ls map {case (x, y) => (x + ilen, y + jlen)}), jlen + get_size(f2.tpe)) + } + })._1 + case (t1: VectorType, t2: VectorType) => + val size = math.min(t1.size, t2.size) + (((0 until size) foldLeft (Seq[(Int, Int)](), 0, 0)){case ((points, ilen, jlen), _) => + val ls = get_valid_points(t1.tpe, t2.tpe, flip1, flip2) + (points ++ (ls map {case (x, y) => ((x + ilen), (y + jlen))}), + ilen + get_size(t1.tpe), jlen + get_size(t2.tpe)) + })._1 + case (ClockType, ClockType) => if (flip1 == flip2) Seq((0, 0)) else Nil + case _ => error("shouldn't be here") + } + } + // =========== GENDER/FLIP UTILS ============ - def swap (g:Gender) : Gender = { - g match { - case UNKNOWNGENDER => UNKNOWNGENDER - case MALE => FEMALE - case FEMALE => MALE - case BIGENDER => BIGENDER - } - } - def swap (d:Direction) : Direction = { - d match { - case Output => Input - case Input => Output - } - } - def swap (f:Orientation) : Orientation = { - f match { - case Default => Flip - case Flip => Default - } - } - def to_dir (g:Gender) : Direction = { - g match { - case MALE => Input - case FEMALE => Output - } - } - def to_gender (d:Direction) : Gender = { - d match { - case Input => MALE - case Output => FEMALE - } - } - def toGender(f: Orientation): Gender = f match { - case Default => FEMALE - case Flip => MALE + def swap(g: Gender) : Gender = g match { + case UNKNOWNGENDER => UNKNOWNGENDER + case MALE => FEMALE + case FEMALE => MALE + case BIGENDER => BIGENDER + } + def swap(d: Direction) : Direction = d match { + case Output => Input + case Input => Output + } + def swap(f: Orientation) : Orientation = f match { + case Default => Flip + case Flip => Default + } + def to_dir(g: Gender): Direction = g match { + case MALE => Input + case FEMALE => Output } - def toFlip(g: Gender): Orientation = g match { + def to_gender(d: Direction): Gender = d match { + case Input => MALE + case Output => FEMALE + } + def to_flip(d: Direction): Orientation = d match { + case Input => Flip + case Output => Default + } + def to_flip(g: Gender): Orientation = g match { case MALE => Flip case FEMALE => Default } - def field_flip (v:Type,s:String) : Orientation = { - v match { - case v:BundleType => { - val ft = v.fields.find {p => p.name == s} - ft match { - case ft:Some[Field] => ft.get.flip - case ft => Default - } - } - case v => Default - } - } - def get_field (v:Type,s:String) : Field = { - v match { - case v:BundleType => { - val ft = v.fields.find {p => p.name == s} - ft match { - case ft:Some[Field] => ft.get - case ft => error("Shouldn't be here"); Field("blah",Default,UnknownType) - } - } - case v => error("Shouldn't be here"); Field("blah",Default,UnknownType) - } - } - def times (flip:Orientation, d:Direction) : Direction = times(flip, d) - def times (d:Direction,flip:Orientation) : Direction = { - flip match { - case Default => d - case Flip => swap(d) - } - } - def times (g: Gender, d: Direction): Direction = times(d, g) - def times (d: Direction, g: Gender): Direction = g match { - case FEMALE => d - case MALE => swap(d) // MALE == INPUT == REVERSE - } + def field_flip(v:Type, s:String) : Orientation = v match { + case (v:BundleType) => v.fields find (_.name == s) match { + case Some(ft) => ft.flip + case None => Default + } + case v => Default + } + def get_field(v:Type, s:String) : Field = v match { + case (v:BundleType) => v.fields find (_.name == s) match { + case Some(ft) => ft + case None => error("Shouldn't be here") + } + case v => error("Shouldn't be here") + } - def times (g:Gender,flip:Orientation) : Gender = times(flip, g) - def times (flip:Orientation, g:Gender) : Gender = { - flip match { - case Default => g - case Flip => swap(g) - } - } - def times (f1:Orientation, f2:Orientation) : Orientation = { - f2 match { - case Default => f1 - case Flip => swap(f1) - } - } + def times(flip: Orientation, d: Direction): Direction = times(flip, d) + def times(d: Direction,flip: Orientation): Direction = flip match { + case Default => d + case Flip => swap(d) + } + def times(g: Gender, d: Direction): Direction = times(d, g) + def times(d: Direction, g: Gender): Direction = g match { + case FEMALE => d + case MALE => swap(d) // MALE == INPUT == REVERSE + } + def times(g: Gender,flip: Orientation): Gender = times(flip, g) + def times(flip: Orientation, g: Gender): Gender = flip match { + case Default => g + case Flip => swap(g) + } + def times(f1: Orientation, f2: Orientation): Orientation = f2 match { + case Default => f1 + case Flip => swap(f1) + } // =========== ACCESSORS ========= - def info (s:Statement) : Info = { - s match { - case s:DefWire => s.info - case s:DefRegister => s.info - case s:DefInstance => s.info - case s:WDefInstance => s.info - case s:DefMemory => s.info - case s:DefNode => s.info - case s:Conditionally => s.info - case s:PartialConnect => s.info - case s:Connect => s.info - case s:IsInvalid => s.info - case s:Stop => s.info - case s:Print => s.info - case s:Block => NoInfo - case EmptyStmt => NoInfo - } - } - def gender (e:Expression) : Gender = { - e match { - case e:WRef => e.gender - case e:WSubField => e.gender - case e:WSubIndex => e.gender - case e:WSubAccess => e.gender - case e:DoPrim => MALE - case e:UIntLiteral => MALE - case e:SIntLiteral => MALE - case e:Mux => MALE - case e:ValidIf => MALE - case e:WInvalid => MALE - case e => println(e); error("Shouldn't be here") - }} - def get_gender (s:Statement) : Gender = - s match { - case s:DefWire => BIGENDER - case s:DefRegister => BIGENDER - case s:WDefInstance => MALE - case s:DefNode => MALE - case s:DefInstance => MALE - case s:DefMemory => MALE - case s:Block => UNKNOWNGENDER - case s:Connect => UNKNOWNGENDER - case s:PartialConnect => UNKNOWNGENDER - case s:Stop => UNKNOWNGENDER - case s:Print => UNKNOWNGENDER - case EmptyStmt => UNKNOWNGENDER - case s:IsInvalid => UNKNOWNGENDER - } - def get_gender (p:Port) : Gender = - if (p.direction == Input) MALE else FEMALE - def kind (e:Expression) : Kind = - e match { - case e:WRef => e.kind - case e:WSubField => kind(e.exp) - case e:WSubIndex => kind(e.exp) - case e => ExpKind() - } - def tpe (e:Expression) : Type = - e match { - case e:Reference => e.tpe - case e:SubField => e.tpe - case e:SubIndex => e.tpe - case e:SubAccess => e.tpe - case e:WRef => e.tpe - case e:WSubField => e.tpe - case e:WSubIndex => e.tpe - case e:WSubAccess => e.tpe - case e:DoPrim => e.tpe - case e:Mux => e.tpe - case e:ValidIf => e.tpe - case e:UIntLiteral => UIntType(e.width) - case e:SIntLiteral => SIntType(e.width) - case e:WVoid => UnknownType - case e:WInvalid => UnknownType - } - def get_type (s:Statement) : Type = { - s match { - case s:DefWire => s.tpe - case s:DefRegister => s.tpe - case s:DefNode => tpe(s.value) - case s:DefMemory => { - val depth = s.depth - val addr = Field("addr",Default,UIntType(IntWidth(scala.math.max(ceil_log2(depth), 1)))) - val en = Field("en",Default,BoolType) - val clk = Field("clk",Default,ClockType) - val def_data = Field("data",Default,s.dataType) - val rev_data = Field("data",Flip,s.dataType) - val mask = Field("mask",Default,create_mask(s.dataType)) - val wmode = Field("wmode",Default,UIntType(IntWidth(1))) - val rdata = Field("rdata",Flip,s.dataType) - val wdata = Field("wdata",Default,s.dataType) - val wmask = Field("wmask",Default,create_mask(s.dataType)) - val read_type = BundleType(Seq(rev_data,addr,en,clk)) - val write_type = BundleType(Seq(def_data,mask,addr,en,clk)) - val readwrite_type = BundleType(Seq(wmode,rdata,wdata,wmask,addr,en,clk)) - - val mem_fields = ArrayBuffer[Field]() - s.readers.foreach {x => mem_fields += Field(x,Flip,read_type)} - s.writers.foreach {x => mem_fields += Field(x,Flip,write_type)} - s.readwriters.foreach {x => mem_fields += Field(x,Flip,readwrite_type)} - BundleType(mem_fields) - } - case s:DefInstance => UnknownType - case s:WDefInstance => s.tpe - case _ => UnknownType - }} - def get_name (s:Statement) : String = { - s match { - case s:DefWire => s.name - case s:DefRegister => s.name - case s:DefNode => s.name - case s:DefMemory => s.name - case s:DefInstance => s.name - case s:WDefInstance => s.name - case _ => error("Shouldn't be here"); "blah" - }} - def get_info (s:Statement) : Info = { - s match { - case s:DefWire => s.info - case s:DefRegister => s.info - case s:DefInstance => s.info - case s:WDefInstance => s.info - case s:DefMemory => s.info - case s:DefNode => s.info - case s:Conditionally => s.info - case s:PartialConnect => s.info - case s:Connect => s.info - case s:IsInvalid => s.info - case s:Stop => s.info - case s:Print => s.info - case _ => NoInfo - }} + def kind(e: Expression): Kind = e match { + case e: WRef => e.kind + case e: WSubField => kind(e.exp) + case e: WSubIndex => kind(e.exp) + case e: WSubAccess => kind(e.exp) + case e => ExpKind() + } + def gender (e: Expression): Gender = e match { + case e: WRef => e.gender + case e: WSubField => e.gender + case e: WSubIndex => e.gender + case e: WSubAccess => e.gender + case e: DoPrim => MALE + case e: UIntLiteral => MALE + case e: SIntLiteral => MALE + case e: Mux => MALE + case e: ValidIf => MALE + case e: WInvalid => MALE + case e => println(e); error("Shouldn't be here") + } + def get_gender(s:Statement): Gender = s match { + case s: DefWire => BIGENDER + case s: DefRegister => BIGENDER + case s: WDefInstance => MALE + case s: DefNode => MALE + case s: DefInstance => MALE + case s: DefMemory => MALE + case s: Block => UNKNOWNGENDER + case s: Connect => UNKNOWNGENDER + case s: PartialConnect => UNKNOWNGENDER + case s: Stop => UNKNOWNGENDER + case s: Print => UNKNOWNGENDER + case s: IsInvalid => UNKNOWNGENDER + case EmptyStmt => UNKNOWNGENDER + } + def get_gender(p: Port): Gender = if (p.direction == Input) MALE else FEMALE + def get_type(s: Statement): Type = s match { + case s: DefWire => s.tpe + case s: DefRegister => s.tpe + case s: DefNode => s.value.tpe + case s: DefMemory => + val depth = s.depth + val addr = Field("addr", Default, UIntType(IntWidth(scala.math.max(ceil_log2(depth), 1)))) + val en = Field("en", Default, BoolType) + val clk = Field("clk", Default, ClockType) + val def_data = Field("data", Default, s.dataType) + val rev_data = Field("data", Flip, s.dataType) + val mask = Field("mask", Default, create_mask(s.dataType)) + val wmode = Field("wmode", Default, UIntType(IntWidth(1))) + val rdata = Field("rdata", Flip, s.dataType) + val wdata = Field("wdata", Default, s.dataType) + val wmask = Field("wmask", Default, create_mask(s.dataType)) + val read_type = BundleType(Seq(rev_data, addr, en, clk)) + val write_type = BundleType(Seq(def_data, mask, addr, en, clk)) + val readwrite_type = BundleType(Seq(wmode, rdata, wdata, wmask, addr, en, clk)) + BundleType( + (s.readers map (Field(_, Flip, read_type))) ++ + (s.writers map (Field(_, Flip, write_type))) ++ + (s.readwriters map (Field(_, Flip, readwrite_type))) + ) + case s: WDefInstance => s.tpe + case _ => UnknownType + } + def get_name(s: Statement): String = s match { + case s: HasName => s.name + case _ => error("Shouldn't be here") + } + def get_info(s: Statement): Info = s match { + case s: HasInfo => s.info + case _ => NoInfo + } /** Splits an Expression into root Ref and tail * @@ -680,11 +504,12 @@ object Utils extends LazyLogging { case None => getRootDecl(root.name)(m.body) match { case Some(decl) => decl - case None => throw new DeclarationNotFoundException(s"[module ${m.name}] Reference ${expr.serialize} not declared!") + case None => throw new DeclarationNotFoundException( + s"[module ${m.name}] Reference ${expr.serialize} not declared!") } } rootDecl - case e => throw new FIRRTLException(s"getDeclaration does not support Expressions of type ${e.getClass}") + case e => error(s"getDeclaration does not support Expressions of type ${e.getClass}") } } @@ -699,7 +524,6 @@ object Utils extends LazyLogging { def apply_s (s:Statement) : Statement = s map (apply_s) map (apply_e) map (apply_t) apply_s(s) } - val ONE = IntWidth(1) //def digits (s:String) : Boolean { // val digits = "0123456789" // var yes:Boolean = true @@ -750,322 +574,79 @@ object Utils extends LazyLogging { // to-stmt(body(m)) // map(to-port,ports(m)) // sym-hash - implicit class StmtUtils(stmt: Statement) { - def getType(): Type = - stmt match { - case s: DefWire => s.tpe - case s: DefRegister => s.tpe - case s: DefMemory => s.dataType - case _ => UnknownType - } + val v_keywords = Set( + "alias", "always", "always_comb", "always_ff", "always_latch", + "and", "assert", "assign", "assume", "attribute", "automatic", - def getInfo: Info = - stmt match { - case s: DefWire => s.info - case s: DefRegister => s.info - case s: DefInstance => s.info - case s: DefMemory => s.info - case s: DefNode => s.info - case s: Conditionally => s.info - case s: PartialConnect => s.info - case s: Connect => s.info - case s: IsInvalid => s.info - case s: Stop => s.info - case s: Print => s.info - case _ => NoInfo - } - } + "before", "begin", "bind", "bins", "binsof", "bit", "break", + "buf", "bufif0", "bufif1", "byte", - implicit class FlipUtils(f: Orientation) { - def flip(): Orientation = { - f match { - case Flip => Default - case Default => Flip - } - } - - def toDirection(): Direction = { - f match { - case Default => Output - case Flip => Input - } - } - } + "case", "casex", "casez", "cell", "chandle", "class", "clocking", + "cmos", "config", "const", "constraint", "context", "continue", + "cover", "covergroup", "coverpoint", "cross", - implicit class FieldUtils(field: Field) { - def flip(): Field = Field(field.name, field.flip.flip, field.tpe) + "deassign", "default", "defparam", "design", "disable", "dist", "do", - def getType(): Type = field.tpe - def toPort(info: Info = NoInfo): Port = - Port(info, field.name, field.flip.toDirection, field.tpe) - } + "edge", "else", "end", "endattribute", "endcase", "endclass", + "endclocking", "endconfig", "endfunction", "endgenerate", + "endgroup", "endinterface", "endmodule", "endpackage", + "endprimitive", "endprogram", "endproperty", "endspecify", + "endsequence", "endtable", "endtask", + "enum", "event", "expect", "export", "extends", "extern", - implicit class TypeUtils(t: Type) { - def isGround: Boolean = t match { - case (_: UIntType | _: SIntType | ClockType) => true - case (_: BundleType | _: VectorType) => false - } - def isAggregate: Boolean = !t.isGround + "final", "first_match", "for", "force", "foreach", "forever", + "fork", "forkjoin", "function", - def getType(): Type = - t match { - case v: VectorType => v.tpe - case tpe: Type => UnknownType - } + "generate", "genvar", - def wipeWidth(): Type = - t match { - case t: UIntType => UIntType(UnknownWidth) - case t: SIntType => SIntType(UnknownWidth) - case _ => t - } - } + "highz0", "highz1", - implicit class DirectionUtils(d: Direction) { - def toFlip(): Orientation = { - d match { - case Input => Flip - case Output => Default - } - } - } - - implicit class PortUtils(p: Port) { - def getType(): Type = p.tpe - def toField(): Field = Field(p.name, p.direction.toFlip, p.tpe) - } + "if", "iff", "ifnone", "ignore_bins", "illegal_bins", "import", + "incdir", "include", "initial", "initvar", "inout", "input", + "inside", "instance", "int", "integer", "interconnect", + "interface", "intersect", + + "join", "join_any", "join_none", "large", "liblist", "library", + "local", "localparam", "logic", "longint", + + "macromodule", "matches", "medium", "modport", "module", + + "nand", "negedge", "new", "nmos", "nor", "noshowcancelled", + "not", "notif0", "notif1", "null", + + "or", "output", + + "package", "packed", "parameter", "pmos", "posedge", + "primitive", "priority", "program", "property", "protected", + "pull0", "pull1", "pulldown", "pullup", + "pulsestyle_onevent", "pulsestyle_ondetect", "pure", + + "rand", "randc", "randcase", "randsequence", "rcmos", + "real", "realtime", "ref", "reg", "release", "repeat", + "return", "rnmos", "rpmos", "rtran", "rtranif0", "rtranif1", + + "scalared", "sequence", "shortint", "shortreal", "showcancelled", + "signed", "small", "solve", "specify", "specparam", "static", + "strength", "string", "strong0", "strong1", "struct", "super", + "supply0", "supply1", + + "table", "tagged", "task", "this", "throughout", "time", "timeprecision", + "timeunit", "tran", "tranif0", "tranif1", "tri", "tri0", "tri1", "triand", + "trior", "trireg", "type","typedef", + + "union", "unique", "unsigned", "use", + + "var", "vectored", "virtual", "void", + + "wait", "wait_order", "wand", "weak0", "weak1", "while", + "wildcard", "wire", "with", "within", "wor", + "xnor", "xor", - val v_keywords = Map[String,Boolean]() + - ("alias" -> true) + - ("always" -> true) + - ("always_comb" -> true) + - ("always_ff" -> true) + - ("always_latch" -> true) + - ("and" -> true) + - ("assert" -> true) + - ("assign" -> true) + - ("assume" -> true) + - ("attribute" -> true) + - ("automatic" -> true) + - ("before" -> true) + - ("begin" -> true) + - ("bind" -> true) + - ("bins" -> true) + - ("binsof" -> true) + - ("bit" -> true) + - ("break" -> true) + - ("buf" -> true) + - ("bufif0" -> true) + - ("bufif1" -> true) + - ("byte" -> true) + - ("case" -> true) + - ("casex" -> true) + - ("casez" -> true) + - ("cell" -> true) + - ("chandle" -> true) + - ("class" -> true) + - ("clocking" -> true) + - ("cmos" -> true) + - ("config" -> true) + - ("const" -> true) + - ("constraint" -> true) + - ("context" -> true) + - ("continue" -> true) + - ("cover" -> true) + - ("covergroup" -> true) + - ("coverpoint" -> true) + - ("cross" -> true) + - ("deassign" -> true) + - ("default" -> true) + - ("defparam" -> true) + - ("design" -> true) + - ("disable" -> true) + - ("dist" -> true) + - ("do" -> true) + - ("edge" -> true) + - ("else" -> true) + - ("end" -> true) + - ("endattribute" -> true) + - ("endcase" -> true) + - ("endclass" -> true) + - ("endclocking" -> true) + - ("endconfig" -> true) + - ("endfunction" -> true) + - ("endgenerate" -> true) + - ("endgroup" -> true) + - ("endinterface" -> true) + - ("endmodule" -> true) + - ("endpackage" -> true) + - ("endprimitive" -> true) + - ("endprogram" -> true) + - ("endproperty" -> true) + - ("endspecify" -> true) + - ("endsequence" -> true) + - ("endtable" -> true) + - ("endtask" -> true) + - ("enum" -> true) + - ("event" -> true) + - ("expect" -> true) + - ("export" -> true) + - ("extends" -> true) + - ("extern" -> true) + - ("final" -> true) + - ("first_match" -> true) + - ("for" -> true) + - ("force" -> true) + - ("foreach" -> true) + - ("forever" -> true) + - ("fork" -> true) + - ("forkjoin" -> true) + - ("function" -> true) + - ("generate" -> true) + - ("genvar" -> true) + - ("highz0" -> true) + - ("highz1" -> true) + - ("if" -> true) + - ("iff" -> true) + - ("ifnone" -> true) + - ("ignore_bins" -> true) + - ("illegal_bins" -> true) + - ("import" -> true) + - ("incdir" -> true) + - ("include" -> true) + - ("initial" -> true) + - ("initvar" -> true) + - ("inout" -> true) + - ("input" -> true) + - ("inside" -> true) + - ("instance" -> true) + - ("int" -> true) + - ("integer" -> true) + - ("interconnect" -> true) + - ("interface" -> true) + - ("intersect" -> true) + - ("join" -> true) + - ("join_any" -> true) + - ("join_none" -> true) + - ("large" -> true) + - ("liblist" -> true) + - ("library" -> true) + - ("local" -> true) + - ("localparam" -> true) + - ("logic" -> true) + - ("longint" -> true) + - ("macromodule" -> true) + - ("matches" -> true) + - ("medium" -> true) + - ("modport" -> true) + - ("module" -> true) + - ("nand" -> true) + - ("negedge" -> true) + - ("new" -> true) + - ("nmos" -> true) + - ("nor" -> true) + - ("noshowcancelled" -> true) + - ("not" -> true) + - ("notif0" -> true) + - ("notif1" -> true) + - ("null" -> true) + - ("or" -> true) + - ("output" -> true) + - ("package" -> true) + - ("packed" -> true) + - ("parameter" -> true) + - ("pmos" -> true) + - ("posedge" -> true) + - ("primitive" -> true) + - ("priority" -> true) + - ("program" -> true) + - ("property" -> true) + - ("protected" -> true) + - ("pull0" -> true) + - ("pull1" -> true) + - ("pulldown" -> true) + - ("pullup" -> true) + - ("pulsestyle_onevent" -> true) + - ("pulsestyle_ondetect" -> true) + - ("pure" -> true) + - ("rand" -> true) + - ("randc" -> true) + - ("randcase" -> true) + - ("randsequence" -> true) + - ("rcmos" -> true) + - ("real" -> true) + - ("realtime" -> true) + - ("ref" -> true) + - ("reg" -> true) + - ("release" -> true) + - ("repeat" -> true) + - ("return" -> true) + - ("rnmos" -> true) + - ("rpmos" -> true) + - ("rtran" -> true) + - ("rtranif0" -> true) + - ("rtranif1" -> true) + - ("scalared" -> true) + - ("sequence" -> true) + - ("shortint" -> true) + - ("shortreal" -> true) + - ("showcancelled" -> true) + - ("signed" -> true) + - ("small" -> true) + - ("solve" -> true) + - ("specify" -> true) + - ("specparam" -> true) + - ("static" -> true) + - ("strength" -> true) + - ("string" -> true) + - ("strong0" -> true) + - ("strong1" -> true) + - ("struct" -> true) + - ("super" -> true) + - ("supply0" -> true) + - ("supply1" -> true) + - ("table" -> true) + - ("tagged" -> true) + - ("task" -> true) + - ("this" -> true) + - ("throughout" -> true) + - ("time" -> true) + - ("timeprecision" -> true) + - ("timeunit" -> true) + - ("tran" -> true) + - ("tranif0" -> true) + - ("tranif1" -> true) + - ("tri" -> true) + - ("tri0" -> true) + - ("tri1" -> true) + - ("triand" -> true) + - ("trior" -> true) + - ("trireg" -> true) + - ("type" -> true) + - ("typedef" -> true) + - ("union" -> true) + - ("unique" -> true) + - ("unsigned" -> true) + - ("use" -> true) + - ("var" -> true) + - ("vectored" -> true) + - ("virtual" -> true) + - ("void" -> true) + - ("wait" -> true) + - ("wait_order" -> true) + - ("wand" -> true) + - ("weak0" -> true) + - ("weak1" -> true) + - ("while" -> true) + - ("wildcard" -> true) + - ("wire" -> true) + - ("with" -> true) + - ("within" -> true) + - ("wor" -> true) + - ("xnor" -> true) + - ("xor" -> true) + - ("SYNTHESIS" -> true) + - ("PRINTF_COND" -> true) + - ("VCS" -> true) + "SYNTHESIS", + "PRINTF_COND", + "VCS") } object MemoizedHash { diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index 60a49bac..e0e7c57a 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -105,7 +105,7 @@ object CheckChirrtl extends Pass with LazyLogging { e } def checkChirrtlS(s: Statement): Statement = { - sinfo = s.getInfo + sinfo = get_info(s) def checkName(name: String): String = { if (names.contains(name)) errors.append(new NotUniqueException(name)) else names(name) = true @@ -138,7 +138,7 @@ object CheckChirrtl extends Pass with LazyLogging { for (p <- m.ports) { sinfo = p.info names(p.name) = true - val tpe = p.getType + val tpe = p.tpe tpe map (checkChirrtlT) tpe map (checkChirrtlW) } diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 9ee20c0a..6e49ce93 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -241,7 +241,7 @@ object CheckHighForm extends Pass with LazyLogging { else names(name) = true name } - sinfo = s.getInfo + sinfo = get_info(s) s map (checkName) s map (checkHighFormT) @@ -276,7 +276,7 @@ object CheckHighForm extends Pass with LazyLogging { for (p <- m.ports) { // FIXME should we set sinfo here? names(p.name) = true - val tpe = p.getType + val tpe = p.tpe tpe map (checkHighFormT) tpe map (checkHighFormW) } @@ -336,27 +336,36 @@ object CheckTypes extends Pass with LazyLogging { def all_same_type (ls:Seq[Expression]) : Unit = { var error = false for (x <- ls) { - if (wt(tpe(ls.head)) != wt(tpe(x))) error = true + if (wt(ls.head.tpe) != wt(x.tpe)) error = true } if (error) errors.append(new OpNotAllSameType(info,e.op.serialize)) } def all_ground (ls:Seq[Expression]) : Unit = { var error = false for (x <- ls ) { - if (!(tpe(x).typeof[UIntType] || tpe(x).typeof[SIntType])) error = true + x.tpe match { + case _: UIntType | _: SIntType => + case _ => error = true + } } if (error) errors.append(new OpNotGround(info,e.op.serialize)) } def all_uint (ls:Seq[Expression]) : Unit = { var error = false for (x <- ls ) { - if (!(tpe(x).typeof[UIntType])) error = true + x.tpe match { + case _: UIntType => + case _ => error = true + } } if (error) errors.append(new OpNotAllUInt(info,e.op.serialize)) } def is_uint (x:Expression) : Unit = { var error = false - if (!(tpe(x).typeof[UIntType])) error = true + x.tpe match { + case _: UIntType => + case _ => error = true + } if (error) errors.append(new OpNotUInt(info,e.op.serialize,x.serialize)) } @@ -417,7 +426,7 @@ object CheckTypes extends Pass with LazyLogging { (e map (check_types_e(info))) match { case (e:WRef) => e case (e:WSubField) => { - (tpe(e.exp)) match { + (e.exp.tpe) match { case (t:BundleType) => { val ft = t.fields.find(p => p.name == e.name) if (ft == None) errors.append(new SubfieldNotInBundle(info,e.name)) @@ -426,7 +435,7 @@ object CheckTypes extends Pass with LazyLogging { } } case (e:WSubIndex) => { - (tpe(e.exp)) match { + (e.exp.tpe) match { case (t:VectorType) => { if (e.value >= t.size) errors.append(new IndexTooLarge(info,e.value)) } @@ -434,24 +443,30 @@ object CheckTypes extends Pass with LazyLogging { } } case (e:WSubAccess) => { - (tpe(e.exp)) match { + (e.exp.tpe) match { case (t:VectorType) => false case (t) => errors.append(new IndexOnNonVector(info)) } - (tpe(e.index)) match { + (e.index.tpe) match { case (t:UIntType) => false case (t) => errors.append(new AccessIndexNotUInt(info)) } } case (e:DoPrim) => check_types_primop(e,errors,info) case (e:Mux) => { - if (wt(tpe(e.tval)) != wt(tpe(e.fval))) errors.append(new MuxSameType(info)) - if (!passive(tpe(e))) errors.append(new MuxPassiveTypes(info)) - if (!(tpe(e.cond).typeof[UIntType])) errors.append(new MuxCondUInt(info)) + if (wt(e.tval.tpe) != wt(e.fval.tpe)) errors.append(new MuxSameType(info)) + if (!passive(e.tpe)) errors.append(new MuxPassiveTypes(info)) + e.cond.tpe match { + case _: UIntType => + case _ => errors.append(new MuxCondUInt(info)) + } } case (e:ValidIf) => { - if (!passive(tpe(e))) errors.append(new ValidIfPassiveTypes(info)) - if (!(tpe(e.cond).typeof[UIntType])) errors.append(new ValidIfCondUInt(info)) + if (!passive(e.tpe)) errors.append(new ValidIfPassiveTypes(info)) + e.cond.tpe match { + case _: UIntType => + case _ => errors.append(new ValidIfCondUInt(info)) + } } case (_:UIntLiteral | _:SIntLiteral) => false } @@ -484,22 +499,22 @@ object CheckTypes extends Pass with LazyLogging { def check_types_s (s:Statement) : Statement = { s map (check_types_e(get_info(s))) match { - case (s:Connect) => if (wt(tpe(s.loc)) != wt(tpe(s.expr))) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) - case (s:DefRegister) => if (wt(s.tpe) != wt(tpe(s.init))) errors.append(new InvalidRegInit(s.info)) - case (s:PartialConnect) => if (!bulk_equals(tpe(s.loc),tpe(s.expr),Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) + case (s:Connect) => if (wt(s.loc.tpe) != wt(s.expr.tpe)) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) + case (s:DefRegister) => if (wt(s.tpe) != wt(s.init.tpe)) errors.append(new InvalidRegInit(s.info)) + case (s:PartialConnect) => if (!bulk_equals(s.loc.tpe,s.expr.tpe,Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) case (s:Stop) => { - if (wt(tpe(s.clk)) != wt(ClockType) ) errors.append(new ReqClk(s.info)) - if (wt(tpe(s.en)) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) + if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info)) + if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) } case (s:Print)=> { for (x <- s.args ) { - if (wt(tpe(x)) != wt(ut()) && wt(tpe(x)) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info)) + if (wt(x.tpe) != wt(ut()) && wt(x.tpe) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info)) } - if (wt(tpe(s.clk)) != wt(ClockType) ) errors.append(new ReqClk(s.info)) - if (wt(tpe(s.en)) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) + if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info)) + if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) } - case (s:Conditionally) => if (wt(tpe(s.pred)) != wt(ut()) ) errors.append(new PredNotUInt(s.info)) - case (s:DefNode) => if (!passive(tpe(s.value)) ) errors.append(new NodePassiveType(s.info)) + case (s:Conditionally) => if (wt(s.pred.tpe) != wt(ut()) ) errors.append(new PredNotUInt(s.info)) + case (s:DefNode) => if (!passive(s.value.tpe) ) errors.append(new NodePassiveType(s.info)) case (s) => false } s map (check_types_s) @@ -571,7 +586,7 @@ object CheckGenders extends Pass { fQ } - val has_flipQ = flipQ(tpe(e)) + val has_flipQ = flipQ(e.tpe) //println(e) //println(gender) //println(desired) @@ -597,7 +612,7 @@ object CheckGenders extends Pass { (e) match { case (e:WRef) => genders(e.name) case (e:WSubField) => - val f = tpe(e.exp).as[BundleType].get.fields.find(f => f.name == e.name).get + val f = e.exp.tpe.asInstanceOf[BundleType].fields.find(f => f.name == e.name).get times(get_gender(e.exp,genders),f.flip) case (e:WSubIndex) => get_gender(e.exp,genders) case (e:WSubAccess) => get_gender(e.exp,genders) @@ -735,7 +750,7 @@ object CheckWidths extends Pass { } def check_width_s (s:Statement) : Statement = { s map (check_width_s) map (check_width_e(get_info(s))) - def tm (t:Type) : Type = mapr(check_width_w(info(s)) _,t) + def tm (t:Type) : Type = mapr(check_width_w(get_info(s)) _,t) s map (tm) } diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index 57782a3c..2e8b53f3 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -129,7 +129,7 @@ object ConstProp extends Pass { private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { - def isUInt(e: Expression): Boolean = tpe(e) match { + def isUInt(e: Expression): Boolean = e.tpe match { case UIntType(_) => true case _ => false } @@ -163,7 +163,7 @@ object ConstProp extends Pass { def range(e: Expression): Range = e match { case UIntLiteral(value, _) => Range(value, value) case SIntLiteral(value, _) => Range(value, value) - case _ => tpe(e) match { + case _ => e.tpe match { case SIntType(IntWidth(width)) => Range( min = BigInt(0) - BigInt(2).pow(width.toInt - 1), max = BigInt(2).pow(width.toInt - 1) - BigInt(1) @@ -226,7 +226,7 @@ object ConstProp extends Pass { case Pad => e.args(0) match { case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(e.consts(0))) case SIntLiteral(v, _) => SIntLiteral(v, IntWidth(e.consts(0))) - case _ if long_BANG(tpe(e.args(0))) == e.consts(0) => e.args(0) + case _ if long_BANG(e.args(0).tpe) == e.consts(0) => e.args(0) case _ => e } case Bits => e.args(0) match { @@ -234,9 +234,9 @@ object ConstProp extends Pass { val hi = e.consts(0).toInt val lo = e.consts(1).toInt require(hi >= lo) - UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e))) + UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(e.tpe)) } - case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match { + case x if long_BANG(e.tpe) == long_BANG(x.tpe) => x.tpe match { case t: UIntType => x case _ => asUInt(x, e.tpe) } diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 921693c7..3d26298a 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -131,8 +131,8 @@ object ExpandWhens extends Pass { val falseValue = altNetlist.getOrElse(lvalue, defaultValue) (trueValue, falseValue) match { case (WInvalid(), WInvalid()) => WInvalid() - case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, tpe(fv)) - case (tv, WInvalid()) => ValidIf(s.pred, tv, tpe(tv)) + case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe) + case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe) case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv)) } case None => diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 7793c85c..a8fda1bf 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -5,7 +5,6 @@ package passes import scala.collection.mutable import firrtl.Mappers.{ExpMap,StmtMap} -import firrtl.Utils.WithAs import firrtl.ir._ import firrtl.passes.{PassException,PassExceptions} import Annotations.{Loose, Unstable, Annotation, TransID, Named, ModuleName, ComponentName, CircuitName, AnnotationMap} diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 585598a8..a4c584ed 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -105,15 +105,15 @@ object LowerTypes extends Pass { require(tail.isEmpty) // there can't be a tail for these val memType = memDataTypeMap(mem.name) - if (memType.isGround) { - Seq(e) - } else { - val exps = create_exps(mem.name, memType) - exps map { e => - val loMemName = loweredName(e) - val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) - mergeRef(loMem, mergeRef(port, field)) - } + memType match { + case _: GroundType => Seq(e) + case _ => + val exps = create_exps(mem.name, memType) + exps map { e => + val loMemName = loweredName(e) + val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) + mergeRef(loMem, mergeRef(port, field)) + } } // Fields that need not be replicated for each // eg. mem.reader.data[0].a @@ -138,7 +138,7 @@ object LowerTypes extends Pass { case k: InstanceKind => val (root, tail) = splitRef(e) val name = loweredName(tail) - WSubField(root, name, tpe(e), gender(e)) + WSubField(root, name, e.tpe, gender(e)) case k: MemKind => val exps = lowerTypesMemExp(e) if (exps.length > 1) @@ -146,7 +146,7 @@ object LowerTypes extends Pass { " to be expanded!") exps(0) case k => - WRef(loweredName(e), tpe(e), kind(e), gender(e)) + WRef(loweredName(e), e.tpe, kind(e), gender(e)) } case e: Mux => e map (lowerTypesExp) case e: ValidIf => e map (lowerTypesExp) @@ -158,26 +158,26 @@ object LowerTypes extends Pass { s map lowerTypesStmt match { case s: DefWire => sinfo = s.info - if (s.tpe.isGround) { - s - } else { - val exps = create_exps(s.name, s.tpe) - val stmts = exps map (e => DefWire(s.info, loweredName(e), tpe(e))) - Block(stmts) + s.tpe match { + case _: GroundType => s + case _ => + val exps = create_exps(s.name, s.tpe) + val stmts = exps map (e => DefWire(s.info, loweredName(e), e.tpe)) + Block(stmts) } case s: DefRegister => sinfo = s.info - if (s.tpe.isGround) { - s map lowerTypesExp - } else { - val es = create_exps(s.name, s.tpe) - val inits = create_exps(s.init) map (lowerTypesExp) - val clock = lowerTypesExp(s.clock) - val reset = lowerTypesExp(s.reset) - val stmts = es zip inits map { case (e, i) => - DefRegister(s.info, loweredName(e), tpe(e), clock, reset, i) - } - Block(stmts) + s.tpe match { + case _: GroundType => s map lowerTypesExp + case _ => + val es = create_exps(s.name, s.tpe) + val inits = create_exps(s.init) map (lowerTypesExp) + val clock = lowerTypesExp(s.clock) + val reset = lowerTypesExp(s.reset) + val stmts = es zip inits map { case (e, i) => + DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i) + } + Block(stmts) } // Could instead just save the type of each Module as it gets processed case s: WDefInstance => @@ -188,7 +188,7 @@ object LowerTypes extends Pass { val exps = create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) exps map ( e => // Flip because inst genders are reversed from Module type - Field(loweredName(e), toFlip(gender(e)).flip, tpe(e)) + Field(loweredName(e), swap(to_flip(gender(e))), e.tpe) ) } WDefInstance(s.info, s.name, s.module, BundleType(fieldsx)) @@ -197,16 +197,16 @@ object LowerTypes extends Pass { case s: DefMemory => sinfo = s.info memDataTypeMap += (s.name -> s.dataType) - if (s.dataType.isGround) { - s - } else { - val exps = create_exps(s.name, s.dataType) - val stmts = exps map { e => - DefMemory(s.info, loweredName(e), tpe(e), s.depth, - s.writeLatency, s.readLatency, s.readers, s.writers, - s.readwriters) - } - Block(stmts) + s.dataType match { + case _: GroundType => s + case _ => + val exps = create_exps(s.name, s.dataType) + val stmts = exps map { e => + DefMemory(s.info, loweredName(e), e.tpe, s.depth, + s.writeLatency, s.readLatency, s.readers, s.writers, + s.readwriters) + } + Block(stmts) } // wire foo : { a , b } // node x = foo @@ -217,7 +217,7 @@ object LowerTypes extends Pass { // node y = x_a case s: DefNode => sinfo = s.info - val names = create_exps(s.name, tpe(s.value)) map (lowerTypesExp) + val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp) val exps = create_exps(s.value) map (lowerTypesExp) val stmts = names zip exps map { case (n, e) => DefNode(s.info, loweredName(n), e) @@ -249,7 +249,7 @@ object LowerTypes extends Pass { // Lower Ports val portsx = m.ports flatMap { p => val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) - exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), tpe(e)) ) + exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe) ) } m match { case m: ExtModule => m.copy(ports = portsx) diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 0cabc293..f2117761 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -2,7 +2,7 @@ package firrtl package passes import firrtl.Mappers.{ExpMap, StmtMap} -import firrtl.Utils.{tpe, long_BANG} +import firrtl.Utils.long_BANG import firrtl.PrimOps._ import firrtl.ir._ @@ -10,10 +10,10 @@ import firrtl.ir._ object PadWidths extends Pass { def name = "Pad Widths" private def width(t: Type): Int = long_BANG(t).toInt - private def width(e: Expression): Int = width(tpe(e)) + private def width(e: Expression): Int = width(e.tpe) // Returns an expression with the correct integer width private def fixup(i: Int)(e: Expression) = { - def tx = tpe(e) match { + def tx = e.tpe match { case t: UIntType => UIntType(IntWidth(i)) case t: SIntType => SIntType(IntWidth(i)) // default case should never be reached diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 7b4f9aa2..6b6dc811 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -103,7 +103,7 @@ object ResolveKinds extends Pass { def resolve (body:Statement) = { def resolve_expr (e:Expression):Expression = { e match { - case e:WRef => WRef(e.name,tpe(e),kinds(e.name),e.gender) + case e:WRef => WRef(e.name,e.tpe,kinds(e.name),e.gender) case e => e map (resolve_expr) } } @@ -170,11 +170,11 @@ object InferTypes extends Pass { val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { e map (infer_types_e) match { - case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value)) + case e:ValidIf => ValidIf(e.cond,e.value,e.value.tpe) case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender) - case e:WSubField => WSubField(e.exp,e.name,field_type(tpe(e.exp),e.name),e.gender) - case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(tpe(e.exp)),e.gender) - case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(tpe(e.exp)),e.gender) + case e:WSubField => WSubField(e.exp,e.name,field_type(e.exp.tpe,e.name),e.gender) + case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(e.exp.tpe),e.gender) + case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(e.exp.tpe),e.gender) case e:DoPrim => set_primop_type(e) case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval)) case e:UIntLiteral => e @@ -246,7 +246,7 @@ object ResolveGenders extends Pass { case e:WRef => WRef(e.name,e.tpe,e.kind,g) case e:WSubField => { val expx = - field_flip(tpe(e.exp),e.name) match { + field_flip(e.exp.tpe,e.name) match { case Default => resolve_e(g)(e.exp) case Flip => resolve_e(swap(g))(e.exp) } @@ -474,7 +474,7 @@ object InferWidths extends Pass { case (t:SIntType) => t.width case ClockType => IntWidth(1) case (t) => error("No width!"); IntWidth(-1) } } - def width_BANG (e:Expression) : Width = width_BANG(tpe(e)) + def width_BANG (e:Expression) : Width = width_BANG(e.tpe) def reduce_var_widths(c: Circuit, h: LinkedHashMap[String,Width]): Circuit = { def evaluate(w: Width): Width = { @@ -549,40 +549,40 @@ object InferWidths extends Pass { def get_constraints_e (e:Expression) : Expression = { (e map (get_constraints_e)) match { case (e:Mux) => { - constrain(width_BANG(e.cond),ONE) - constrain(ONE,width_BANG(e.cond)) + constrain(width_BANG(e.cond),IntWidth(1)) + constrain(IntWidth(1),width_BANG(e.cond)) e } case (e) => e }} def get_constraints (s:Statement) : Statement = { (s map (get_constraints_e)) match { case (s:Connect) => { - val n = get_size(tpe(s.loc)) + val n = get_size(s.loc.tpe) val ce_loc = create_exps(s.loc) val ce_exp = create_exps(s.expr) for (i <- 0 until n) { val locx = ce_loc(i) val expx = ce_exp(i) - get_flip(tpe(s.loc),i,Default) match { + get_flip(s.loc.tpe,i,Default) match { case Default => constrain(width_BANG(locx),width_BANG(expx)) case Flip => constrain(width_BANG(expx),width_BANG(locx)) }} s } case (s:PartialConnect) => { - val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) + val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) for (x <- ls) { val locx = create_exps(s.loc)(x._1) val expx = create_exps(s.expr)(x._2) - get_flip(tpe(s.loc),x._1,Default) match { + get_flip(s.loc.tpe,x._1,Default) match { case Default => constrain(width_BANG(locx),width_BANG(expx)) case Flip => constrain(width_BANG(expx),width_BANG(locx)) }} s } case (s:DefRegister) => { - constrain(width_BANG(s.reset),ONE) - constrain(ONE,width_BANG(s.reset)) - get_constraints_t(s.tpe,tpe(s.init),Default) + constrain(width_BANG(s.reset),IntWidth(1)) + constrain(IntWidth(1),width_BANG(s.reset)) + get_constraints_t(s.tpe,s.init.tpe,Default) s } case (s:Conditionally) => { - v += WGeq(width_BANG(s.pred),ONE) - v += WGeq(ONE,width_BANG(s.pred)) + v += WGeq(width_BANG(s.pred),IntWidth(1)) + v += WGeq(IntWidth(1),width_BANG(s.pred)) s map (get_constraints) } case (s) => s map (get_constraints) }} @@ -661,7 +661,7 @@ object ExpandConnects extends Pass { e map (set_gender) match { case (e:WRef) => WRef(e.name,e.tpe,e.kind,genders(e.name)) case (e:WSubField) => { - val f = get_field(tpe(e.exp),e.name) + val f = get_field(e.exp.tpe,e.name) val genderx = times(gender(e.exp),f.flip) WSubField(e.exp,e.name,e.tpe,genderx) } @@ -677,7 +677,7 @@ object ExpandConnects extends Pass { case (s:DefMemory) => { genders(s.name) = MALE; s } case (s:DefNode) => { genders(s.name) = MALE; s } case (s:IsInvalid) => { - val n = get_size(tpe(s.expr)) + val n = get_size(s.expr.tpe) val invalids = ArrayBuffer[Statement]() val exps = create_exps(s.expr) for (i <- 0 until n) { @@ -696,14 +696,14 @@ object ExpandConnects extends Pass { } else Block(invalids) } case (s:Connect) => { - val n = get_size(tpe(s.loc)) + val n = get_size(s.loc.tpe) val connects = ArrayBuffer[Statement]() val locs = create_exps(s.loc) val exps = create_exps(s.expr) for (i <- 0 until n) { val locx = locs(i) val expx = exps(i) - val sx = get_flip(tpe(s.loc),i,Default) match { + val sx = get_flip(s.loc.tpe,i,Default) match { case Default => Connect(s.info,locx,expx) case Flip => Connect(s.info,expx,locx) } @@ -712,14 +712,14 @@ object ExpandConnects extends Pass { Block(connects) } case (s:PartialConnect) => { - val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) + val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) val connects = ArrayBuffer[Statement]() val locs = create_exps(s.loc) val exps = create_exps(s.expr) ls.foreach { x => { val locx = locs(x._1) val expx = exps(x._2) - val sx = get_flip(tpe(s.loc),x._1,Default) match { + val sx = get_flip(s.loc.tpe,x._1,Default) match { case Default => Connect(s.info,locx,expx) case Flip => Connect(s.info,expx,locx) } @@ -755,7 +755,7 @@ object Legalize extends Pass { def legalizeShiftRight (e: DoPrim): Expression = e.op match { case Shr => { val amount = e.consts(0).toInt - val width = long_BANG(tpe(e.args(0))) + val width = long_BANG(e.args(0).tpe) lazy val msb = width - 1 if (amount >= width) { e.tpe match { @@ -771,9 +771,9 @@ object Legalize extends Pass { case _ => e } def legalizeConnect(c: Connect): Statement = { - val t = tpe(c.loc) + val t = c.loc.tpe val w = long_BANG(t) - if (w >= long_BANG(tpe(c.expr))) c + if (w >= long_BANG(c.expr.tpe)) c else { val newType = t match { case _: UIntType => UIntType(IntWidth(w)) @@ -811,8 +811,8 @@ object VerilogWrap extends Pass { if (e.op == Tail) { (a0()) match { case (e0:DoPrim) => { - if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),tpe(e)) - else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),tpe(e)) + if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),e.tpe) + else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),e.tpe) else e } case (e0) => e @@ -913,12 +913,12 @@ object CInferTypes extends Pass { def infer_types_e (e:Expression) : Expression = { e map infer_types_e match { case (e:Reference) => Reference(e.name, types.getOrElse(e.name,UnknownType)) - case (e:SubField) => SubField(e.expr,e.name,field_type(tpe(e.expr),e.name)) - case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(tpe(e.expr))) - case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(tpe(e.expr))) + case (e:SubField) => SubField(e.expr,e.name,field_type(e.expr.tpe,e.name)) + case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(e.expr.tpe)) + case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(e.expr.tpe)) case (e:DoPrim) => set_primop_type(e) case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval)) - case (e:ValidIf) => ValidIf(e.cond,e.value,tpe(e.value)) + case (e:ValidIf) => ValidIf(e.cond,e.value,e.value.tpe) case (_:UIntLiteral | _:SIntLiteral) => e } } @@ -1067,8 +1067,8 @@ object RemoveCHIRRTL extends Pass { val e2s = create_exps(e.fval) (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2))) case (e:ValidIf) => - create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1))) - case (e) => (tpe(e)) match { + create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe)) + case (e) => (e.tpe) match { case (_:GroundType) => Seq(e) case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => exps ++ create_exps(SubField(e,f.name,f.tpe))) @@ -1276,7 +1276,7 @@ object RemoveCHIRRTL extends Pass { case Some(en) => stmts += Connect(s.info,en,one) } if (has_write_mport) { - val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) + val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) val locs = create_exps(get_mask(s.loc)) for (x <- ls ) { val locx = locs(x._1) diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index a3ce49f7..880d6b1c 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -76,7 +76,7 @@ object RemoveAccesses extends Pass { def onStmt(s: Statement): Statement = { def create_temp(e: Expression): (Statement, Expression) = { val n = namespace.newTemp - (DefWire(info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e))) + (DefWire(get_info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e))) } /** Replaces a subaccess in a given male expression @@ -94,9 +94,9 @@ object RemoveAccesses extends Pass { stmts += wire rs.zipWithIndex foreach { case (x, i) if i < temps.size => - stmts += Connect(info(s),getTemp(i),x.base) + stmts += Connect(get_info(s),getTemp(i),x.base) case (x, i) => - stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt) + stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt) } temp } diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 1c9674e1..3b6021ed 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -2,7 +2,7 @@ package firrtl package passes import firrtl.Mappers.{ExpMap, StmtMap} -import firrtl.Utils.{tpe, kind, gender, info} +import firrtl.Utils.{kind, gender, get_info} import firrtl.ir._ import scala.collection.mutable @@ -20,18 +20,18 @@ object SplitExpressions extends Pass { def split(e: Expression): Expression = e match { case e: DoPrim => { val name = namespace.newTemp - v += DefNode(info(s), name, e) - WRef(name, tpe(e), kind(e), gender(e)) + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), gender(e)) } case e: Mux => { val name = namespace.newTemp - v += DefNode(info(s), name, e) - WRef(name, tpe(e), kind(e), gender(e)) + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), gender(e)) } case e: ValidIf => { val name = namespace.newTemp - v += DefNode(info(s), name, e) - WRef(name, tpe(e), kind(e), gender(e)) + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), gender(e)) } case e => e } diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index b1a20fdd..d034719a 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -109,8 +109,9 @@ object Uniquify extends Pass { val newName = findValidPrefix(f.name, Seq(""), namespace) namespace += newName Field(newName, f.flip, f.tpe) - } map { f => - if (f.tpe.isAggregate) { + } map { f => f.tpe match { + case _: GroundType => f + case _ => val tpe = recUniquifyNames(f.tpe, collection.mutable.HashSet()) val elts = enumerateNames(tpe) // Need leading _ for findValidPrefix, it doesn't add _ for checks @@ -123,8 +124,6 @@ object Uniquify extends Pass { } namespace ++= (elts map (e => LowerTypes.loweredName(prefix +: e))) Field(prefix, f.flip, tpe) - } else { - f } } BundleType(newFields) @@ -349,7 +348,9 @@ object Uniquify extends Pass { def uniquifyPorts(m: DefModule): DefModule = { def uniquifyPorts(ports: Seq[Port]): Seq[Port] = { - val portsType = BundleType(ports map (_.toField)) + val portsType = BundleType(ports map { + case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe) + }) val uniquePortsType = uniquifyNames(portsType, collection.mutable.HashSet()) val localMap = createNameMapping(portsType, uniquePortsType) portNameMap += (m.name -> localMap) |
