aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-09-07 13:14:22 -0700
committerGitHub2016-09-07 13:14:22 -0700
commit8647a25fec8c5e18d766ff3e3602d3345cd8549c (patch)
tree429f7acf1f95b0c1e3e9b9b1f2d528c49761356b /src
parent0c6db9ef0669e3fb92fcc0bda2085f934d065f0b (diff)
parentb1b977407d12878fb5d8ea92950888002beb258b (diff)
Merge pull request #271 from ucb-bar/cleanup_utils
Clean up Utils
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Emitter.scala91
-rw-r--r--src/main/scala/firrtl/PrimOps.scala38
-rw-r--r--src/main/scala/firrtl/Utils.scala1225
-rw-r--r--src/main/scala/firrtl/passes/CheckChirrtl.scala4
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala71
-rw-r--r--src/main/scala/firrtl/passes/ConstProp.scala10
-rw-r--r--src/main/scala/firrtl/passes/ExpandWhens.scala4
-rw-r--r--src/main/scala/firrtl/passes/Inline.scala1
-rw-r--r--src/main/scala/firrtl/passes/LowerTypes.scala82
-rw-r--r--src/main/scala/firrtl/passes/PadWidths.scala6
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala72
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala6
-rw-r--r--src/main/scala/firrtl/passes/SplitExpressions.scala14
-rw-r--r--src/main/scala/firrtl/passes/Uniquify.scala11
14 files changed, 617 insertions, 1018 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala
index a4f5c14d..6c658257 100644
--- a/src/main/scala/firrtl/Emitter.scala
+++ b/src/main/scala/firrtl/Emitter.scala
@@ -68,9 +68,9 @@ class VerilogEmitter extends Emitter {
var mname = ""
def wref (n:String,t:Type) = WRef(n,t,ExpKind(),UNKNOWNGENDER)
def remove_root (ex:Expression) : Expression = {
- (ex.as[WSubField].get.exp) match {
+ (ex.asInstanceOf[WSubField].exp) match {
case (e:WSubField) => remove_root(e)
- case (e:WRef) => WRef(ex.as[WSubField].get.name,tpe(ex),InstanceKind(),UNKNOWNGENDER)
+ case (e:WRef) => WRef(ex.asInstanceOf[WSubField].name,ex.tpe,InstanceKind(),UNKNOWNGENDER)
}
}
def not_empty (s:ArrayBuffer[_]) : Boolean = if (s.size == 0) false else true
@@ -120,7 +120,7 @@ class VerilogEmitter extends Emitter {
case (i:Long) => w.get.write(i.toString)
case (t:VIndent) => w.get.write(" ")
case (s:Seq[Any]) => {
- s.foreach((x:Any) => emit2(x.as[Any].get, top + 1))
+ s.foreach((x:Any) => emit2(x, top + 1))
if (top == 0) w.get.write("\n")
}
}
@@ -142,9 +142,12 @@ class VerilogEmitter extends Emitter {
}
def op_stream (doprim:DoPrim) : Seq[Any] = {
def cast_if (e:Expression) : Any = {
- val signed = doprim.args.find(x => tpe(x).typeof[SIntType])
+ val signed = doprim.args.find(x => x.tpe match {
+ case _: SIntType => true
+ case _ => false
+ })
if (signed == None) e
- else tpe(e) match {
+ else e.tpe match {
case (t:SIntType) => Seq("$signed(",e,")")
case (t:UIntType) => Seq("$signed({1'b0,",e,"})")
}
@@ -156,7 +159,7 @@ class VerilogEmitter extends Emitter {
}
}
def cast_as (e:Expression) : Any = {
- (tpe(e)) match {
+ (e.tpe) match {
case (t:UIntType) => e
case (t:SIntType) => Seq("$signed(",e,")")
}
@@ -192,7 +195,7 @@ class VerilogEmitter extends Emitter {
case Eq => Seq(cast_if(a0())," == ", cast_if(a1()))
case Neq => Seq(cast_if(a0())," != ", cast_if(a1()))
case Pad => {
- val w = long_BANG(tpe(a0()))
+ val w = long_BANG(a0().tpe)
val diff = (c0() - w)
if (w == 0) Seq(a0())
else doprim.tpe match {
@@ -219,13 +222,13 @@ class VerilogEmitter extends Emitter {
case Shlw => Seq(cast(a0())," << ", c0())
case Shl => Seq(cast(a0())," << ",c0())
case Shr => {
- if (c0 >= long_BANG(tpe(a0)))
+ if (c0 >= long_BANG(a0.tpe))
error("Verilog emitter does not support SHIFT_RIGHT >= arg width")
- Seq(a0(),"[", long_BANG(tpe(a0())) - 1,":",c0(),"]")
+ Seq(a0(),"[", long_BANG(a0().tpe) - 1,":",c0(),"]")
}
case Neg => Seq("-{",cast(a0()),"}")
case Cvt => {
- tpe(a0()) match {
+ a0().tpe match {
case (t:UIntType) => Seq("{1'b0,",cast(a0()),"}")
case (t:SIntType) => Seq(cast(a0()))
}
@@ -258,18 +261,18 @@ class VerilogEmitter extends Emitter {
case Cat => Seq("{",cast(a0()),",",cast(a1()),"}")
case Bits => {
// If selecting zeroth bit and single-bit wire, just emit the wire
- if (c0() == 0 && c1() == 0 && long_BANG(tpe(a0())) == 1) Seq(a0())
+ if (c0() == 0 && c1() == 0 && long_BANG(a0().tpe) == 1) Seq(a0())
else if (c0() == c1()) Seq(a0(),"[",c0(),"]")
else Seq(a0(),"[",c0(),":",c1(),"]")
}
case Head => {
- val w = long_BANG(tpe(a0()))
+ val w = long_BANG(a0().tpe)
val high = w - 1
val low = w - c0()
Seq(a0(),"[",high,":",low,"]")
}
case Tail => {
- val w = long_BANG(tpe(a0()))
+ val w = long_BANG(a0().tpe)
val low = w - c0() - 1
Seq(a0(),"[",low,":",0,"]")
}
@@ -286,7 +289,7 @@ class VerilogEmitter extends Emitter {
case (s:Connect) => netlist(s.loc) = s.expr
case (s:IsInvalid) => {
val n = namespace.newTemp
- val e = wref(n,tpe(s.expr))
+ val e = wref(n,s.expr.tpe)
netlist(s.expr) = e
}
case (s:Conditionally) => simlist += s
@@ -319,12 +322,12 @@ class VerilogEmitter extends Emitter {
assigns += Seq("`ifndef RANDOMIZE_GARBAGE_ASSIGN")
assigns += Seq("assign ", e, " = ", syn, ";")
assigns += Seq("`else")
- assigns += Seq("assign ", e, " = ", garbageCond, " ? ", rand_string(tpe(syn)), " : ", syn, ";")
+ assigns += Seq("assign ", e, " = ", garbageCond, " ? ", rand_string(syn.tpe), " : ", syn, ";")
assigns += Seq("`endif")
}
def invalidAssign(e: Expression) = {
assigns += Seq("`ifdef RANDOMIZE_INVALID_ASSIGN")
- assigns += Seq("assign ", e, " = ", rand_string(tpe(e)), ";")
+ assigns += Seq("assign ", e, " = ", rand_string(e.tpe), ";")
assigns += Seq("`endif")
}
def update_and_reset(r: Expression, clk: Expression, reset: Expression, init: Expression) = {
@@ -387,7 +390,7 @@ class VerilogEmitter extends Emitter {
}
def initialize(e: Expression) = {
initials += Seq("`ifdef RANDOMIZE_REG_INIT")
- initials += Seq(e, " = ", rand_string(tpe(e)), ";")
+ initials += Seq(e, " = ", rand_string(e.tpe), ";")
initials += Seq("`endif")
}
def initialize_mem(s: DefMemory) = {
@@ -407,8 +410,8 @@ class VerilogEmitter extends Emitter {
}}
instdeclares += Seq(");")
for (e <- es) {
- declare("wire",LowerTypes.loweredName(e),tpe(e))
- val ex = WRef(LowerTypes.loweredName(e),tpe(e),kind(e),gender(e))
+ declare("wire",LowerTypes.loweredName(e),e.tpe)
+ val ex = WRef(LowerTypes.loweredName(e),e.tpe,kind(e),gender(e))
if (gender(e) == FEMALE) {
assign(ex,netlist(e))
}
@@ -444,8 +447,8 @@ class VerilogEmitter extends Emitter {
def delay (e:Expression, n:Int, clk:Expression) : Expression = {
((0 until n) foldLeft e){(ex, i) =>
val name = namespace.newTemp
- declare("reg",name,tpe(e))
- val exx = WRef(name,tpe(e),ExpKind(),UNKNOWNGENDER)
+ declare("reg",name,e.tpe)
+ val exx = WRef(name,e.tpe,ExpKind(),UNKNOWNGENDER)
initialize(exx)
update(exx,ex,clk,one)
exx
@@ -478,13 +481,13 @@ class VerilogEmitter extends Emitter {
initialize(e)
}
case (s:IsInvalid) => {
- val wref = netlist(s.expr).as[WRef].get
- declare("wire",wref.name,tpe(s.expr))
+ val wref = netlist(s.expr).asInstanceOf[WRef]
+ declare("wire",wref.name,s.expr.tpe)
invalidAssign(wref)
}
case (s:DefNode) => {
- declare("wire",s.name,tpe(s.value))
- assign(WRef(s.name,tpe(s.value),NodeKind(),MALE),s.value)
+ declare("wire",s.name,s.value.tpe)
+ assign(WRef(s.name,s.value.tpe,NodeKind(),MALE),s.value)
}
case (s:Stop) => {
val errorString = StringLit(s"${s.ret}\n".getBytes)
@@ -513,9 +516,9 @@ class VerilogEmitter extends Emitter {
//Ports should share an always@posedge, so can't have intermediary wire
val clk = netlist(mem_exp(r,"clk"))
- declare("wire",LowerTypes.loweredName(data),tpe(data))
- declare("wire",LowerTypes.loweredName(addr),tpe(addr))
- declare("wire",LowerTypes.loweredName(en),tpe(en))
+ declare("wire",LowerTypes.loweredName(data),data.tpe)
+ declare("wire",LowerTypes.loweredName(addr),addr.tpe)
+ declare("wire",LowerTypes.loweredName(en),en.tpe)
//; Read port
assign(addr,netlist(addr)) //;Connects value to m.r.addr
@@ -524,8 +527,8 @@ class VerilogEmitter extends Emitter {
val en_pipe = if (weq(en,one)) one else delay(en,s.readLatency-1,clk)
val addrx = if (s.readLatency > 0) {
val name = namespace.newTemp
- val ref = WRef(name,tpe(addr),ExpKind(),UNKNOWNGENDER)
- declare("reg",name,tpe(addr))
+ val ref = WRef(name,addr.tpe,ExpKind(),UNKNOWNGENDER)
+ declare("reg",name,addr.tpe)
initialize(ref)
update(ref,addr_pipe,clk,en_pipe)
ref
@@ -548,10 +551,10 @@ class VerilogEmitter extends Emitter {
//Ports should share an always@posedge, so can't have intermediary wire
val clk = netlist(mem_exp(w,"clk"))
- declare("wire",LowerTypes.loweredName(data),tpe(data))
- declare("wire",LowerTypes.loweredName(addr),tpe(addr))
- declare("wire",LowerTypes.loweredName(mask),tpe(mask))
- declare("wire",LowerTypes.loweredName(en),tpe(en))
+ declare("wire",LowerTypes.loweredName(data),data.tpe)
+ declare("wire",LowerTypes.loweredName(addr),addr.tpe)
+ declare("wire",LowerTypes.loweredName(mask),mask.tpe)
+ declare("wire",LowerTypes.loweredName(en),en.tpe)
//; Write port
assign(data,netlist(data))
@@ -577,12 +580,12 @@ class VerilogEmitter extends Emitter {
//Ports should share an always@posedge, so can't have intermediary wire
val clk = netlist(mem_exp(rw,"clk"))
- declare("wire",LowerTypes.loweredName(wmode),tpe(wmode))
- declare("wire",LowerTypes.loweredName(rdata),tpe(rdata))
- declare("wire",LowerTypes.loweredName(wdata),tpe(wdata))
- declare("wire",LowerTypes.loweredName(wmask),tpe(wmask))
- declare("wire",LowerTypes.loweredName(addr),tpe(addr))
- declare("wire",LowerTypes.loweredName(en),tpe(en))
+ declare("wire",LowerTypes.loweredName(wmode),wmode.tpe)
+ declare("wire",LowerTypes.loweredName(rdata),rdata.tpe)
+ declare("wire",LowerTypes.loweredName(wdata),wdata.tpe)
+ declare("wire",LowerTypes.loweredName(wmask),wmask.tpe)
+ declare("wire",LowerTypes.loweredName(addr),addr.tpe)
+ declare("wire",LowerTypes.loweredName(en),en.tpe)
//; Assigned to lowered wires of each
assign(addr,netlist(addr))
@@ -602,8 +605,8 @@ class VerilogEmitter extends Emitter {
val raddrxx = if (s.readLatency > 0) {
val name = namespace.newTemp
- val ref = WRef(name,tpe(raddrx),ExpKind(),UNKNOWNGENDER)
- declare("reg",name,tpe(raddrx))
+ val ref = WRef(name,raddrx.tpe,ExpKind(),UNKNOWNGENDER)
+ declare("reg",name,raddrx.tpe)
initialize(ref)
ref
} else addr
@@ -613,8 +616,8 @@ class VerilogEmitter extends Emitter {
def declare_and_assign(exp: Expression) = {
val name = namespace.newTemp
- val ref = wref(name, tpe(exp))
- declare("wire", name, tpe(exp))
+ val ref = wref(name, exp.tpe)
+ declare("wire", name, exp.tpe)
assign(ref, exp)
ref
}
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala
index 1bf8947a..8b705b29 100644
--- a/src/main/scala/firrtl/PrimOps.scala
+++ b/src/main/scala/firrtl/PrimOps.scala
@@ -146,20 +146,20 @@ object PrimOps extends LazyLogging {
o match {
case Add => {
val t = (t1(),t2()) match {
- case (t1:UIntType, t2:UIntType) => UIntType(PLUS(MAX(w1(),w2()),Utils.ONE))
- case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE))
- case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE))
- case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE))
+ case (t1:UIntType, t2:UIntType) => UIntType(PLUS(MAX(w1(),w2()),IntWidth(1)))
+ case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1)))
+ case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1)))
+ case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1)))
case (t1, t2) => UnknownType
}
DoPrim(o,a,c,t)
}
case Sub => {
val t = (t1(),t2()) match {
- case (t1:UIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE))
- case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE))
- case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE))
- case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE))
+ case (t1:UIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1)))
+ case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1)))
+ case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1)))
+ case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),IntWidth(1)))
case (t1, t2) => UnknownType
}
DoPrim(o,a,c,t)
@@ -177,9 +177,9 @@ object PrimOps extends LazyLogging {
case Div => {
val t = (t1(),t2()) match {
case (t1:UIntType, t2:UIntType) => UIntType(w1())
- case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),Utils.ONE))
+ case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),IntWidth(1)))
case (t1:SIntType, t2:UIntType) => SIntType(w1())
- case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),Utils.ONE))
+ case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),IntWidth(1)))
case (t1, t2) => UnknownType
}
DoPrim(o,a,c,t)
@@ -188,7 +188,7 @@ object PrimOps extends LazyLogging {
val t = (t1(),t2()) match {
case (t1:UIntType, t2:UIntType) => UIntType(MIN(w1(),w2()))
case (t1:UIntType, t2:SIntType) => UIntType(MIN(w1(),w2()))
- case (t1:SIntType, t2:UIntType) => SIntType(MIN(w1(),PLUS(w2(),Utils.ONE)))
+ case (t1:SIntType, t2:UIntType) => SIntType(MIN(w1(),PLUS(w2(),IntWidth(1))))
case (t1:SIntType, t2:SIntType) => SIntType(MIN(w1(),w2()))
case (t1, t2) => UnknownType
}
@@ -266,7 +266,7 @@ object PrimOps extends LazyLogging {
val t = (t1()) match {
case (t1:UIntType) => UIntType(w1())
case (t1:SIntType) => UIntType(w1())
- case ClockType => UIntType(Utils.ONE)
+ case ClockType => UIntType(IntWidth(1))
case (t1) => UnknownType
}
DoPrim(o,a,c,t)
@@ -275,7 +275,7 @@ object PrimOps extends LazyLogging {
val t = (t1()) match {
case (t1:UIntType) => SIntType(w1())
case (t1:SIntType) => SIntType(w1())
- case ClockType => SIntType(Utils.ONE)
+ case ClockType => SIntType(IntWidth(1))
case (t1) => UnknownType
}
DoPrim(o,a,c,t)
@@ -299,8 +299,8 @@ object PrimOps extends LazyLogging {
}
case Shr => {
val t = (t1()) match {
- case (t1:UIntType) => UIntType(MAX(MINUS(w1(),c1()),Utils.ONE))
- case (t1:SIntType) => SIntType(MAX(MINUS(w1(),c1()),Utils.ONE))
+ case (t1:UIntType) => UIntType(MAX(MINUS(w1(),c1()),IntWidth(1)))
+ case (t1:SIntType) => SIntType(MAX(MINUS(w1(),c1()),IntWidth(1)))
case (t1) => UnknownType
}
DoPrim(o,a,c,t)
@@ -323,7 +323,7 @@ object PrimOps extends LazyLogging {
}
case Cvt => {
val t = (t1()) match {
- case (t1:UIntType) => SIntType(PLUS(w1(),Utils.ONE))
+ case (t1:UIntType) => SIntType(PLUS(w1(),IntWidth(1)))
case (t1:SIntType) => SIntType(w1())
case (t1) => UnknownType
}
@@ -331,8 +331,8 @@ object PrimOps extends LazyLogging {
}
case Neg => {
val t = (t1()) match {
- case (t1:UIntType) => SIntType(PLUS(w1(),Utils.ONE))
- case (t1:SIntType) => SIntType(PLUS(w1(),Utils.ONE))
+ case (t1:UIntType) => SIntType(PLUS(w1(),IntWidth(1)))
+ case (t1:SIntType) => SIntType(PLUS(w1(),IntWidth(1)))
case (t1) => UnknownType
}
DoPrim(o,a,c,t)
@@ -396,7 +396,7 @@ object PrimOps extends LazyLogging {
}
case Bits => {
val t = (t1()) match {
- case (_:UIntType|_:SIntType) => UIntType(PLUS(MINUS(c1(),c2()),Utils.ONE))
+ case (_:UIntType|_:SIntType) => UIntType(PLUS(MINUS(c1(),c2()),IntWidth(1)))
case (t1) => UnknownType
}
DoPrim(o,a,c,t)
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 1db8ce78..9404e5e2 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -36,16 +36,14 @@ MODIFICATIONS.
package firrtl
-import scala.collection.mutable.StringBuilder
+import firrtl.ir._
+import firrtl.PrimOps._
+import firrtl.Mappers._
+import firrtl.WrappedExpression._
+import firrtl.WrappedType._
+import scala.collection.mutable.{StringBuilder, ArrayBuffer, LinkedHashMap, HashMap, HashSet}
import java.io.PrintWriter
import com.typesafe.scalalogging.LazyLogging
-import WrappedExpression._
-import firrtl.WrappedType._
-import firrtl.Mappers._
-import firrtl.PrimOps._
-import firrtl.ir._
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.LinkedHashMap
//import scala.reflect.runtime.universe._
class FIRRTLException(str: String) extends Exception(str)
@@ -82,535 +80,361 @@ object Utils extends LazyLogging {
if (bi < BigInt(0)) "\"h" + bi.toString(16).substring(1) + "\""
else "\"h" + bi.toString(16) + "\""
- implicit class WithAs[T](x: T) {
- import scala.reflect._
- def as[O: ClassTag]: Option[O] = x match {
- case o: O => Some(o)
- case _ => None }
- def typeof[O: ClassTag]: Boolean = x match {
- case o: O => true
- case _ => false }
- }
- implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x)
- def ceil_log2(x: BigInt): BigInt = (x-1).bitLength
- def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt
- def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b
- def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a
- def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1
- val gen_names = Map[String,Int]()
- val delin = "_"
- val BoolType = UIntType(IntWidth(1))
- val one = UIntLiteral(BigInt(1),IntWidth(1))
- val zero = UIntLiteral(BigInt(0),IntWidth(1))
- def uint (i:Int) : UIntLiteral = {
- val num_bits = req_num_bits(i)
- val w = IntWidth(scala.math.max(1,num_bits - 1))
- UIntLiteral(BigInt(i),w)
- }
- def req_num_bits (i: Int) : Int = {
- val ix = if (i < 0) ((-1 * i) - 1) else i
- ceil_log2(ix + 1) + 1
- }
- def AND (e1:WrappedExpression,e2:WrappedExpression) : Expression = {
- if (e1 == e2) e1.e1
- else if ((e1 == we(zero)) | (e2 == we(zero))) zero
- else if (e1 == we(one)) e2.e1
- else if (e2 == we(one)) e1.e1
- else DoPrim(And,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1)))
- }
-
- def OR (e1:WrappedExpression,e2:WrappedExpression) : Expression = {
- if (e1 == e2) e1.e1
- else if ((e1 == we(one)) | (e2 == we(one))) one
- else if (e1 == we(zero)) e2.e1
- else if (e2 == we(zero)) e1.e1
- else DoPrim(Or,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1)))
- }
- def EQV (e1:Expression,e2:Expression) : Expression = { DoPrim(Eq,Seq(e1,e2),Seq(),tpe(e1)) }
- def NOT (e1:WrappedExpression) : Expression = {
- if (e1 == we(one)) zero
- else if (e1 == we(zero)) one
- else DoPrim(Eq,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1)))
- }
+ implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x)
+ def ceil_log2(x: BigInt): BigInt = (x-1).bitLength
+ def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt
+ def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b
+ def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a
+ def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1
+ val BoolType = UIntType(IntWidth(1))
+ val one = UIntLiteral(BigInt(1),IntWidth(1))
+ val zero = UIntLiteral(BigInt(0),IntWidth(1))
+ def uint (i:Int) : UIntLiteral = {
+ val num_bits = req_num_bits(i)
+ val w = IntWidth(scala.math.max(1,num_bits - 1))
+ UIntLiteral(BigInt(i),w)
+ }
+ def req_num_bits (i: Int) : Int = {
+ val ix = if (i < 0) ((-1 * i) - 1) else i
+ ceil_log2(ix + 1) + 1
+ }
+ def EQV (e1:Expression,e2:Expression) : Expression =
+ DoPrim(Eq, Seq(e1, e2), Nil, e1.tpe)
+ // TODO: these should be fixed
+ def AND (e1:WrappedExpression,e2:WrappedExpression) : Expression = {
+ if (e1 == e2) e1.e1
+ else if ((e1 == we(zero)) | (e2 == we(zero))) zero
+ else if (e1 == we(one)) e2.e1
+ else if (e2 == we(one)) e1.e1
+ else DoPrim(And,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1)))
+ }
+ def OR (e1:WrappedExpression,e2:WrappedExpression) : Expression = {
+ if (e1 == e2) e1.e1
+ else if ((e1 == we(one)) | (e2 == we(one))) one
+ else if (e1 == we(zero)) e2.e1
+ else if (e2 == we(zero)) e1.e1
+ else DoPrim(Or,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1)))
+ }
+ def NOT (e1:WrappedExpression) : Expression = {
+ if (e1 == we(one)) zero
+ else if (e1 == we(zero)) one
+ else DoPrim(Eq,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1)))
+ }
-
- //def MUX (p:Expression,e1:Expression,e2:Expression) : Expression = {
- // Mux(p,e1,e2,mux_type(tpe(e1),tpe(e2)))
- //}
+ def create_mask(dt: Type): Type = dt match {
+ case t: VectorType => VectorType(create_mask(t.tpe),t.size)
+ case t: BundleType => BundleType(t.fields.map (f => f.copy(tpe=create_mask(f.tpe))))
+ case t: UIntType => BoolType
+ case t: SIntType => BoolType
+ }
- def create_mask (dt:Type) : Type = {
- dt match {
- case t:VectorType => VectorType(create_mask(t.tpe),t.size)
- case t:BundleType => {
- val fieldss = t.fields.map { f => Field(f.name,f.flip,create_mask(f.tpe)) }
- BundleType(fieldss)
- }
- case t:UIntType => BoolType
- case t:SIntType => BoolType
- }
- }
- def create_exps (n:String, t:Type) : Seq[Expression] =
- create_exps(WRef(n,t,ExpKind(),UNKNOWNGENDER))
- def create_exps (e:Expression) : Seq[Expression] = e match {
- case (e:Mux) =>
- val e1s = create_exps(e.tval)
- val e2s = create_exps(e.fval)
- (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type_and_widths(e1,e2)))
- case (e:ValidIf) => create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1)))
- case (e) => tpe(e) match {
- case (_:GroundType) => Seq(e)
- case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
- exps ++ create_exps(WSubField(e,f.name,f.tpe,times(gender(e), f.flip))))
- case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
- exps ++ create_exps(WSubIndex(e,i,t.tpe,gender(e))))
+ def create_exps(n: String, t: Type): Seq[Expression] =
+ create_exps(WRef(n, t, ExpKind(), UNKNOWNGENDER))
+ def create_exps(e: Expression): Seq[Expression] = e match {
+ case (e: Mux) =>
+ val e1s = create_exps(e.tval)
+ val e2s = create_exps(e.fval)
+ e1s zip e2s map {case (e1, e2) =>
+ Mux(e.cond, e1, e2, mux_type_and_widths(e1,e2))
}
- }
- def get_flip (t:Type, i:Int, f:Orientation) : Orientation = {
- if (i >= get_size(t)) error("Shouldn't be here")
- val x = t match {
- case (t:UIntType) => f
- case (t:SIntType) => f
- case ClockType => f
- case (t:BundleType) => {
- var n = i
- var ret:Option[Orientation] = None
- t.fields.foreach { x => {
- if (n < get_size(x.tpe)) {
- ret match {
- case None => ret = Some(get_flip(x.tpe,n,times(x.flip,f)))
- case ret => {}
- }
- } else { n = n - get_size(x.tpe) }
- }}
- ret.asInstanceOf[Some[Orientation]].get
- }
- case (t:VectorType) => {
- var n = i
- var ret:Option[Orientation] = None
- for (j <- 0 until t.size) {
- if (n < get_size(t.tpe)) {
- ret = Some(get_flip(t.tpe,n,f))
- } else {
- n = n - get_size(t.tpe)
- }
+ case (e: ValidIf) => create_exps(e.value) map (e1 => ValidIf(e.cond, e1, e1.tpe))
+ case (e) => e.tpe match {
+ case (_: GroundType) => Seq(e)
+ case (t: BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
+ exps ++ create_exps(WSubField(e, f.name, f.tpe,times(gender(e), f.flip))))
+ case (t: VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
+ exps ++ create_exps(WSubIndex(e, i, t.tpe,gender(e))))
+ }
+ }
+ def get_flip(t: Type, i: Int, f: Orientation): Orientation = {
+ if (i >= get_size(t)) error("Shouldn't be here")
+ t match {
+ case (_: GroundType) => f
+ case (t: BundleType) =>
+ val (_, flip) = ((t.fields foldLeft (i, None: Option[Orientation])){
+ case ((n, ret), x) if n < get_size(x.tpe) => ret match {
+ case None => (n, Some(get_flip(x.tpe,n,times(x.flip,f))))
+ case Some(_) => (n, ret)
}
- ret.asInstanceOf[Some[Orientation]].get
- }
- }
- x
- }
-
- def get_point (e:Expression) : Int = {
- e match {
- case (e:WRef) => 0
- case (e:WSubField) => {
- var i = 0
- tpe(e.exp).asInstanceOf[BundleType].fields.find { f => {
- val b = f.name == e.name
- if (!b) { i = i + get_size(f.tpe)}
- b
- }}
- i
- }
- case (e:WSubIndex) => e.value * get_size(e.tpe)
- case (e:WSubAccess) => get_point(e.exp)
- }
+ case ((n, ret), x) => (n - get_size(x.tpe), ret)
+ })
+ flip.get
+ case (t: VectorType) =>
+ val (_, flip) = (((0 until t.size) foldLeft (i, None: Option[Orientation])){
+ case ((n, ret), x) if n < get_size(t.tpe) => ret match {
+ case None => (n, Some(get_flip(t.tpe,n,f)))
+ case Some(_) => (n, ret)
+ }
+ case ((n, ret), x) => (n - get_size(t.tpe), ret)
+ })
+ flip.get
+ }
}
+
+ def get_point (e:Expression) : Int = e match {
+ case (e: WRef) => 0
+ case (e: WSubField) => e.exp.tpe match {case b: BundleType =>
+ (b.fields takeWhile (_.name != e.name) foldLeft 0)(
+ (point, f) => point + get_size(f.tpe))
+ }
+ case (e: WSubIndex) => e.value * get_size(e.tpe)
+ case (e: WSubAccess) => get_point(e.exp)
+ }
- /** Returns true if t, or any subtype, contains a flipped field
- * @param t [[firrtl.ir.Type]]
- * @return if t contains [[firrtl.ir.Flip]]
- */
- def hasFlip(t: Type): Boolean = {
- var has = false
- def findFlip(t: Type): Type = t map (findFlip) match {
- case t: BundleType =>
- for (f <- t.fields) { if (f.flip == Flip) has = true }
- t
- case t: Type => t
- }
- findFlip(t)
- has
- }
+ /** Returns true if t, or any subtype, contains a flipped field
+ * @param t [[firrtl.ir.Type]]
+ * @return if t contains [[firrtl.ir.Flip]]
+ */
+ def hasFlip(t: Type): Boolean = t match {
+ case t: BundleType =>
+ (t.fields exists (_.flip == Flip)) ||
+ (t.fields exists (f => hasFlip(f.tpe)))
+ case t: VectorType => hasFlip(t.tpe)
+ case _ => false
+ }
//============== TYPES ================
- def mux_type (e1:Expression,e2:Expression) : Type = mux_type(tpe(e1),tpe(e2))
- def mux_type (t1:Type,t2:Type) : Type = {
- if (wt(t1) == wt(t2)) {
- (t1,t2) match {
- case (t1:UIntType,t2:UIntType) => UIntType(UnknownWidth)
- case (t1:SIntType,t2:SIntType) => SIntType(UnknownWidth)
- case (t1:VectorType,t2:VectorType) => VectorType(mux_type(t1.tpe,t2.tpe),t1.size)
- case (t1:BundleType,t2:BundleType) =>
- BundleType((t1.fields,t2.fields).zipped.map((f1,f2) => {
- Field(f1.name,f1.flip,mux_type(f1.tpe,f2.tpe))
- }))
- }
- } else UnknownType
- }
- def mux_type_and_widths (e1:Expression,e2:Expression) : Type = mux_type_and_widths(tpe(e1),tpe(e2))
- def mux_type_and_widths (t1:Type,t2:Type) : Type = {
- def wmax (w1:Width,w2:Width) : Width = {
- (w1,w2) match {
- case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width.max(w2.width))
- case (w1,w2) => MaxWidth(Seq(w1,w2))
- }
- }
- val wt1 = new WrappedType(t1)
- val wt2 = new WrappedType(t2)
- if (wt1 == wt2) {
- (t1,t2) match {
- case (t1:UIntType,t2:UIntType) => UIntType(wmax(t1.width,t2.width))
- case (t1:SIntType,t2:SIntType) => SIntType(wmax(t1.width,t2.width))
- case (t1:VectorType,t2:VectorType) => VectorType(mux_type_and_widths(t1.tpe,t2.tpe),t1.size)
- case (t1:BundleType,t2:BundleType) => BundleType((t1.fields zip t2.fields).map{case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe,f2.tpe))})
- }
- } else UnknownType
- }
- def module_type (m:DefModule) : Type = {
- BundleType(m.ports.map(p => p.toField))
- }
- def sub_type (v:Type) : Type = {
- v match {
- case v:VectorType => v.tpe
- case v => UnknownType
- }
- }
- def field_type (v:Type,s:String) : Type = {
- v match {
- case v:BundleType => {
- val ft = v.fields.find(p => p.name == s)
- ft match {
- case ft:Some[Field] => ft.get.tpe
- case ft => UnknownType
- }
- }
- case v => UnknownType
- }
- }
+ def mux_type (e1:Expression, e2:Expression) : Type = mux_type(e1.tpe, e2.tpe)
+ def mux_type (t1:Type, t2:Type) : Type = (t1,t2) match {
+ case (t1:UIntType, t2:UIntType) => UIntType(UnknownWidth)
+ case (t1:SIntType, t2:SIntType) => SIntType(UnknownWidth)
+ case (t1:VectorType, t2:VectorType) => VectorType(mux_type(t1.tpe, t2.tpe), t1.size)
+ case (t1:BundleType, t2:BundleType) => BundleType((t1.fields zip t2.fields) map {
+ case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe))
+ })
+ case _ => UnknownType
+ }
+ def mux_type_and_widths (e1:Expression,e2:Expression) : Type =
+ mux_type_and_widths(e1.tpe, e2.tpe)
+ def mux_type_and_widths (t1:Type, t2:Type) : Type = {
+ def wmax (w1:Width, w2:Width) : Width = (w1,w2) match {
+ case (w1:IntWidth, w2:IntWidth) => IntWidth(w1.width max w2.width)
+ case (w1, w2) => MaxWidth(Seq(w1, w2))
+ }
+ (t1, t2) match {
+ case (t1:UIntType, t2:UIntType) => UIntType(wmax(t1.width, t2.width))
+ case (t1:SIntType, t2:SIntType) => SIntType(wmax(t1.width, t2.width))
+ case (t1:VectorType, t2:VectorType) => VectorType(
+ mux_type_and_widths(t1.tpe, t2.tpe), t1.size)
+ case (t1:BundleType, t2:BundleType) => BundleType((t1.fields zip t2.fields) map {
+ case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe, f2.tpe))
+ })
+ case _ => UnknownType
+ }
+ }
+
+ def module_type(m: DefModule): Type = BundleType(m.ports map {
+ case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe)
+ })
+ def sub_type(v: Type): Type = v match {
+ case v: VectorType => v.tpe
+ case v => UnknownType
+ }
+ def field_type(v:Type, s: String) : Type = v match {
+ case v: BundleType => v.fields find (_.name == s) match {
+ case Some(f) => f.tpe
+ case None => UnknownType
+ }
+ case v => UnknownType
+ }
////=====================================
- def widthBANG (t:Type) : Width = {
- t match {
- case g: GroundType => g.width
- case t => error("No width!")
- }
- }
- def long_BANG (t:Type) : Long = {
- (t) match {
- case g: GroundType =>
- g.width match {
- case IntWidth(x) => x.toLong
- case _ => throw new FIRRTLException(s"Expecting IntWidth, got: ${g.width}")
- }
- case (t:BundleType) => {
- var w = 0
- for (f <- t.fields) { w = w + long_BANG(f.tpe).toInt }
- w
- }
- case (t:VectorType) => t.size * long_BANG(t.tpe)
- }
- }
-// =================================
- def error(str:String) = throw new FIRRTLException(str)
+ def widthBANG (t:Type) : Width = t match {
+ case g: GroundType => g.width
+ case t => error("No width!")
+ }
+ def long_BANG(t: Type): Long = t match {
+ case (g: GroundType) => g.width match {
+ case IntWidth(x) => x.toLong
+ case _ => error(s"Expecting IntWidth, got: ${g.width}")
+ }
+ case (t: BundleType) => (t.fields foldLeft 0)((w, f) =>
+ w + long_BANG(f.tpe).toInt)
+ case (t: VectorType) => t.size * long_BANG(t.tpe)
+ }
- implicit class FirrtlNodeUtils(node: FirrtlNode) {
- def getType(): Type =
- node match {
- case e: Expression => e.getType
- case s: Statement => s.getType
- //case f: Field => f.getType
- case t: Type => t.getType
- case p: Port => p.getType
- case _ => UnknownType
- }
- }
+// =================================
+ def error(str: String) = throw new FIRRTLException(str)
//// =============== EXPANSION FUNCTIONS ================
- def get_size (t:Type) : Int = {
- t match {
- case (t:BundleType) => {
- var sum = 0
- for (f <- t.fields) {
- sum = sum + get_size(f.tpe)
- }
- sum
- }
- case (t:VectorType) => t.size * get_size(t.tpe)
- case (t) => 1
- }
- }
- def get_valid_points (t1:Type, t2:Type, flip1:Orientation, flip2:Orientation) : Seq[(Int,Int)] = {
- //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2])
- (t1,t2) match {
- case (t1:UIntType,t2:UIntType) => if (flip1 == flip2) Seq((0, 0)) else Seq()
- case (t1:SIntType,t2:SIntType) => if (flip1 == flip2) Seq((0, 0)) else Seq()
- case (t1:BundleType,t2:BundleType) => {
- val points = ArrayBuffer[(Int,Int)]()
- var ilen = 0
- var jlen = 0
- for (i <- 0 until t1.fields.size) {
- for (j <- 0 until t2.fields.size) {
- val f1 = t1.fields(i)
- val f2 = t2.fields(j)
- if (f1.name == f2.name) {
- val ls = get_valid_points(f1.tpe,f2.tpe,times(flip1, f1.flip),times(flip2, f2.flip))
- for (x <- ls) {
- points += ((x._1 + ilen, x._2 + jlen))
- }
- }
- jlen = jlen + get_size(t2.fields(j).tpe)
- }
- ilen = ilen + get_size(t1.fields(i).tpe)
- jlen = 0
- }
- points
- }
- case (t1:VectorType,t2:VectorType) => {
- val points = ArrayBuffer[(Int,Int)]()
- var ilen = 0
- var jlen = 0
- for (i <- 0 until scala.math.min(t1.size,t2.size)) {
- val ls = get_valid_points(t1.tpe,t2.tpe,flip1,flip2)
- for (x <- ls) {
- val y = ((x._1 + ilen), (x._2 + jlen))
- points += y
- }
- ilen = ilen + get_size(t1.tpe)
- jlen = jlen + get_size(t2.tpe)
- }
- points
- }
- case (ClockType,ClockType) => if (flip1 == flip2) Seq((0, 0)) else Seq()
- }
- }
+ def get_size(t: Type): Int = t match {
+ case (t: BundleType) => (t.fields foldLeft 0)(
+ (sum, f) => sum + get_size(f.tpe))
+ case (t: VectorType) => t.size * get_size(t.tpe)
+ case (t) => 1
+ }
+
+ def get_valid_points(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Seq[(Int,Int)] = {
+ //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2])
+ (t1, t2) match {
+ case (t1: UIntType, t2: UIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil
+ case (t1: SIntType, t2: SIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil
+ case (t1: BundleType, t2: BundleType) =>
+ def emptyMap = Map[String, (Type, Orientation, Int)]()
+ val t1_fields = ((t1.fields foldLeft (emptyMap, 0)){case ((map, ilen), f1) =>
+ (map + (f1.name -> (f1.tpe, f1.flip, ilen)), ilen + get_size(f1.tpe))})._1
+ ((t2.fields foldLeft (Seq[(Int, Int)](), 0)){case ((points, jlen), f2) =>
+ t1_fields get f2.name match {
+ case None => (points, jlen + get_size(f2.tpe))
+ case Some((f1_tpe, f1_flip, ilen))=>
+ val f1_times = times(flip1, f1_flip)
+ val f2_times = times(flip2, f2.flip)
+ val ls = get_valid_points(f1_tpe, f2.tpe, f1_times, f2_times)
+ (points ++ (ls map {case (x, y) => (x + ilen, y + jlen)}), jlen + get_size(f2.tpe))
+ }
+ })._1
+ case (t1: VectorType, t2: VectorType) =>
+ val size = math.min(t1.size, t2.size)
+ (((0 until size) foldLeft (Seq[(Int, Int)](), 0, 0)){case ((points, ilen, jlen), _) =>
+ val ls = get_valid_points(t1.tpe, t2.tpe, flip1, flip2)
+ (points ++ (ls map {case (x, y) => ((x + ilen), (y + jlen))}),
+ ilen + get_size(t1.tpe), jlen + get_size(t2.tpe))
+ })._1
+ case (ClockType, ClockType) => if (flip1 == flip2) Seq((0, 0)) else Nil
+ case _ => error("shouldn't be here")
+ }
+ }
+
// =========== GENDER/FLIP UTILS ============
- def swap (g:Gender) : Gender = {
- g match {
- case UNKNOWNGENDER => UNKNOWNGENDER
- case MALE => FEMALE
- case FEMALE => MALE
- case BIGENDER => BIGENDER
- }
- }
- def swap (d:Direction) : Direction = {
- d match {
- case Output => Input
- case Input => Output
- }
- }
- def swap (f:Orientation) : Orientation = {
- f match {
- case Default => Flip
- case Flip => Default
- }
- }
- def to_dir (g:Gender) : Direction = {
- g match {
- case MALE => Input
- case FEMALE => Output
- }
- }
- def to_gender (d:Direction) : Gender = {
- d match {
- case Input => MALE
- case Output => FEMALE
- }
- }
- def toGender(f: Orientation): Gender = f match {
- case Default => FEMALE
- case Flip => MALE
+ def swap(g: Gender) : Gender = g match {
+ case UNKNOWNGENDER => UNKNOWNGENDER
+ case MALE => FEMALE
+ case FEMALE => MALE
+ case BIGENDER => BIGENDER
+ }
+ def swap(d: Direction) : Direction = d match {
+ case Output => Input
+ case Input => Output
+ }
+ def swap(f: Orientation) : Orientation = f match {
+ case Default => Flip
+ case Flip => Default
+ }
+ def to_dir(g: Gender): Direction = g match {
+ case MALE => Input
+ case FEMALE => Output
}
- def toFlip(g: Gender): Orientation = g match {
+ def to_gender(d: Direction): Gender = d match {
+ case Input => MALE
+ case Output => FEMALE
+ }
+ def to_flip(d: Direction): Orientation = d match {
+ case Input => Flip
+ case Output => Default
+ }
+ def to_flip(g: Gender): Orientation = g match {
case MALE => Flip
case FEMALE => Default
}
- def field_flip (v:Type,s:String) : Orientation = {
- v match {
- case v:BundleType => {
- val ft = v.fields.find {p => p.name == s}
- ft match {
- case ft:Some[Field] => ft.get.flip
- case ft => Default
- }
- }
- case v => Default
- }
- }
- def get_field (v:Type,s:String) : Field = {
- v match {
- case v:BundleType => {
- val ft = v.fields.find {p => p.name == s}
- ft match {
- case ft:Some[Field] => ft.get
- case ft => error("Shouldn't be here"); Field("blah",Default,UnknownType)
- }
- }
- case v => error("Shouldn't be here"); Field("blah",Default,UnknownType)
- }
- }
- def times (flip:Orientation, d:Direction) : Direction = times(flip, d)
- def times (d:Direction,flip:Orientation) : Direction = {
- flip match {
- case Default => d
- case Flip => swap(d)
- }
- }
- def times (g: Gender, d: Direction): Direction = times(d, g)
- def times (d: Direction, g: Gender): Direction = g match {
- case FEMALE => d
- case MALE => swap(d) // MALE == INPUT == REVERSE
- }
+ def field_flip(v:Type, s:String) : Orientation = v match {
+ case (v:BundleType) => v.fields find (_.name == s) match {
+ case Some(ft) => ft.flip
+ case None => Default
+ }
+ case v => Default
+ }
+ def get_field(v:Type, s:String) : Field = v match {
+ case (v:BundleType) => v.fields find (_.name == s) match {
+ case Some(ft) => ft
+ case None => error("Shouldn't be here")
+ }
+ case v => error("Shouldn't be here")
+ }
- def times (g:Gender,flip:Orientation) : Gender = times(flip, g)
- def times (flip:Orientation, g:Gender) : Gender = {
- flip match {
- case Default => g
- case Flip => swap(g)
- }
- }
- def times (f1:Orientation, f2:Orientation) : Orientation = {
- f2 match {
- case Default => f1
- case Flip => swap(f1)
- }
- }
+ def times(flip: Orientation, d: Direction): Direction = times(flip, d)
+ def times(d: Direction,flip: Orientation): Direction = flip match {
+ case Default => d
+ case Flip => swap(d)
+ }
+ def times(g: Gender, d: Direction): Direction = times(d, g)
+ def times(d: Direction, g: Gender): Direction = g match {
+ case FEMALE => d
+ case MALE => swap(d) // MALE == INPUT == REVERSE
+ }
+ def times(g: Gender,flip: Orientation): Gender = times(flip, g)
+ def times(flip: Orientation, g: Gender): Gender = flip match {
+ case Default => g
+ case Flip => swap(g)
+ }
+ def times(f1: Orientation, f2: Orientation): Orientation = f2 match {
+ case Default => f1
+ case Flip => swap(f1)
+ }
// =========== ACCESSORS =========
- def info (s:Statement) : Info = {
- s match {
- case s:DefWire => s.info
- case s:DefRegister => s.info
- case s:DefInstance => s.info
- case s:WDefInstance => s.info
- case s:DefMemory => s.info
- case s:DefNode => s.info
- case s:Conditionally => s.info
- case s:PartialConnect => s.info
- case s:Connect => s.info
- case s:IsInvalid => s.info
- case s:Stop => s.info
- case s:Print => s.info
- case s:Block => NoInfo
- case EmptyStmt => NoInfo
- }
- }
- def gender (e:Expression) : Gender = {
- e match {
- case e:WRef => e.gender
- case e:WSubField => e.gender
- case e:WSubIndex => e.gender
- case e:WSubAccess => e.gender
- case e:DoPrim => MALE
- case e:UIntLiteral => MALE
- case e:SIntLiteral => MALE
- case e:Mux => MALE
- case e:ValidIf => MALE
- case e:WInvalid => MALE
- case e => println(e); error("Shouldn't be here")
- }}
- def get_gender (s:Statement) : Gender =
- s match {
- case s:DefWire => BIGENDER
- case s:DefRegister => BIGENDER
- case s:WDefInstance => MALE
- case s:DefNode => MALE
- case s:DefInstance => MALE
- case s:DefMemory => MALE
- case s:Block => UNKNOWNGENDER
- case s:Connect => UNKNOWNGENDER
- case s:PartialConnect => UNKNOWNGENDER
- case s:Stop => UNKNOWNGENDER
- case s:Print => UNKNOWNGENDER
- case EmptyStmt => UNKNOWNGENDER
- case s:IsInvalid => UNKNOWNGENDER
- }
- def get_gender (p:Port) : Gender =
- if (p.direction == Input) MALE else FEMALE
- def kind (e:Expression) : Kind =
- e match {
- case e:WRef => e.kind
- case e:WSubField => kind(e.exp)
- case e:WSubIndex => kind(e.exp)
- case e => ExpKind()
- }
- def tpe (e:Expression) : Type =
- e match {
- case e:Reference => e.tpe
- case e:SubField => e.tpe
- case e:SubIndex => e.tpe
- case e:SubAccess => e.tpe
- case e:WRef => e.tpe
- case e:WSubField => e.tpe
- case e:WSubIndex => e.tpe
- case e:WSubAccess => e.tpe
- case e:DoPrim => e.tpe
- case e:Mux => e.tpe
- case e:ValidIf => e.tpe
- case e:UIntLiteral => UIntType(e.width)
- case e:SIntLiteral => SIntType(e.width)
- case e:WVoid => UnknownType
- case e:WInvalid => UnknownType
- }
- def get_type (s:Statement) : Type = {
- s match {
- case s:DefWire => s.tpe
- case s:DefRegister => s.tpe
- case s:DefNode => tpe(s.value)
- case s:DefMemory => {
- val depth = s.depth
- val addr = Field("addr",Default,UIntType(IntWidth(scala.math.max(ceil_log2(depth), 1))))
- val en = Field("en",Default,BoolType)
- val clk = Field("clk",Default,ClockType)
- val def_data = Field("data",Default,s.dataType)
- val rev_data = Field("data",Flip,s.dataType)
- val mask = Field("mask",Default,create_mask(s.dataType))
- val wmode = Field("wmode",Default,UIntType(IntWidth(1)))
- val rdata = Field("rdata",Flip,s.dataType)
- val wdata = Field("wdata",Default,s.dataType)
- val wmask = Field("wmask",Default,create_mask(s.dataType))
- val read_type = BundleType(Seq(rev_data,addr,en,clk))
- val write_type = BundleType(Seq(def_data,mask,addr,en,clk))
- val readwrite_type = BundleType(Seq(wmode,rdata,wdata,wmask,addr,en,clk))
-
- val mem_fields = ArrayBuffer[Field]()
- s.readers.foreach {x => mem_fields += Field(x,Flip,read_type)}
- s.writers.foreach {x => mem_fields += Field(x,Flip,write_type)}
- s.readwriters.foreach {x => mem_fields += Field(x,Flip,readwrite_type)}
- BundleType(mem_fields)
- }
- case s:DefInstance => UnknownType
- case s:WDefInstance => s.tpe
- case _ => UnknownType
- }}
- def get_name (s:Statement) : String = {
- s match {
- case s:DefWire => s.name
- case s:DefRegister => s.name
- case s:DefNode => s.name
- case s:DefMemory => s.name
- case s:DefInstance => s.name
- case s:WDefInstance => s.name
- case _ => error("Shouldn't be here"); "blah"
- }}
- def get_info (s:Statement) : Info = {
- s match {
- case s:DefWire => s.info
- case s:DefRegister => s.info
- case s:DefInstance => s.info
- case s:WDefInstance => s.info
- case s:DefMemory => s.info
- case s:DefNode => s.info
- case s:Conditionally => s.info
- case s:PartialConnect => s.info
- case s:Connect => s.info
- case s:IsInvalid => s.info
- case s:Stop => s.info
- case s:Print => s.info
- case _ => NoInfo
- }}
+ def kind(e: Expression): Kind = e match {
+ case e: WRef => e.kind
+ case e: WSubField => kind(e.exp)
+ case e: WSubIndex => kind(e.exp)
+ case e: WSubAccess => kind(e.exp)
+ case e => ExpKind()
+ }
+ def gender (e: Expression): Gender = e match {
+ case e: WRef => e.gender
+ case e: WSubField => e.gender
+ case e: WSubIndex => e.gender
+ case e: WSubAccess => e.gender
+ case e: DoPrim => MALE
+ case e: UIntLiteral => MALE
+ case e: SIntLiteral => MALE
+ case e: Mux => MALE
+ case e: ValidIf => MALE
+ case e: WInvalid => MALE
+ case e => println(e); error("Shouldn't be here")
+ }
+ def get_gender(s:Statement): Gender = s match {
+ case s: DefWire => BIGENDER
+ case s: DefRegister => BIGENDER
+ case s: WDefInstance => MALE
+ case s: DefNode => MALE
+ case s: DefInstance => MALE
+ case s: DefMemory => MALE
+ case s: Block => UNKNOWNGENDER
+ case s: Connect => UNKNOWNGENDER
+ case s: PartialConnect => UNKNOWNGENDER
+ case s: Stop => UNKNOWNGENDER
+ case s: Print => UNKNOWNGENDER
+ case s: IsInvalid => UNKNOWNGENDER
+ case EmptyStmt => UNKNOWNGENDER
+ }
+ def get_gender(p: Port): Gender = if (p.direction == Input) MALE else FEMALE
+ def get_type(s: Statement): Type = s match {
+ case s: DefWire => s.tpe
+ case s: DefRegister => s.tpe
+ case s: DefNode => s.value.tpe
+ case s: DefMemory =>
+ val depth = s.depth
+ val addr = Field("addr", Default, UIntType(IntWidth(scala.math.max(ceil_log2(depth), 1))))
+ val en = Field("en", Default, BoolType)
+ val clk = Field("clk", Default, ClockType)
+ val def_data = Field("data", Default, s.dataType)
+ val rev_data = Field("data", Flip, s.dataType)
+ val mask = Field("mask", Default, create_mask(s.dataType))
+ val wmode = Field("wmode", Default, UIntType(IntWidth(1)))
+ val rdata = Field("rdata", Flip, s.dataType)
+ val wdata = Field("wdata", Default, s.dataType)
+ val wmask = Field("wmask", Default, create_mask(s.dataType))
+ val read_type = BundleType(Seq(rev_data, addr, en, clk))
+ val write_type = BundleType(Seq(def_data, mask, addr, en, clk))
+ val readwrite_type = BundleType(Seq(wmode, rdata, wdata, wmask, addr, en, clk))
+ BundleType(
+ (s.readers map (Field(_, Flip, read_type))) ++
+ (s.writers map (Field(_, Flip, write_type))) ++
+ (s.readwriters map (Field(_, Flip, readwrite_type)))
+ )
+ case s: WDefInstance => s.tpe
+ case _ => UnknownType
+ }
+ def get_name(s: Statement): String = s match {
+ case s: HasName => s.name
+ case _ => error("Shouldn't be here")
+ }
+ def get_info(s: Statement): Info = s match {
+ case s: HasInfo => s.info
+ case _ => NoInfo
+ }
/** Splits an Expression into root Ref and tail
*
@@ -680,11 +504,12 @@ object Utils extends LazyLogging {
case None =>
getRootDecl(root.name)(m.body) match {
case Some(decl) => decl
- case None => throw new DeclarationNotFoundException(s"[module ${m.name}] Reference ${expr.serialize} not declared!")
+ case None => throw new DeclarationNotFoundException(
+ s"[module ${m.name}] Reference ${expr.serialize} not declared!")
}
}
rootDecl
- case e => throw new FIRRTLException(s"getDeclaration does not support Expressions of type ${e.getClass}")
+ case e => error(s"getDeclaration does not support Expressions of type ${e.getClass}")
}
}
@@ -699,7 +524,6 @@ object Utils extends LazyLogging {
def apply_s (s:Statement) : Statement = s map (apply_s) map (apply_e) map (apply_t)
apply_s(s)
}
- val ONE = IntWidth(1)
//def digits (s:String) : Boolean {
// val digits = "0123456789"
// var yes:Boolean = true
@@ -750,322 +574,79 @@ object Utils extends LazyLogging {
// to-stmt(body(m))
// map(to-port,ports(m))
// sym-hash
- implicit class StmtUtils(stmt: Statement) {
- def getType(): Type =
- stmt match {
- case s: DefWire => s.tpe
- case s: DefRegister => s.tpe
- case s: DefMemory => s.dataType
- case _ => UnknownType
- }
+ val v_keywords = Set(
+ "alias", "always", "always_comb", "always_ff", "always_latch",
+ "and", "assert", "assign", "assume", "attribute", "automatic",
- def getInfo: Info =
- stmt match {
- case s: DefWire => s.info
- case s: DefRegister => s.info
- case s: DefInstance => s.info
- case s: DefMemory => s.info
- case s: DefNode => s.info
- case s: Conditionally => s.info
- case s: PartialConnect => s.info
- case s: Connect => s.info
- case s: IsInvalid => s.info
- case s: Stop => s.info
- case s: Print => s.info
- case _ => NoInfo
- }
- }
+ "before", "begin", "bind", "bins", "binsof", "bit", "break",
+ "buf", "bufif0", "bufif1", "byte",
- implicit class FlipUtils(f: Orientation) {
- def flip(): Orientation = {
- f match {
- case Flip => Default
- case Default => Flip
- }
- }
-
- def toDirection(): Direction = {
- f match {
- case Default => Output
- case Flip => Input
- }
- }
- }
+ "case", "casex", "casez", "cell", "chandle", "class", "clocking",
+ "cmos", "config", "const", "constraint", "context", "continue",
+ "cover", "covergroup", "coverpoint", "cross",
- implicit class FieldUtils(field: Field) {
- def flip(): Field = Field(field.name, field.flip.flip, field.tpe)
+ "deassign", "default", "defparam", "design", "disable", "dist", "do",
- def getType(): Type = field.tpe
- def toPort(info: Info = NoInfo): Port =
- Port(info, field.name, field.flip.toDirection, field.tpe)
- }
+ "edge", "else", "end", "endattribute", "endcase", "endclass",
+ "endclocking", "endconfig", "endfunction", "endgenerate",
+ "endgroup", "endinterface", "endmodule", "endpackage",
+ "endprimitive", "endprogram", "endproperty", "endspecify",
+ "endsequence", "endtable", "endtask",
+ "enum", "event", "expect", "export", "extends", "extern",
- implicit class TypeUtils(t: Type) {
- def isGround: Boolean = t match {
- case (_: UIntType | _: SIntType | ClockType) => true
- case (_: BundleType | _: VectorType) => false
- }
- def isAggregate: Boolean = !t.isGround
+ "final", "first_match", "for", "force", "foreach", "forever",
+ "fork", "forkjoin", "function",
- def getType(): Type =
- t match {
- case v: VectorType => v.tpe
- case tpe: Type => UnknownType
- }
+ "generate", "genvar",
- def wipeWidth(): Type =
- t match {
- case t: UIntType => UIntType(UnknownWidth)
- case t: SIntType => SIntType(UnknownWidth)
- case _ => t
- }
- }
+ "highz0", "highz1",
- implicit class DirectionUtils(d: Direction) {
- def toFlip(): Orientation = {
- d match {
- case Input => Flip
- case Output => Default
- }
- }
- }
-
- implicit class PortUtils(p: Port) {
- def getType(): Type = p.tpe
- def toField(): Field = Field(p.name, p.direction.toFlip, p.tpe)
- }
+ "if", "iff", "ifnone", "ignore_bins", "illegal_bins", "import",
+ "incdir", "include", "initial", "initvar", "inout", "input",
+ "inside", "instance", "int", "integer", "interconnect",
+ "interface", "intersect",
+
+ "join", "join_any", "join_none", "large", "liblist", "library",
+ "local", "localparam", "logic", "longint",
+
+ "macromodule", "matches", "medium", "modport", "module",
+
+ "nand", "negedge", "new", "nmos", "nor", "noshowcancelled",
+ "not", "notif0", "notif1", "null",
+
+ "or", "output",
+
+ "package", "packed", "parameter", "pmos", "posedge",
+ "primitive", "priority", "program", "property", "protected",
+ "pull0", "pull1", "pulldown", "pullup",
+ "pulsestyle_onevent", "pulsestyle_ondetect", "pure",
+
+ "rand", "randc", "randcase", "randsequence", "rcmos",
+ "real", "realtime", "ref", "reg", "release", "repeat",
+ "return", "rnmos", "rpmos", "rtran", "rtranif0", "rtranif1",
+
+ "scalared", "sequence", "shortint", "shortreal", "showcancelled",
+ "signed", "small", "solve", "specify", "specparam", "static",
+ "strength", "string", "strong0", "strong1", "struct", "super",
+ "supply0", "supply1",
+
+ "table", "tagged", "task", "this", "throughout", "time", "timeprecision",
+ "timeunit", "tran", "tranif0", "tranif1", "tri", "tri0", "tri1", "triand",
+ "trior", "trireg", "type","typedef",
+
+ "union", "unique", "unsigned", "use",
+
+ "var", "vectored", "virtual", "void",
+
+ "wait", "wait_order", "wand", "weak0", "weak1", "while",
+ "wildcard", "wire", "with", "within", "wor",
+ "xnor", "xor",
- val v_keywords = Map[String,Boolean]() +
- ("alias" -> true) +
- ("always" -> true) +
- ("always_comb" -> true) +
- ("always_ff" -> true) +
- ("always_latch" -> true) +
- ("and" -> true) +
- ("assert" -> true) +
- ("assign" -> true) +
- ("assume" -> true) +
- ("attribute" -> true) +
- ("automatic" -> true) +
- ("before" -> true) +
- ("begin" -> true) +
- ("bind" -> true) +
- ("bins" -> true) +
- ("binsof" -> true) +
- ("bit" -> true) +
- ("break" -> true) +
- ("buf" -> true) +
- ("bufif0" -> true) +
- ("bufif1" -> true) +
- ("byte" -> true) +
- ("case" -> true) +
- ("casex" -> true) +
- ("casez" -> true) +
- ("cell" -> true) +
- ("chandle" -> true) +
- ("class" -> true) +
- ("clocking" -> true) +
- ("cmos" -> true) +
- ("config" -> true) +
- ("const" -> true) +
- ("constraint" -> true) +
- ("context" -> true) +
- ("continue" -> true) +
- ("cover" -> true) +
- ("covergroup" -> true) +
- ("coverpoint" -> true) +
- ("cross" -> true) +
- ("deassign" -> true) +
- ("default" -> true) +
- ("defparam" -> true) +
- ("design" -> true) +
- ("disable" -> true) +
- ("dist" -> true) +
- ("do" -> true) +
- ("edge" -> true) +
- ("else" -> true) +
- ("end" -> true) +
- ("endattribute" -> true) +
- ("endcase" -> true) +
- ("endclass" -> true) +
- ("endclocking" -> true) +
- ("endconfig" -> true) +
- ("endfunction" -> true) +
- ("endgenerate" -> true) +
- ("endgroup" -> true) +
- ("endinterface" -> true) +
- ("endmodule" -> true) +
- ("endpackage" -> true) +
- ("endprimitive" -> true) +
- ("endprogram" -> true) +
- ("endproperty" -> true) +
- ("endspecify" -> true) +
- ("endsequence" -> true) +
- ("endtable" -> true) +
- ("endtask" -> true) +
- ("enum" -> true) +
- ("event" -> true) +
- ("expect" -> true) +
- ("export" -> true) +
- ("extends" -> true) +
- ("extern" -> true) +
- ("final" -> true) +
- ("first_match" -> true) +
- ("for" -> true) +
- ("force" -> true) +
- ("foreach" -> true) +
- ("forever" -> true) +
- ("fork" -> true) +
- ("forkjoin" -> true) +
- ("function" -> true) +
- ("generate" -> true) +
- ("genvar" -> true) +
- ("highz0" -> true) +
- ("highz1" -> true) +
- ("if" -> true) +
- ("iff" -> true) +
- ("ifnone" -> true) +
- ("ignore_bins" -> true) +
- ("illegal_bins" -> true) +
- ("import" -> true) +
- ("incdir" -> true) +
- ("include" -> true) +
- ("initial" -> true) +
- ("initvar" -> true) +
- ("inout" -> true) +
- ("input" -> true) +
- ("inside" -> true) +
- ("instance" -> true) +
- ("int" -> true) +
- ("integer" -> true) +
- ("interconnect" -> true) +
- ("interface" -> true) +
- ("intersect" -> true) +
- ("join" -> true) +
- ("join_any" -> true) +
- ("join_none" -> true) +
- ("large" -> true) +
- ("liblist" -> true) +
- ("library" -> true) +
- ("local" -> true) +
- ("localparam" -> true) +
- ("logic" -> true) +
- ("longint" -> true) +
- ("macromodule" -> true) +
- ("matches" -> true) +
- ("medium" -> true) +
- ("modport" -> true) +
- ("module" -> true) +
- ("nand" -> true) +
- ("negedge" -> true) +
- ("new" -> true) +
- ("nmos" -> true) +
- ("nor" -> true) +
- ("noshowcancelled" -> true) +
- ("not" -> true) +
- ("notif0" -> true) +
- ("notif1" -> true) +
- ("null" -> true) +
- ("or" -> true) +
- ("output" -> true) +
- ("package" -> true) +
- ("packed" -> true) +
- ("parameter" -> true) +
- ("pmos" -> true) +
- ("posedge" -> true) +
- ("primitive" -> true) +
- ("priority" -> true) +
- ("program" -> true) +
- ("property" -> true) +
- ("protected" -> true) +
- ("pull0" -> true) +
- ("pull1" -> true) +
- ("pulldown" -> true) +
- ("pullup" -> true) +
- ("pulsestyle_onevent" -> true) +
- ("pulsestyle_ondetect" -> true) +
- ("pure" -> true) +
- ("rand" -> true) +
- ("randc" -> true) +
- ("randcase" -> true) +
- ("randsequence" -> true) +
- ("rcmos" -> true) +
- ("real" -> true) +
- ("realtime" -> true) +
- ("ref" -> true) +
- ("reg" -> true) +
- ("release" -> true) +
- ("repeat" -> true) +
- ("return" -> true) +
- ("rnmos" -> true) +
- ("rpmos" -> true) +
- ("rtran" -> true) +
- ("rtranif0" -> true) +
- ("rtranif1" -> true) +
- ("scalared" -> true) +
- ("sequence" -> true) +
- ("shortint" -> true) +
- ("shortreal" -> true) +
- ("showcancelled" -> true) +
- ("signed" -> true) +
- ("small" -> true) +
- ("solve" -> true) +
- ("specify" -> true) +
- ("specparam" -> true) +
- ("static" -> true) +
- ("strength" -> true) +
- ("string" -> true) +
- ("strong0" -> true) +
- ("strong1" -> true) +
- ("struct" -> true) +
- ("super" -> true) +
- ("supply0" -> true) +
- ("supply1" -> true) +
- ("table" -> true) +
- ("tagged" -> true) +
- ("task" -> true) +
- ("this" -> true) +
- ("throughout" -> true) +
- ("time" -> true) +
- ("timeprecision" -> true) +
- ("timeunit" -> true) +
- ("tran" -> true) +
- ("tranif0" -> true) +
- ("tranif1" -> true) +
- ("tri" -> true) +
- ("tri0" -> true) +
- ("tri1" -> true) +
- ("triand" -> true) +
- ("trior" -> true) +
- ("trireg" -> true) +
- ("type" -> true) +
- ("typedef" -> true) +
- ("union" -> true) +
- ("unique" -> true) +
- ("unsigned" -> true) +
- ("use" -> true) +
- ("var" -> true) +
- ("vectored" -> true) +
- ("virtual" -> true) +
- ("void" -> true) +
- ("wait" -> true) +
- ("wait_order" -> true) +
- ("wand" -> true) +
- ("weak0" -> true) +
- ("weak1" -> true) +
- ("while" -> true) +
- ("wildcard" -> true) +
- ("wire" -> true) +
- ("with" -> true) +
- ("within" -> true) +
- ("wor" -> true) +
- ("xnor" -> true) +
- ("xor" -> true) +
- ("SYNTHESIS" -> true) +
- ("PRINTF_COND" -> true) +
- ("VCS" -> true)
+ "SYNTHESIS",
+ "PRINTF_COND",
+ "VCS")
}
object MemoizedHash {
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala
index 60a49bac..e0e7c57a 100644
--- a/src/main/scala/firrtl/passes/CheckChirrtl.scala
+++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala
@@ -105,7 +105,7 @@ object CheckChirrtl extends Pass with LazyLogging {
e
}
def checkChirrtlS(s: Statement): Statement = {
- sinfo = s.getInfo
+ sinfo = get_info(s)
def checkName(name: String): String = {
if (names.contains(name)) errors.append(new NotUniqueException(name))
else names(name) = true
@@ -138,7 +138,7 @@ object CheckChirrtl extends Pass with LazyLogging {
for (p <- m.ports) {
sinfo = p.info
names(p.name) = true
- val tpe = p.getType
+ val tpe = p.tpe
tpe map (checkChirrtlT)
tpe map (checkChirrtlW)
}
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index 9ee20c0a..6e49ce93 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -241,7 +241,7 @@ object CheckHighForm extends Pass with LazyLogging {
else names(name) = true
name
}
- sinfo = s.getInfo
+ sinfo = get_info(s)
s map (checkName)
s map (checkHighFormT)
@@ -276,7 +276,7 @@ object CheckHighForm extends Pass with LazyLogging {
for (p <- m.ports) {
// FIXME should we set sinfo here?
names(p.name) = true
- val tpe = p.getType
+ val tpe = p.tpe
tpe map (checkHighFormT)
tpe map (checkHighFormW)
}
@@ -336,27 +336,36 @@ object CheckTypes extends Pass with LazyLogging {
def all_same_type (ls:Seq[Expression]) : Unit = {
var error = false
for (x <- ls) {
- if (wt(tpe(ls.head)) != wt(tpe(x))) error = true
+ if (wt(ls.head.tpe) != wt(x.tpe)) error = true
}
if (error) errors.append(new OpNotAllSameType(info,e.op.serialize))
}
def all_ground (ls:Seq[Expression]) : Unit = {
var error = false
for (x <- ls ) {
- if (!(tpe(x).typeof[UIntType] || tpe(x).typeof[SIntType])) error = true
+ x.tpe match {
+ case _: UIntType | _: SIntType =>
+ case _ => error = true
+ }
}
if (error) errors.append(new OpNotGround(info,e.op.serialize))
}
def all_uint (ls:Seq[Expression]) : Unit = {
var error = false
for (x <- ls ) {
- if (!(tpe(x).typeof[UIntType])) error = true
+ x.tpe match {
+ case _: UIntType =>
+ case _ => error = true
+ }
}
if (error) errors.append(new OpNotAllUInt(info,e.op.serialize))
}
def is_uint (x:Expression) : Unit = {
var error = false
- if (!(tpe(x).typeof[UIntType])) error = true
+ x.tpe match {
+ case _: UIntType =>
+ case _ => error = true
+ }
if (error) errors.append(new OpNotUInt(info,e.op.serialize,x.serialize))
}
@@ -417,7 +426,7 @@ object CheckTypes extends Pass with LazyLogging {
(e map (check_types_e(info))) match {
case (e:WRef) => e
case (e:WSubField) => {
- (tpe(e.exp)) match {
+ (e.exp.tpe) match {
case (t:BundleType) => {
val ft = t.fields.find(p => p.name == e.name)
if (ft == None) errors.append(new SubfieldNotInBundle(info,e.name))
@@ -426,7 +435,7 @@ object CheckTypes extends Pass with LazyLogging {
}
}
case (e:WSubIndex) => {
- (tpe(e.exp)) match {
+ (e.exp.tpe) match {
case (t:VectorType) => {
if (e.value >= t.size) errors.append(new IndexTooLarge(info,e.value))
}
@@ -434,24 +443,30 @@ object CheckTypes extends Pass with LazyLogging {
}
}
case (e:WSubAccess) => {
- (tpe(e.exp)) match {
+ (e.exp.tpe) match {
case (t:VectorType) => false
case (t) => errors.append(new IndexOnNonVector(info))
}
- (tpe(e.index)) match {
+ (e.index.tpe) match {
case (t:UIntType) => false
case (t) => errors.append(new AccessIndexNotUInt(info))
}
}
case (e:DoPrim) => check_types_primop(e,errors,info)
case (e:Mux) => {
- if (wt(tpe(e.tval)) != wt(tpe(e.fval))) errors.append(new MuxSameType(info))
- if (!passive(tpe(e))) errors.append(new MuxPassiveTypes(info))
- if (!(tpe(e.cond).typeof[UIntType])) errors.append(new MuxCondUInt(info))
+ if (wt(e.tval.tpe) != wt(e.fval.tpe)) errors.append(new MuxSameType(info))
+ if (!passive(e.tpe)) errors.append(new MuxPassiveTypes(info))
+ e.cond.tpe match {
+ case _: UIntType =>
+ case _ => errors.append(new MuxCondUInt(info))
+ }
}
case (e:ValidIf) => {
- if (!passive(tpe(e))) errors.append(new ValidIfPassiveTypes(info))
- if (!(tpe(e.cond).typeof[UIntType])) errors.append(new ValidIfCondUInt(info))
+ if (!passive(e.tpe)) errors.append(new ValidIfPassiveTypes(info))
+ e.cond.tpe match {
+ case _: UIntType =>
+ case _ => errors.append(new ValidIfCondUInt(info))
+ }
}
case (_:UIntLiteral | _:SIntLiteral) => false
}
@@ -484,22 +499,22 @@ object CheckTypes extends Pass with LazyLogging {
def check_types_s (s:Statement) : Statement = {
s map (check_types_e(get_info(s))) match {
- case (s:Connect) => if (wt(tpe(s.loc)) != wt(tpe(s.expr))) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize))
- case (s:DefRegister) => if (wt(s.tpe) != wt(tpe(s.init))) errors.append(new InvalidRegInit(s.info))
- case (s:PartialConnect) => if (!bulk_equals(tpe(s.loc),tpe(s.expr),Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize))
+ case (s:Connect) => if (wt(s.loc.tpe) != wt(s.expr.tpe)) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize))
+ case (s:DefRegister) => if (wt(s.tpe) != wt(s.init.tpe)) errors.append(new InvalidRegInit(s.info))
+ case (s:PartialConnect) => if (!bulk_equals(s.loc.tpe,s.expr.tpe,Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize))
case (s:Stop) => {
- if (wt(tpe(s.clk)) != wt(ClockType) ) errors.append(new ReqClk(s.info))
- if (wt(tpe(s.en)) != wt(ut()) ) errors.append(new EnNotUInt(s.info))
+ if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info))
+ if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info))
}
case (s:Print)=> {
for (x <- s.args ) {
- if (wt(tpe(x)) != wt(ut()) && wt(tpe(x)) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info))
+ if (wt(x.tpe) != wt(ut()) && wt(x.tpe) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info))
}
- if (wt(tpe(s.clk)) != wt(ClockType) ) errors.append(new ReqClk(s.info))
- if (wt(tpe(s.en)) != wt(ut()) ) errors.append(new EnNotUInt(s.info))
+ if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info))
+ if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info))
}
- case (s:Conditionally) => if (wt(tpe(s.pred)) != wt(ut()) ) errors.append(new PredNotUInt(s.info))
- case (s:DefNode) => if (!passive(tpe(s.value)) ) errors.append(new NodePassiveType(s.info))
+ case (s:Conditionally) => if (wt(s.pred.tpe) != wt(ut()) ) errors.append(new PredNotUInt(s.info))
+ case (s:DefNode) => if (!passive(s.value.tpe) ) errors.append(new NodePassiveType(s.info))
case (s) => false
}
s map (check_types_s)
@@ -571,7 +586,7 @@ object CheckGenders extends Pass {
fQ
}
- val has_flipQ = flipQ(tpe(e))
+ val has_flipQ = flipQ(e.tpe)
//println(e)
//println(gender)
//println(desired)
@@ -597,7 +612,7 @@ object CheckGenders extends Pass {
(e) match {
case (e:WRef) => genders(e.name)
case (e:WSubField) =>
- val f = tpe(e.exp).as[BundleType].get.fields.find(f => f.name == e.name).get
+ val f = e.exp.tpe.asInstanceOf[BundleType].fields.find(f => f.name == e.name).get
times(get_gender(e.exp,genders),f.flip)
case (e:WSubIndex) => get_gender(e.exp,genders)
case (e:WSubAccess) => get_gender(e.exp,genders)
@@ -735,7 +750,7 @@ object CheckWidths extends Pass {
}
def check_width_s (s:Statement) : Statement = {
s map (check_width_s) map (check_width_e(get_info(s)))
- def tm (t:Type) : Type = mapr(check_width_w(info(s)) _,t)
+ def tm (t:Type) : Type = mapr(check_width_w(get_info(s)) _,t)
s map (tm)
}
diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala
index 57782a3c..2e8b53f3 100644
--- a/src/main/scala/firrtl/passes/ConstProp.scala
+++ b/src/main/scala/firrtl/passes/ConstProp.scala
@@ -129,7 +129,7 @@ object ConstProp extends Pass {
private def foldComparison(e: DoPrim) = {
def foldIfZeroedArg(x: Expression): Expression = {
- def isUInt(e: Expression): Boolean = tpe(e) match {
+ def isUInt(e: Expression): Boolean = e.tpe match {
case UIntType(_) => true
case _ => false
}
@@ -163,7 +163,7 @@ object ConstProp extends Pass {
def range(e: Expression): Range = e match {
case UIntLiteral(value, _) => Range(value, value)
case SIntLiteral(value, _) => Range(value, value)
- case _ => tpe(e) match {
+ case _ => e.tpe match {
case SIntType(IntWidth(width)) => Range(
min = BigInt(0) - BigInt(2).pow(width.toInt - 1),
max = BigInt(2).pow(width.toInt - 1) - BigInt(1)
@@ -226,7 +226,7 @@ object ConstProp extends Pass {
case Pad => e.args(0) match {
case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(e.consts(0)))
case SIntLiteral(v, _) => SIntLiteral(v, IntWidth(e.consts(0)))
- case _ if long_BANG(tpe(e.args(0))) == e.consts(0) => e.args(0)
+ case _ if long_BANG(e.args(0).tpe) == e.consts(0) => e.args(0)
case _ => e
}
case Bits => e.args(0) match {
@@ -234,9 +234,9 @@ object ConstProp extends Pass {
val hi = e.consts(0).toInt
val lo = e.consts(1).toInt
require(hi >= lo)
- UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e)))
+ UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(e.tpe))
}
- case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match {
+ case x if long_BANG(e.tpe) == long_BANG(x.tpe) => x.tpe match {
case t: UIntType => x
case _ => asUInt(x, e.tpe)
}
diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala
index 921693c7..3d26298a 100644
--- a/src/main/scala/firrtl/passes/ExpandWhens.scala
+++ b/src/main/scala/firrtl/passes/ExpandWhens.scala
@@ -131,8 +131,8 @@ object ExpandWhens extends Pass {
val falseValue = altNetlist.getOrElse(lvalue, defaultValue)
(trueValue, falseValue) match {
case (WInvalid(), WInvalid()) => WInvalid()
- case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, tpe(fv))
- case (tv, WInvalid()) => ValidIf(s.pred, tv, tpe(tv))
+ case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe)
+ case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe)
case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv))
}
case None =>
diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala
index 7793c85c..a8fda1bf 100644
--- a/src/main/scala/firrtl/passes/Inline.scala
+++ b/src/main/scala/firrtl/passes/Inline.scala
@@ -5,7 +5,6 @@ package passes
import scala.collection.mutable
import firrtl.Mappers.{ExpMap,StmtMap}
-import firrtl.Utils.WithAs
import firrtl.ir._
import firrtl.passes.{PassException,PassExceptions}
import Annotations.{Loose, Unstable, Annotation, TransID, Named, ModuleName, ComponentName, CircuitName, AnnotationMap}
diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala
index 585598a8..a4c584ed 100644
--- a/src/main/scala/firrtl/passes/LowerTypes.scala
+++ b/src/main/scala/firrtl/passes/LowerTypes.scala
@@ -105,15 +105,15 @@ object LowerTypes extends Pass {
require(tail.isEmpty) // there can't be a tail for these
val memType = memDataTypeMap(mem.name)
- if (memType.isGround) {
- Seq(e)
- } else {
- val exps = create_exps(mem.name, memType)
- exps map { e =>
- val loMemName = loweredName(e)
- val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
- mergeRef(loMem, mergeRef(port, field))
- }
+ memType match {
+ case _: GroundType => Seq(e)
+ case _ =>
+ val exps = create_exps(mem.name, memType)
+ exps map { e =>
+ val loMemName = loweredName(e)
+ val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER)
+ mergeRef(loMem, mergeRef(port, field))
+ }
}
// Fields that need not be replicated for each
// eg. mem.reader.data[0].a
@@ -138,7 +138,7 @@ object LowerTypes extends Pass {
case k: InstanceKind =>
val (root, tail) = splitRef(e)
val name = loweredName(tail)
- WSubField(root, name, tpe(e), gender(e))
+ WSubField(root, name, e.tpe, gender(e))
case k: MemKind =>
val exps = lowerTypesMemExp(e)
if (exps.length > 1)
@@ -146,7 +146,7 @@ object LowerTypes extends Pass {
" to be expanded!")
exps(0)
case k =>
- WRef(loweredName(e), tpe(e), kind(e), gender(e))
+ WRef(loweredName(e), e.tpe, kind(e), gender(e))
}
case e: Mux => e map (lowerTypesExp)
case e: ValidIf => e map (lowerTypesExp)
@@ -158,26 +158,26 @@ object LowerTypes extends Pass {
s map lowerTypesStmt match {
case s: DefWire =>
sinfo = s.info
- if (s.tpe.isGround) {
- s
- } else {
- val exps = create_exps(s.name, s.tpe)
- val stmts = exps map (e => DefWire(s.info, loweredName(e), tpe(e)))
- Block(stmts)
+ s.tpe match {
+ case _: GroundType => s
+ case _ =>
+ val exps = create_exps(s.name, s.tpe)
+ val stmts = exps map (e => DefWire(s.info, loweredName(e), e.tpe))
+ Block(stmts)
}
case s: DefRegister =>
sinfo = s.info
- if (s.tpe.isGround) {
- s map lowerTypesExp
- } else {
- val es = create_exps(s.name, s.tpe)
- val inits = create_exps(s.init) map (lowerTypesExp)
- val clock = lowerTypesExp(s.clock)
- val reset = lowerTypesExp(s.reset)
- val stmts = es zip inits map { case (e, i) =>
- DefRegister(s.info, loweredName(e), tpe(e), clock, reset, i)
- }
- Block(stmts)
+ s.tpe match {
+ case _: GroundType => s map lowerTypesExp
+ case _ =>
+ val es = create_exps(s.name, s.tpe)
+ val inits = create_exps(s.init) map (lowerTypesExp)
+ val clock = lowerTypesExp(s.clock)
+ val reset = lowerTypesExp(s.reset)
+ val stmts = es zip inits map { case (e, i) =>
+ DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i)
+ }
+ Block(stmts)
}
// Could instead just save the type of each Module as it gets processed
case s: WDefInstance =>
@@ -188,7 +188,7 @@ object LowerTypes extends Pass {
val exps = create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE)))
exps map ( e =>
// Flip because inst genders are reversed from Module type
- Field(loweredName(e), toFlip(gender(e)).flip, tpe(e))
+ Field(loweredName(e), swap(to_flip(gender(e))), e.tpe)
)
}
WDefInstance(s.info, s.name, s.module, BundleType(fieldsx))
@@ -197,16 +197,16 @@ object LowerTypes extends Pass {
case s: DefMemory =>
sinfo = s.info
memDataTypeMap += (s.name -> s.dataType)
- if (s.dataType.isGround) {
- s
- } else {
- val exps = create_exps(s.name, s.dataType)
- val stmts = exps map { e =>
- DefMemory(s.info, loweredName(e), tpe(e), s.depth,
- s.writeLatency, s.readLatency, s.readers, s.writers,
- s.readwriters)
- }
- Block(stmts)
+ s.dataType match {
+ case _: GroundType => s
+ case _ =>
+ val exps = create_exps(s.name, s.dataType)
+ val stmts = exps map { e =>
+ DefMemory(s.info, loweredName(e), e.tpe, s.depth,
+ s.writeLatency, s.readLatency, s.readers, s.writers,
+ s.readwriters)
+ }
+ Block(stmts)
}
// wire foo : { a , b }
// node x = foo
@@ -217,7 +217,7 @@ object LowerTypes extends Pass {
// node y = x_a
case s: DefNode =>
sinfo = s.info
- val names = create_exps(s.name, tpe(s.value)) map (lowerTypesExp)
+ val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp)
val exps = create_exps(s.value) map (lowerTypesExp)
val stmts = names zip exps map { case (n, e) =>
DefNode(s.info, loweredName(n), e)
@@ -249,7 +249,7 @@ object LowerTypes extends Pass {
// Lower Ports
val portsx = m.ports flatMap { p =>
val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction)))
- exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), tpe(e)) )
+ exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe) )
}
m match {
case m: ExtModule => m.copy(ports = portsx)
diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala
index 0cabc293..f2117761 100644
--- a/src/main/scala/firrtl/passes/PadWidths.scala
+++ b/src/main/scala/firrtl/passes/PadWidths.scala
@@ -2,7 +2,7 @@ package firrtl
package passes
import firrtl.Mappers.{ExpMap, StmtMap}
-import firrtl.Utils.{tpe, long_BANG}
+import firrtl.Utils.long_BANG
import firrtl.PrimOps._
import firrtl.ir._
@@ -10,10 +10,10 @@ import firrtl.ir._
object PadWidths extends Pass {
def name = "Pad Widths"
private def width(t: Type): Int = long_BANG(t).toInt
- private def width(e: Expression): Int = width(tpe(e))
+ private def width(e: Expression): Int = width(e.tpe)
// Returns an expression with the correct integer width
private def fixup(i: Int)(e: Expression) = {
- def tx = tpe(e) match {
+ def tx = e.tpe match {
case t: UIntType => UIntType(IntWidth(i))
case t: SIntType => SIntType(IntWidth(i))
// default case should never be reached
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 7b4f9aa2..6b6dc811 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -103,7 +103,7 @@ object ResolveKinds extends Pass {
def resolve (body:Statement) = {
def resolve_expr (e:Expression):Expression = {
e match {
- case e:WRef => WRef(e.name,tpe(e),kinds(e.name),e.gender)
+ case e:WRef => WRef(e.name,e.tpe,kinds(e.name),e.gender)
case e => e map (resolve_expr)
}
}
@@ -170,11 +170,11 @@ object InferTypes extends Pass {
val types = LinkedHashMap[String,Type]()
def infer_types_e (e:Expression) : Expression = {
e map (infer_types_e) match {
- case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value))
+ case e:ValidIf => ValidIf(e.cond,e.value,e.value.tpe)
case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender)
- case e:WSubField => WSubField(e.exp,e.name,field_type(tpe(e.exp),e.name),e.gender)
- case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(tpe(e.exp)),e.gender)
- case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(tpe(e.exp)),e.gender)
+ case e:WSubField => WSubField(e.exp,e.name,field_type(e.exp.tpe,e.name),e.gender)
+ case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(e.exp.tpe),e.gender)
+ case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(e.exp.tpe),e.gender)
case e:DoPrim => set_primop_type(e)
case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval))
case e:UIntLiteral => e
@@ -246,7 +246,7 @@ object ResolveGenders extends Pass {
case e:WRef => WRef(e.name,e.tpe,e.kind,g)
case e:WSubField => {
val expx =
- field_flip(tpe(e.exp),e.name) match {
+ field_flip(e.exp.tpe,e.name) match {
case Default => resolve_e(g)(e.exp)
case Flip => resolve_e(swap(g))(e.exp)
}
@@ -474,7 +474,7 @@ object InferWidths extends Pass {
case (t:SIntType) => t.width
case ClockType => IntWidth(1)
case (t) => error("No width!"); IntWidth(-1) } }
- def width_BANG (e:Expression) : Width = width_BANG(tpe(e))
+ def width_BANG (e:Expression) : Width = width_BANG(e.tpe)
def reduce_var_widths(c: Circuit, h: LinkedHashMap[String,Width]): Circuit = {
def evaluate(w: Width): Width = {
@@ -549,40 +549,40 @@ object InferWidths extends Pass {
def get_constraints_e (e:Expression) : Expression = {
(e map (get_constraints_e)) match {
case (e:Mux) => {
- constrain(width_BANG(e.cond),ONE)
- constrain(ONE,width_BANG(e.cond))
+ constrain(width_BANG(e.cond),IntWidth(1))
+ constrain(IntWidth(1),width_BANG(e.cond))
e }
case (e) => e }}
def get_constraints (s:Statement) : Statement = {
(s map (get_constraints_e)) match {
case (s:Connect) => {
- val n = get_size(tpe(s.loc))
+ val n = get_size(s.loc.tpe)
val ce_loc = create_exps(s.loc)
val ce_exp = create_exps(s.expr)
for (i <- 0 until n) {
val locx = ce_loc(i)
val expx = ce_exp(i)
- get_flip(tpe(s.loc),i,Default) match {
+ get_flip(s.loc.tpe,i,Default) match {
case Default => constrain(width_BANG(locx),width_BANG(expx))
case Flip => constrain(width_BANG(expx),width_BANG(locx)) }}
s }
case (s:PartialConnect) => {
- val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default)
+ val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
for (x <- ls) {
val locx = create_exps(s.loc)(x._1)
val expx = create_exps(s.expr)(x._2)
- get_flip(tpe(s.loc),x._1,Default) match {
+ get_flip(s.loc.tpe,x._1,Default) match {
case Default => constrain(width_BANG(locx),width_BANG(expx))
case Flip => constrain(width_BANG(expx),width_BANG(locx)) }}
s }
case (s:DefRegister) => {
- constrain(width_BANG(s.reset),ONE)
- constrain(ONE,width_BANG(s.reset))
- get_constraints_t(s.tpe,tpe(s.init),Default)
+ constrain(width_BANG(s.reset),IntWidth(1))
+ constrain(IntWidth(1),width_BANG(s.reset))
+ get_constraints_t(s.tpe,s.init.tpe,Default)
s }
case (s:Conditionally) => {
- v += WGeq(width_BANG(s.pred),ONE)
- v += WGeq(ONE,width_BANG(s.pred))
+ v += WGeq(width_BANG(s.pred),IntWidth(1))
+ v += WGeq(IntWidth(1),width_BANG(s.pred))
s map (get_constraints) }
case (s) => s map (get_constraints) }}
@@ -661,7 +661,7 @@ object ExpandConnects extends Pass {
e map (set_gender) match {
case (e:WRef) => WRef(e.name,e.tpe,e.kind,genders(e.name))
case (e:WSubField) => {
- val f = get_field(tpe(e.exp),e.name)
+ val f = get_field(e.exp.tpe,e.name)
val genderx = times(gender(e.exp),f.flip)
WSubField(e.exp,e.name,e.tpe,genderx)
}
@@ -677,7 +677,7 @@ object ExpandConnects extends Pass {
case (s:DefMemory) => { genders(s.name) = MALE; s }
case (s:DefNode) => { genders(s.name) = MALE; s }
case (s:IsInvalid) => {
- val n = get_size(tpe(s.expr))
+ val n = get_size(s.expr.tpe)
val invalids = ArrayBuffer[Statement]()
val exps = create_exps(s.expr)
for (i <- 0 until n) {
@@ -696,14 +696,14 @@ object ExpandConnects extends Pass {
} else Block(invalids)
}
case (s:Connect) => {
- val n = get_size(tpe(s.loc))
+ val n = get_size(s.loc.tpe)
val connects = ArrayBuffer[Statement]()
val locs = create_exps(s.loc)
val exps = create_exps(s.expr)
for (i <- 0 until n) {
val locx = locs(i)
val expx = exps(i)
- val sx = get_flip(tpe(s.loc),i,Default) match {
+ val sx = get_flip(s.loc.tpe,i,Default) match {
case Default => Connect(s.info,locx,expx)
case Flip => Connect(s.info,expx,locx)
}
@@ -712,14 +712,14 @@ object ExpandConnects extends Pass {
Block(connects)
}
case (s:PartialConnect) => {
- val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default)
+ val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
val connects = ArrayBuffer[Statement]()
val locs = create_exps(s.loc)
val exps = create_exps(s.expr)
ls.foreach { x => {
val locx = locs(x._1)
val expx = exps(x._2)
- val sx = get_flip(tpe(s.loc),x._1,Default) match {
+ val sx = get_flip(s.loc.tpe,x._1,Default) match {
case Default => Connect(s.info,locx,expx)
case Flip => Connect(s.info,expx,locx)
}
@@ -755,7 +755,7 @@ object Legalize extends Pass {
def legalizeShiftRight (e: DoPrim): Expression = e.op match {
case Shr => {
val amount = e.consts(0).toInt
- val width = long_BANG(tpe(e.args(0)))
+ val width = long_BANG(e.args(0).tpe)
lazy val msb = width - 1
if (amount >= width) {
e.tpe match {
@@ -771,9 +771,9 @@ object Legalize extends Pass {
case _ => e
}
def legalizeConnect(c: Connect): Statement = {
- val t = tpe(c.loc)
+ val t = c.loc.tpe
val w = long_BANG(t)
- if (w >= long_BANG(tpe(c.expr))) c
+ if (w >= long_BANG(c.expr.tpe)) c
else {
val newType = t match {
case _: UIntType => UIntType(IntWidth(w))
@@ -811,8 +811,8 @@ object VerilogWrap extends Pass {
if (e.op == Tail) {
(a0()) match {
case (e0:DoPrim) => {
- if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),tpe(e))
- else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),tpe(e))
+ if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),e.tpe)
+ else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),e.tpe)
else e
}
case (e0) => e
@@ -913,12 +913,12 @@ object CInferTypes extends Pass {
def infer_types_e (e:Expression) : Expression = {
e map infer_types_e match {
case (e:Reference) => Reference(e.name, types.getOrElse(e.name,UnknownType))
- case (e:SubField) => SubField(e.expr,e.name,field_type(tpe(e.expr),e.name))
- case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(tpe(e.expr)))
- case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(tpe(e.expr)))
+ case (e:SubField) => SubField(e.expr,e.name,field_type(e.expr.tpe,e.name))
+ case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(e.expr.tpe))
+ case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(e.expr.tpe))
case (e:DoPrim) => set_primop_type(e)
case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval))
- case (e:ValidIf) => ValidIf(e.cond,e.value,tpe(e.value))
+ case (e:ValidIf) => ValidIf(e.cond,e.value,e.value.tpe)
case (_:UIntLiteral | _:SIntLiteral) => e
}
}
@@ -1067,8 +1067,8 @@ object RemoveCHIRRTL extends Pass {
val e2s = create_exps(e.fval)
(e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2)))
case (e:ValidIf) =>
- create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1)))
- case (e) => (tpe(e)) match {
+ create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe))
+ case (e) => (e.tpe) match {
case (_:GroundType) => Seq(e)
case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
exps ++ create_exps(SubField(e,f.name,f.tpe)))
@@ -1276,7 +1276,7 @@ object RemoveCHIRRTL extends Pass {
case Some(en) => stmts += Connect(s.info,en,one)
}
if (has_write_mport) {
- val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default)
+ val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
val locs = create_exps(get_mask(s.loc))
for (x <- ls ) {
val locx = locs(x._1)
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index a3ce49f7..880d6b1c 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -76,7 +76,7 @@ object RemoveAccesses extends Pass {
def onStmt(s: Statement): Statement = {
def create_temp(e: Expression): (Statement, Expression) = {
val n = namespace.newTemp
- (DefWire(info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e)))
+ (DefWire(get_info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e)))
}
/** Replaces a subaccess in a given male expression
@@ -94,9 +94,9 @@ object RemoveAccesses extends Pass {
stmts += wire
rs.zipWithIndex foreach {
case (x, i) if i < temps.size =>
- stmts += Connect(info(s),getTemp(i),x.base)
+ stmts += Connect(get_info(s),getTemp(i),x.base)
case (x, i) =>
- stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt)
+ stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt)
}
temp
}
diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala
index 1c9674e1..3b6021ed 100644
--- a/src/main/scala/firrtl/passes/SplitExpressions.scala
+++ b/src/main/scala/firrtl/passes/SplitExpressions.scala
@@ -2,7 +2,7 @@ package firrtl
package passes
import firrtl.Mappers.{ExpMap, StmtMap}
-import firrtl.Utils.{tpe, kind, gender, info}
+import firrtl.Utils.{kind, gender, get_info}
import firrtl.ir._
import scala.collection.mutable
@@ -20,18 +20,18 @@ object SplitExpressions extends Pass {
def split(e: Expression): Expression = e match {
case e: DoPrim => {
val name = namespace.newTemp
- v += DefNode(info(s), name, e)
- WRef(name, tpe(e), kind(e), gender(e))
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), gender(e))
}
case e: Mux => {
val name = namespace.newTemp
- v += DefNode(info(s), name, e)
- WRef(name, tpe(e), kind(e), gender(e))
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), gender(e))
}
case e: ValidIf => {
val name = namespace.newTemp
- v += DefNode(info(s), name, e)
- WRef(name, tpe(e), kind(e), gender(e))
+ v += DefNode(get_info(s), name, e)
+ WRef(name, e.tpe, kind(e), gender(e))
}
case e => e
}
diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala
index b1a20fdd..d034719a 100644
--- a/src/main/scala/firrtl/passes/Uniquify.scala
+++ b/src/main/scala/firrtl/passes/Uniquify.scala
@@ -109,8 +109,9 @@ object Uniquify extends Pass {
val newName = findValidPrefix(f.name, Seq(""), namespace)
namespace += newName
Field(newName, f.flip, f.tpe)
- } map { f =>
- if (f.tpe.isAggregate) {
+ } map { f => f.tpe match {
+ case _: GroundType => f
+ case _ =>
val tpe = recUniquifyNames(f.tpe, collection.mutable.HashSet())
val elts = enumerateNames(tpe)
// Need leading _ for findValidPrefix, it doesn't add _ for checks
@@ -123,8 +124,6 @@ object Uniquify extends Pass {
}
namespace ++= (elts map (e => LowerTypes.loweredName(prefix +: e)))
Field(prefix, f.flip, tpe)
- } else {
- f
}
}
BundleType(newFields)
@@ -349,7 +348,9 @@ object Uniquify extends Pass {
def uniquifyPorts(m: DefModule): DefModule = {
def uniquifyPorts(ports: Seq[Port]): Seq[Port] = {
- val portsType = BundleType(ports map (_.toField))
+ val portsType = BundleType(ports map {
+ case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe)
+ })
val uniquePortsType = uniquifyNames(portsType, collection.mutable.HashSet())
val localMap = createNameMapping(portsType, uniquePortsType)
portNameMap += (m.name -> localMap)