diff options
| author | azidar | 2016-02-04 09:23:19 -0800 |
|---|---|---|
| committer | azidar | 2016-02-09 18:57:06 -0800 |
| commit | b32acb9a52a426087226284f4a1e2890cbdadc00 (patch) | |
| tree | e1771c82f9e707d95b507e67455a1e7fbbffea6a /src | |
| parent | ddeac42c426dbda9000eef1b74f8d5032c55f58f (diff) | |
Added Expand Whens pass
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/IR.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 45 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 35 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 211 | ||||
| -rw-r--r-- | src/main/stanza/ir-utils.stanza | 4 |
5 files changed, 283 insertions, 15 deletions
diff --git a/src/main/scala/firrtl/IR.scala b/src/main/scala/firrtl/IR.scala index 3656ef22..85870ab9 100644 --- a/src/main/scala/firrtl/IR.scala +++ b/src/main/scala/firrtl/IR.scala @@ -61,7 +61,7 @@ case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression case class UIntValue(value: BigInt, width: Width) extends Expression case class SIntValue(value: BigInt, width: Width) extends Expression -case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression +case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression trait Stmt extends AST case class DefWire(info: Info, name: String, tpe: Type) extends Stmt @@ -114,4 +114,3 @@ case class ExModule(info: Info, name: String, ports: Seq[Port]) extends Module case class Circuit(info: Info, modules: Seq[Module], main: String) extends AST - diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index be17c61e..406e393c 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -23,7 +23,12 @@ object Utils { // Is there a more elegant way to do this? private type FlagMap = Map[String, Boolean] private val FlagMap = Map[String, Boolean]().withDefaultValue(false) - + implicit class WithAs[T](x: T) { + import scala.reflect._ + def as[O: ClassTag]: Option[O] = x match { + case o: O => Some(o) + case _ => None } } + implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x) def ceil_log2(x: BigInt): BigInt = (x-1).bitLength def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt val gen_names = Map[String,Int]() @@ -68,10 +73,13 @@ object Utils { else if (e2 == zero) e1 else DoPrim(OR_OP,Seq(e1,e2),Seq(),UIntType(IntWidth(1))) } - - def EQV (e1:Expression,e2:Expression) : Expression = { - DoPrim(EQUAL_OP,Seq(e1,e2),Seq(),tpe(e1)) + def EQV (e1:Expression,e2:Expression) : Expression = { DoPrim(EQUAL_OP,Seq(e1,e2),Seq(),tpe(e1)) } + def NOT (e1:Expression) : Expression = { + if (e1 == one) zero + else if (e1 == zero) one + else DoPrim(EQUAL_OP,Seq(e1,zero),Seq(),UIntType(IntWidth(1))) } + //def MUX (p:Expression,e1:Expression,e2:Expression) : Expression = { // Mux(p,e1,e2,mux_type(tpe(e1),tpe(e2))) @@ -486,6 +494,35 @@ object Utils { case s:DefInstance => UnknownType() case _ => UnknownType() }} + def get_name (s:Stmt) : String = { + s match { + case s:DefWire => s.name + case s:DefPoison => s.name + case s:DefRegister => s.name + case s:DefNode => s.name + case s:DefMemory => s.name + case s:DefInstance => s.name + case s:WDefInstance => s.name + case _ => error("Shouldn't be here"); "blah" + }} + def get_info (s:Stmt) : Info = { + s match { + case s:DefWire => s.info + case s:DefPoison => s.info + case s:DefRegister => s.info + case s:DefInstance => s.info + case s:WDefInstance => s.info + case s:DefMemory => s.info + case s:DefNode => s.info + case s:Conditionally => s.info + case s:BulkConnect => s.info + case s:Connect => s.info + case s:IsInvalid => s.info + case s:Stop => s.info + case s:Print => s.info + case _ => error("Shouldn't be here"); NoInfo + }} + // =============== MAPPERS =================== def sMap(f:Stmt => Stmt, stmt: Stmt): Stmt = diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 6fc57b8e..35fcb93a 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -27,9 +27,42 @@ case class WSubIndex(exp:Expression,value:Int,tpe:Type,gender:Gender) extends Ex case class WSubAccess(exp:Expression,index:Expression,tpe:Type,gender:Gender) extends Expression case class WVoid() extends Expression case class WInvalid() extends Expression - case class WDefInstance(info:Info,name:String,module:String,tpe:Type) extends Stmt +class WrappedExpression (val e1:Expression) { + override def equals (we:Any) = { + we match { + case (we:WrappedExpression) => { + (e1,we.e1) match { + case (e1:UIntValue,e2:UIntValue) => if (e1.value == e2.value) true else false + // TODO is this necessary? width(e1) == width(e2) + case (e1:SIntValue,e2:SIntValue) => if (e1.value == e2.value) true else false + // TODO is this necessary? width(e1) == width(e2) + case (e1:WRef,e2:WRef) => e1.name equals e2.name + case (e1:WSubField,e2:WSubField) => (e1.name equals e2.name) && (e1.exp == e2.exp) + case (e1:WSubIndex,e2:WSubIndex) => (e1.value == e2.value) && (e1.exp == e2.exp) + case (e1:WSubAccess,e2:WSubAccess) => (e1.index == e2.index) && (e1.exp == e2.exp) + case (e1:WVoid,e2:WVoid) => true + case (e1:WInvalid,e2:WInvalid) => true + case (e1:DoPrim,e2:DoPrim) => { + var are_equal = e1.op == e2.op + (e1.args,e2.args).zipped.foreach{ (x,y) => { if (x != y) are_equal = false }} + (e1.consts,e2.consts).zipped.foreach{ (x,y) => { if (x != y) are_equal = false }} + are_equal + } + case (e1:Mux,e2:Mux) => (e1.cond == e2.cond) && (e1.tval == e2.tval) && (e1.fval == e2.fval) + case (e1:ValidIf,e2:ValidIf) => (e1.cond == e2.cond) && (e1.value == e2.value) + case (e1,e2) => false + } + } + case _ => false + } + } + override def hashCode = e1.serialize().hashCode + override def toString = e1.serialize() +} + + case class VarWidth(name:String) extends Width case class PlusWidth(arg1:Width,arg2:Width) extends Width case class MinusWidth(arg1:Width,arg2:Width) extends Width diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 591f4c99..7cd4fdcf 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -31,7 +31,7 @@ trait StanzaPass extends LazyLogging { val fromStanza = Files.createTempFile(Paths.get(""), n, ".fir") Files.write(toStanza, c.serialize.getBytes) - val cmd = Seq("firrtl-stanza", "-i", toStanza.toString, "-o", fromStanza.toString, "-b", "firrtl") ++ + val cmd = Seq("firrtl-stanza", "-i", toStanza.toString, "-o", fromStanza.toString, "-b", "firrtl", "-p", "c") ++ stanzaPasses.flatMap(x=>Seq("-x", x)) logger.debug(cmd.mkString(" ")) val ret = cmd.! @@ -44,12 +44,18 @@ trait StanzaPass extends LazyLogging { } object PassUtils extends LazyLogging { - val listOfPasses: Seq[Pass] = Seq(ToWorkingIR,ResolveKinds,ResolveGenders,PullMuxes,ExpandConnects,RemoveAccesses) + val listOfPasses: Seq[Pass] = Seq(ToWorkingIR,ResolveKinds,ResolveGenders,PullMuxes,ExpandConnects,RemoveAccesses,ExpandWhens) lazy val mapNameToPass: Map[String, Pass] = listOfPasses.map(p => p.name -> p).toMap - def executePasses(c: Circuit, passes: Seq[Pass]): Circuit = { + def executePasses(c: Circuit, passes: Seq[Pass]): Circuit = { if (passes.isEmpty) c - else executePasses(passes.head.run(c), passes.tail) + else { + val p = passes.head + val name = p.name + logger.debug(c.serialize()) + logger.debug(s"Starting ${name}") + executePasses(p.run(c), passes.tail) + } } } @@ -623,8 +629,201 @@ object RemoveAccesses extends Pass { } object ExpandWhens extends Pass with StanzaPass { - def name = "Expand Whens" - def run (c:Circuit): Circuit = stanzaPass(c, "expand-whens") + def name = "Expand Whens" + var mname = "" +// ; ========== Expand When Utilz ========== + def add (hash:HashMap[WrappedExpression,Expression],key:WrappedExpression,value:Expression) = { + hash += (key -> value) + } + + def get_entries (hash:HashMap[WrappedExpression,Expression],exps:Seq[Expression]) : HashMap[WrappedExpression,Expression] = { + val hashx = HashMap[WrappedExpression,Expression]() + exps.foreach { e => { + val value = hash.get(e) + value match { + case (value:Some[Expression]) => add(hashx,e,value.get) + case (None) => {} + } + }} + hashx + } + def get_female_refs (n:String,t:Type,g:Gender) : Seq[Expression] = { + val exps = create_exps(WRef(n,t,ExpKind(),g)) + val expsx = ArrayBuffer[Expression]() + def get_gender (t:Type, i:Int, g:Gender) : Gender = { + val f = get_flip(t,i,DEFAULT) + times(g, f) + } + for (i <- 0 until exps.size) { + get_gender(t,i,g) match { + case BIGENDER => expsx += exps(i) + case FEMALE => expsx += exps(i) + case _ => false + } + } + expsx + } + + // ------------ Pass ------------------- + def run (c:Circuit): Circuit = { + def void_all (m:InModule) : InModule = { + mname = m.name + def void_all_s (s:Stmt) : Stmt = { + (s) match { + case (_:DefWire|_:DefRegister|_:WDefInstance|_:DefMemory) => { + val voids = ArrayBuffer[Stmt]() + for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) { + voids += Connect(get_info(s),e,WVoid()) + } + Begin(Seq(s,Begin(voids))) + } + case (s) => sMap(void_all_s _,s) + } + } + val voids = ArrayBuffer[Stmt]() + for (p <- m.ports) { + for (e <- get_female_refs(p.name,p.tpe,get_gender(p))) { + voids += Connect(p.info,e,WVoid()) + } + } + val bodyx = void_all_s(m.body) + voids += bodyx + InModule(m.info,m.name,m.ports,Begin(voids)) + } + def expand_whens (m:InModule) : Tuple2[HashMap[WrappedExpression,Expression],ArrayBuffer[Stmt]] = { + val simlist = ArrayBuffer[Stmt]() + mname = m.name + def expand_whens (netlist:HashMap[WrappedExpression,Expression],p:Expression)(s:Stmt) : Stmt = { + (s) match { + case (s:Connect) => netlist(s.loc) = s.exp + case (s:IsInvalid) => netlist(s.exp) = WInvalid() + case (s:Conditionally) => { + val exps = ArrayBuffer[Expression]() + def prefetch (s:Stmt) : Stmt = { + (s) match { + case (s:Connect) => exps += s.loc; s + case (s) => sMap(prefetch _,s) + } + } + prefetch(s.conseq) + val c_netlist = get_entries(netlist,exps) + expand_whens(c_netlist,AND(p,s.pred))(s.conseq) + expand_whens(netlist,AND(p,NOT(s.pred)))(s.alt) + for (lvalue <- c_netlist.keys) { + val value = netlist.get(lvalue) + (value) match { + case (value:Some[Expression]) => { + val tv = c_netlist(lvalue) + val fv = value.get + val res = (tv,fv) match { + case (tv:WInvalid,fv:WInvalid) => WInvalid() + case (tv:WInvalid,fv) => ValidIf(NOT(s.pred),fv,tpe(fv)) + case (tv,fv:WInvalid) => ValidIf(s.pred,tv,tpe(tv)) + case (tv,fv) => Mux(s.pred,tv,fv,mux_type_and_widths(tv,fv)) + } + netlist(lvalue) = res + } + case (None) => add(netlist,lvalue,c_netlist(lvalue)) + } + } + } + case (s:Print) => { + if (p == one) { + simlist += s + } else { + simlist += Print(s.info,s.string,s.args,s.clk,AND(p,s.en)) + } + } + case (s:Stop) => { + if (p == one) { + simlist += s + } else { + simlist += Stop(s.info,s.ret,s.clk,AND(p,s.en)) + } + } + case (s) => sMap(expand_whens(netlist,p) _, s) + } + s + } + val netlist = HashMap[WrappedExpression,Expression]() + expand_whens(netlist,one)(m.body) + + //println("Netlist:") + //println(netlist) + //println("Simlist:") + //println(simlist) + ( netlist, simlist ) + } + + def create_module (netlist:HashMap[WrappedExpression,Expression],simlist:ArrayBuffer[Stmt],m:InModule) : InModule = { + mname = m.name + val stmts = ArrayBuffer[Stmt]() + val connections = ArrayBuffer[Stmt]() + def replace_void (e:Expression)(rvalue:Expression) : Expression = { + (rvalue) match { + case (rv:WVoid) => e + case (rv) => eMap(replace_void(e) _,rv) + } + } + def create (s:Stmt) : Stmt = { + (s) match { + case (_:DefWire|_:WDefInstance|_:DefMemory) => { + stmts += s + for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) { + val rvalue = netlist(e) + val con = (rvalue) match { + case (rvalue:WInvalid) => IsInvalid(get_info(s),e) + case (rvalue) => Connect(get_info(s),e,rvalue) + } + connections += con + } + } + case (s:DefRegister) => { + stmts += s + for (e <- get_female_refs(get_name(s),get_type(s),get_gender(s))) { + val rvalue = replace_void(e)(netlist(e)) + val con = (rvalue) match { + case (rvalue:WInvalid) => IsInvalid(get_info(s),e) + case (rvalue) => Connect(get_info(s),e,rvalue) + } + connections += con + } + } + case (_:DefPoison|_:DefNode) => stmts += s + case (s) => sMap(create _,s) + } + s + } + create(m.body) + for (p <- m.ports) { + for (e <- get_female_refs(p.name,p.tpe,get_gender(p))) { + val rvalue = netlist(e) + val con = (rvalue) match { + case (rvalue:WInvalid) => IsInvalid(p.info,e) + case (rvalue) => Connect(p.info,e,rvalue) + } + connections += con + } + } + for (x <- simlist) { stmts += x } + InModule(m.info,m.name,m.ports,Begin(Seq(Begin(stmts),Begin(connections)))) + } + + val voided_modules = c.modules.map{ m => { + (m) match { + case (m:ExModule) => m + case (m:InModule) => void_all(m) + } } } + val modulesx = voided_modules.map{ m => { + (m) match { + case (m:ExModule) => m + case (m:InModule) => { + val (netlist, simlist) = expand_whens(m) + create_module(netlist,simlist,m) + } + }}} + Circuit(c.info,modulesx,c.main) + } } object CheckInitialization extends Pass with StanzaPass { diff --git a/src/main/stanza/ir-utils.stanza b/src/main/stanza/ir-utils.stanza index 3fc8b155..22ee228b 100644 --- a/src/main/stanza/ir-utils.stanza +++ b/src/main/stanza/ir-utils.stanza @@ -147,8 +147,8 @@ public defn children (e:Expression) -> List<Expression> : public var mname : Symbol = `blah public defn exp-hash (e:Expression) -> Int : turn-off-debug(false) - val i = symbol-hash(to-symbol(string-join(map(to-string,list(mname `.... e))))) - ;val i = symbol-hash(to-symbol(to-string(e))) + ;val i = symbol-hash(to-symbol(string-join(map(to-string,list(mname `.... e))))) + val i = symbol-hash(to-symbol(to-string(e))) turn-on-debug(false) i |
