aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/ExpandConnects.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/ExpandConnects.scala')
-rw-r--r--src/main/scala/firrtl/passes/ExpandConnects.scala86
1 files changed, 86 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandConnects.scala b/src/main/scala/firrtl/passes/ExpandConnects.scala
new file mode 100644
index 00000000..250c9ce0
--- /dev/null
+++ b/src/main/scala/firrtl/passes/ExpandConnects.scala
@@ -0,0 +1,86 @@
+package firrtl.passes
+
+import firrtl.Utils.{create_exps, flow, get_field, get_valid_points, times, to_flip, to_flow}
+import firrtl.ir._
+import firrtl.options.{PreservesAll, Dependency}
+import firrtl.{DuplexFlow, Flow, SinkFlow, SourceFlow, Transform, WDefInstance, WRef, WSubAccess, WSubField, WSubIndex}
+import firrtl.Mappers._
+
+object ExpandConnects extends Pass with PreservesAll[Transform] {
+
+ override val prerequisites =
+ Seq( Dependency(PullMuxes),
+ Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped
+
+ def run(c: Circuit): Circuit = {
+ def expand_connects(m: Module): Module = {
+ val flows = collection.mutable.LinkedHashMap[String,Flow]()
+ def expand_s(s: Statement): Statement = {
+ def set_flow(e: Expression): Expression = e map set_flow match {
+ case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name))
+ case ex: WSubField =>
+ val f = get_field(ex.expr.tpe, ex.name)
+ val flowx = times(flow(ex.expr), f.flip)
+ WSubField(ex.expr, ex.name, ex.tpe, flowx)
+ case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr))
+ case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr))
+ case ex => ex
+ }
+ s match {
+ case sx: DefWire => flows(sx.name) = DuplexFlow; sx
+ case sx: DefRegister => flows(sx.name) = DuplexFlow; sx
+ case sx: WDefInstance => flows(sx.name) = SourceFlow; sx
+ case sx: DefMemory => flows(sx.name) = SourceFlow; sx
+ case sx: DefNode => flows(sx.name) = SourceFlow; sx
+ case sx: IsInvalid =>
+ val invalids = create_exps(sx.expr).flatMap { case expx =>
+ flow(set_flow(expx)) match {
+ case DuplexFlow => Some(IsInvalid(sx.info, expx))
+ case SinkFlow => Some(IsInvalid(sx.info, expx))
+ case _ => None
+ }
+ }
+ invalids.size match {
+ case 0 => EmptyStmt
+ case 1 => invalids.head
+ case _ => Block(invalids)
+ }
+ case sx: Connect =>
+ val locs = create_exps(sx.loc)
+ val exps = create_exps(sx.expr)
+ Block(locs.zip(exps).map { case (locx, expx) =>
+ to_flip(flow(locx)) match {
+ case Default => Connect(sx.info, locx, expx)
+ case Flip => Connect(sx.info, expx, locx)
+ }
+ })
+ case sx: PartialConnect =>
+ val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default)
+ val locs = create_exps(sx.loc)
+ val exps = create_exps(sx.expr)
+ val stmts = ls map { case (x, y) =>
+ locs(x).tpe match {
+ case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y)))
+ case _ =>
+ to_flip(flow(locs(x))) match {
+ case Default => Connect(sx.info, locs(x), exps(y))
+ case Flip => Connect(sx.info, exps(y), locs(x))
+ }
+ }
+ }
+ Block(stmts)
+ case sx => sx map expand_s
+ }
+ }
+
+ m.ports.foreach { p => flows(p.name) = to_flow(p.direction) }
+ Module(m.info, m.name, m.ports, expand_s(m.body))
+ }
+
+ val modulesx = c.modules.map {
+ case (m: ExtModule) => m
+ case (m: Module) => expand_connects(m)
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
+}