aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorazidar2016-02-04 09:23:19 -0800
committerazidar2016-02-09 18:57:06 -0800
commitb32acb9a52a426087226284f4a1e2890cbdadc00 (patch)
treee1771c82f9e707d95b507e67455a1e7fbbffea6a /src
parentddeac42c426dbda9000eef1b74f8d5032c55f58f (diff)
Added Expand Whens pass
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/IR.scala3
-rw-r--r--src/main/scala/firrtl/Utils.scala45
-rw-r--r--src/main/scala/firrtl/WIR.scala35
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala211
-rw-r--r--src/main/stanza/ir-utils.stanza4
5 files changed, 283 insertions, 15 deletions
diff --git a/src/main/scala/firrtl/IR.scala b/src/main/scala/firrtl/IR.scala
index 3656ef22..85870ab9 100644
--- a/src/main/scala/firrtl/IR.scala
+++ b/src/main/scala/firrtl/IR.scala
@@ -61,7 +61,7 @@ case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type)
case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression
case class UIntValue(value: BigInt, width: Width) extends Expression
case class SIntValue(value: BigInt, width: Width) extends Expression
-case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression
+case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression
trait Stmt extends AST
case class DefWire(info: Info, name: String, tpe: Type) extends Stmt
@@ -114,4 +114,3 @@ case class ExModule(info: Info, name: String, ports: Seq[Port]) extends Module
case class Circuit(info: Info, modules: Seq[Module], main: String) extends AST
-
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index be17c61e..406e393c 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -23,7 +23,12 @@ object Utils {
// Is there a more elegant way to do this?
private type FlagMap = Map[String, Boolean]
private val FlagMap = Map[String, Boolean]().withDefaultValue(false)
-
+ implicit class WithAs[T](x: T) {
+ import scala.reflect._
+ def as[O: ClassTag]: Option[O] = x match {
+ case o: O => Some(o)
+ case _ => None } }
+ implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x)
def ceil_log2(x: BigInt): BigInt = (x-1).bitLength
def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt
val gen_names = Map[String,Int]()
@@ -68,10 +73,13 @@ object Utils {
else if (e2 == zero) e1
else DoPrim(OR_OP,Seq(e1,e2),Seq(),UIntType(IntWidth(1)))
}
-
- def EQV (e1:Expression,e2:Expression) : Expression = {
- DoPrim(EQUAL_OP,Seq(e1,e2),Seq(),tpe(e1))
+ def EQV (e1:Expression,e2:Expression) : Expression = { DoPrim(EQUAL_OP,Seq(e1,e2),Seq(),tpe(e1)) }
+ def NOT (e1:Expression) : Expression = {
+ if (e1 == one) zero
+ else if (e1 == zero) one
+ else DoPrim(EQUAL_OP,Seq(e1,zero),Seq(),UIntType(IntWidth(1)))
}
+
//def MUX (p:Expression,e1:Expression,e2:Expression) : Expression = {
// Mux(p,e1,e2,mux_type(tpe(e1),tpe(e2)))
@@ -486,6 +494,35 @@ object Utils {
case s:DefInstance => UnknownType()
case _ => UnknownType()
}}
+ def get_name (s:Stmt) : String = {
+ s match {
+ case s:DefWire => s.name
+ case s:DefPoison => s.name
+ case s:DefRegister => s.name
+ case s:DefNode => s.name
+ case s:DefMemory => s.name
+ case s:DefInstance => s.name
+ case s:WDefInstance => s.name
+ case _ => error("Shouldn't be here"); "blah"
+ }}
+ def get_info (s:Stmt) : Info = {
+ s match {
+ case s:DefWire => s.info
+ case s:DefPoison => s.info
+ case s:DefRegister => s.info
+ case s:DefInstance => s.info
+ case s:WDefInstance => s.info
+ case s:DefMemory => s.info
+ case s:DefNode => s.info
+ case s:Conditionally => s.info
+ case s:BulkConnect => s.info
+ case s:Connect => s.info
+ case s:IsInvalid => s.info
+ case s:Stop => s.info
+ case s:Print => s.info
+ case _ => error("Shouldn't be here"); NoInfo
+ }}
+
// =============== MAPPERS ===================
def sMap(f:Stmt => Stmt, stmt: Stmt): Stmt =
diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala
index 6fc57b8e..35fcb93a 100644
--- a/src/main/scala/firrtl/WIR.scala
+++ b/src/main/scala/firrtl/WIR.scala
@@ -27,9 +27,42 @@ case class WSubIndex(exp:Expression,value:Int,tpe:Type,gender:Gender) extends Ex
case class WSubAccess(exp:Expression,index:Expression,tpe:Type,gender:Gender) extends Expression
case class WVoid() extends Expression
case class WInvalid() extends Expression
-
case class WDefInstance(info:Info,name:String,module:String,tpe:Type) extends Stmt
+class WrappedExpression (val e1:Expression) {
+ override def equals (we:Any) = {
+ we match {
+ case (we:WrappedExpression) => {
+ (e1,we.e1) match {
+ case (e1:UIntValue,e2:UIntValue) => if (e1.value == e2.value) true else false
+ // TODO is this necessary? width(e1) == width(e2)
+ case (e1:SIntValue,e2:SIntValue) => if (e1.value == e2.value) true else false
+ // TODO is this necessary? width(e1) == width(e2)
+ case (e1:WRef,e2:WRef) => e1.name equals e2.name
+ case (e1:WSubField,e2:WSubField) => (e1.name equals e2.name) && (e1.exp == e2.exp)
+ case (e1:WSubIndex,e2:WSubIndex) => (e1.value == e2.value) && (e1.exp == e2.exp)
+ case (e1:WSubAccess,e2:WSubAccess) => (e1.index == e2.index) && (e1.exp == e2.exp)
+ case (e1:WVoid,e2:WVoid) => true
+ case (e1:WInvalid,e2:WInvalid) => true
+ case (e1:DoPrim,e2:DoPrim) => {
+ var are_equal = e1.op == e2.op
+ (e1.args,e2.args).zipped.foreach{ (x,y) => { if (x != y) are_equal = false }}
+ (e1.consts,e2.consts).zipped.foreach{ (x,y) => { if (x != y) are_equal = false }}
+ are_equal
+ }
+ case (e1:Mux,e2:Mux) => (e1.cond == e2.cond) && (e1.tval == e2.tval) && (e1.fval == e2.fval)
+ case (e1:ValidIf,e2:ValidIf) => (e1.cond == e2.cond) && (e1.value == e2.value)
+ case (e1,e2) => false
+ }
+ }
+ case _ => false
+ }
+ }
+ override def hashCode = e1.serialize().hashCode
+ override def toString = e1.serialize()
+}
+
+
case class VarWidth(name:String) extends Width
case class PlusWidth(arg1:Width,arg2:Width) extends Width
case class MinusWidth(arg1:Width,arg2:Width) extends Width
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 591f4c99..7cd4fdcf 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -31,7 +31,7 @@ trait StanzaPass extends LazyLogging {
val fromStanza = Files.createTempFile(Paths.get(""), n, ".fir")
Files.write(toStanza, c.serialize.getBytes)
- val cmd = Seq("firrtl-stanza", "-i", toStanza.toString, "-o", fromStanza.toString, "-b", "firrtl") ++
+ val cmd = Seq("firrtl-stanza", "-i", toStanza.toString, "-o", fromStanza.toString, "-b", "firrtl", "-p", "c") ++
stanzaPasses.flatMap(x=>Seq("-x", x))
logger.debug(cmd.mkString(" "))
val ret = cmd.!
@@ -44,12 +44,18 @@ trait StanzaPass extends LazyLogging {
}
object PassUtils extends LazyLogging {
- val listOfPasses: Seq[Pass] = Seq(ToWorkingIR,ResolveKinds,ResolveGenders,PullMuxes,ExpandConnects,RemoveAccesses)
+ val listOfPasses: Seq[Pass] = Seq(ToWorkingIR,ResolveKinds,ResolveGenders,PullMuxes,ExpandConnects,RemoveAccesses,ExpandWhens)
lazy val mapNameToPass: Map[String, Pass] = listOfPasses.map(p => p.name -> p).toMap
- def executePasses(c: Circuit, passes: Seq[Pass]): Circuit = {
+ def executePasses(c: Circuit, passes: Seq[Pass]): Circuit = {
if (passes.isEmpty) c
- else executePasses(passes.head.run(c), passes.tail)
+ else {
+ val p = passes.head
+ val name = p.name
+ logger.debug(c.serialize())
+ logger.debug(s"Starting ${name}")
+ executePasses(p.run(c), passes.tail)
+ }
}
}
@@ -623,8 +629,201 @@ object RemoveAccesses extends Pass {
}
object ExpandWhens extends Pass with StanzaPass {
- def name = "Expand Whens"
- def run (c:Circuit): Circuit = stanzaPass(c, "expand-whens")
+ def name = "Expand Whens"
+ var mname = ""
+// ; ========== Expand When Utilz ==========
+ def add (hash:HashMap[WrappedExpression,Expression],key:WrappedExpression,value:Expression) = {
+ hash += (key -> value)
+ }
+
+ def get_entries (hash:HashMap[WrappedExpression,Expression],exps:Seq[Expression]) : HashMap[WrappedExpression,Expression] = {
+ val hashx = HashMap[WrappedExpression,Expression]()
+ exps.foreach { e => {
+ val value = hash.get(e)
+ value match {
+ case (value:Some[Expression]) => add(hashx,e,value.get)
+ case (None) => {}
+ }
+ }}
+ hashx
+ }
+ def get_female_refs (n:String,t:Type,g:Gender) : Seq[Expression] = {
+ val exps = create_exps(WRef(n,t,ExpKind(),g))
+ val expsx = ArrayBuffer[Expression]()
+ def get_gender (t:Type, i:Int, g:Gender) : Gender = {
+ val f = get_flip(t,i,DEFAULT)
+ times(g, f)
+ }
+ for (i <- 0 until exps.size) {
+ get_gender(t,i,g) match {
+ case BIGENDER => expsx += exps(i)
+ case FEMALE => expsx += exps(i)
+ case _ => false
+ }
+ }
+ expsx
+ }
+
+ // ------------ Pass -------------------
+ def run (c:Circuit): Circuit = {
+ def void_all (m:InModule) : InModule = {
+ mname = m.name
+ def void_all_s (s:Stmt) : Stmt = {
+ (s) match {
+ case (_:DefWire|_:DefRegister|_:WDefInstance|_:DefMemory) => {
+ val voids = ArrayBuffer[Stmt]()
+ for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) {
+ voids += Connect(get_info(s),e,WVoid())
+ }
+ Begin(Seq(s,Begin(voids)))
+ }
+ case (s) => sMap(void_all_s _,s)
+ }
+ }
+ val voids = ArrayBuffer[Stmt]()
+ for (p <- m.ports) {
+ for (e <- get_female_refs(p.name,p.tpe,get_gender(p))) {
+ voids += Connect(p.info,e,WVoid())
+ }
+ }
+ val bodyx = void_all_s(m.body)
+ voids += bodyx
+ InModule(m.info,m.name,m.ports,Begin(voids))
+ }
+ def expand_whens (m:InModule) : Tuple2[HashMap[WrappedExpression,Expression],ArrayBuffer[Stmt]] = {
+ val simlist = ArrayBuffer[Stmt]()
+ mname = m.name
+ def expand_whens (netlist:HashMap[WrappedExpression,Expression],p:Expression)(s:Stmt) : Stmt = {
+ (s) match {
+ case (s:Connect) => netlist(s.loc) = s.exp
+ case (s:IsInvalid) => netlist(s.exp) = WInvalid()
+ case (s:Conditionally) => {
+ val exps = ArrayBuffer[Expression]()
+ def prefetch (s:Stmt) : Stmt = {
+ (s) match {
+ case (s:Connect) => exps += s.loc; s
+ case (s) => sMap(prefetch _,s)
+ }
+ }
+ prefetch(s.conseq)
+ val c_netlist = get_entries(netlist,exps)
+ expand_whens(c_netlist,AND(p,s.pred))(s.conseq)
+ expand_whens(netlist,AND(p,NOT(s.pred)))(s.alt)
+ for (lvalue <- c_netlist.keys) {
+ val value = netlist.get(lvalue)
+ (value) match {
+ case (value:Some[Expression]) => {
+ val tv = c_netlist(lvalue)
+ val fv = value.get
+ val res = (tv,fv) match {
+ case (tv:WInvalid,fv:WInvalid) => WInvalid()
+ case (tv:WInvalid,fv) => ValidIf(NOT(s.pred),fv,tpe(fv))
+ case (tv,fv:WInvalid) => ValidIf(s.pred,tv,tpe(tv))
+ case (tv,fv) => Mux(s.pred,tv,fv,mux_type_and_widths(tv,fv))
+ }
+ netlist(lvalue) = res
+ }
+ case (None) => add(netlist,lvalue,c_netlist(lvalue))
+ }
+ }
+ }
+ case (s:Print) => {
+ if (p == one) {
+ simlist += s
+ } else {
+ simlist += Print(s.info,s.string,s.args,s.clk,AND(p,s.en))
+ }
+ }
+ case (s:Stop) => {
+ if (p == one) {
+ simlist += s
+ } else {
+ simlist += Stop(s.info,s.ret,s.clk,AND(p,s.en))
+ }
+ }
+ case (s) => sMap(expand_whens(netlist,p) _, s)
+ }
+ s
+ }
+ val netlist = HashMap[WrappedExpression,Expression]()
+ expand_whens(netlist,one)(m.body)
+
+ //println("Netlist:")
+ //println(netlist)
+ //println("Simlist:")
+ //println(simlist)
+ ( netlist, simlist )
+ }
+
+ def create_module (netlist:HashMap[WrappedExpression,Expression],simlist:ArrayBuffer[Stmt],m:InModule) : InModule = {
+ mname = m.name
+ val stmts = ArrayBuffer[Stmt]()
+ val connections = ArrayBuffer[Stmt]()
+ def replace_void (e:Expression)(rvalue:Expression) : Expression = {
+ (rvalue) match {
+ case (rv:WVoid) => e
+ case (rv) => eMap(replace_void(e) _,rv)
+ }
+ }
+ def create (s:Stmt) : Stmt = {
+ (s) match {
+ case (_:DefWire|_:WDefInstance|_:DefMemory) => {
+ stmts += s
+ for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) {
+ val rvalue = netlist(e)
+ val con = (rvalue) match {
+ case (rvalue:WInvalid) => IsInvalid(get_info(s),e)
+ case (rvalue) => Connect(get_info(s),e,rvalue)
+ }
+ connections += con
+ }
+ }
+ case (s:DefRegister) => {
+ stmts += s
+ for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) {
+ val rvalue = replace_void(e)(netlist(e))
+ val con = (rvalue) match {
+ case (rvalue:WInvalid) => IsInvalid(get_info(s),e)
+ case (rvalue) => Connect(get_info(s),e,rvalue)
+ }
+ connections += con
+ }
+ }
+ case (_:DefPoison|_:DefNode) => stmts += s
+ case (s) => sMap(create _,s)
+ }
+ s
+ }
+ create(m.body)
+ for (p <- m.ports) {
+ for (e <- get_female_refs(p.name,p.tpe,get_gender(p))) {
+ val rvalue = netlist(e)
+ val con = (rvalue) match {
+ case (rvalue:WInvalid) => IsInvalid(p.info,e)
+ case (rvalue) => Connect(p.info,e,rvalue)
+ }
+ connections += con
+ }
+ }
+ for (x <- simlist) { stmts += x }
+ InModule(m.info,m.name,m.ports,Begin(Seq(Begin(stmts),Begin(connections))))
+ }
+
+ val voided_modules = c.modules.map{ m => {
+ (m) match {
+ case (m:ExModule) => m
+ case (m:InModule) => void_all(m)
+ } } }
+ val modulesx = voided_modules.map{ m => {
+ (m) match {
+ case (m:ExModule) => m
+ case (m:InModule) => {
+ val (netlist, simlist) = expand_whens(m)
+ create_module(netlist,simlist,m)
+ }
+ }}}
+ Circuit(c.info,modulesx,c.main)
+ }
}
object CheckInitialization extends Pass with StanzaPass {
diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza
index 3fc8b155..22ee228b 100644
--- a/src/main/stanza/ir-utils.stanza
+++ b/src/main/stanza/ir-utils.stanza
@@ -147,8 +147,8 @@ public defn children (e:Expression) -> List<Expression> :
public var mname : Symbol = `blah
public defn exp-hash (e:Expression) -> Int :
turn-off-debug(false)
- val i = symbol-hash(to-symbol(string-join(map(to-string,list(mname `.... e)))))
- ;val i = symbol-hash(to-symbol(to-string(e)))
+ ;val i = symbol-hash(to-symbol(string-join(map(to-string,list(mname `.... e)))))
+ val i = symbol-hash(to-symbol(to-string(e)))
turn-on-debug(false)
i