aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdam Izraelevitz2020-04-10 12:47:33 -0700
committerGitHub2020-04-10 19:47:33 +0000
commit54ff9451f285cc18bca0ab519e013ff8326538b8 (patch)
treefff7384be0131e52249b245fe78e18f413fc0c61
parent632930723adc2f78d4d6445acf5f8bcc250a6c0c (diff)
Split Passes.scala into separate files (#1496)
* Split Passes.scala into separate files * Add imports of implicit things Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
-rw-r--r--src/main/scala/firrtl/passes/ExpandConnects.scala86
-rw-r--r--src/main/scala/firrtl/passes/Legalize.scala89
-rw-r--r--src/main/scala/firrtl/passes/Pass.scala39
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala367
-rw-r--r--src/main/scala/firrtl/passes/PullMuxes.scala47
-rw-r--r--src/main/scala/firrtl/passes/ToWorkingIR.scala28
-rw-r--r--src/main/scala/firrtl/passes/VerilogPrep.scala106
7 files changed, 395 insertions, 367 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandConnects.scala b/src/main/scala/firrtl/passes/ExpandConnects.scala
new file mode 100644
index 00000000..250c9ce0
--- /dev/null
+++ b/src/main/scala/firrtl/passes/ExpandConnects.scala
@@ -0,0 +1,86 @@
+package firrtl.passes
+
+import firrtl.Utils.{create_exps, flow, get_field, get_valid_points, times, to_flip, to_flow}
+import firrtl.ir._
+import firrtl.options.{PreservesAll, Dependency}
+import firrtl.{DuplexFlow, Flow, SinkFlow, SourceFlow, Transform, WDefInstance, WRef, WSubAccess, WSubField, WSubIndex}
+import firrtl.Mappers._
+
+object ExpandConnects extends Pass with PreservesAll[Transform] {
+
+ override val prerequisites =
+ Seq( Dependency(PullMuxes),
+ Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped
+
+ def run(c: Circuit): Circuit = {
+ def expand_connects(m: Module): Module = {
+ val flows = collection.mutable.LinkedHashMap[String,Flow]()
+ def expand_s(s: Statement): Statement = {
+ def set_flow(e: Expression): Expression = e map set_flow match {
+ case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name))
+ case ex: WSubField =>
+ val f = get_field(ex.expr.tpe, ex.name)
+ val flowx = times(flow(ex.expr), f.flip)
+ WSubField(ex.expr, ex.name, ex.tpe, flowx)
+ case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr))
+ case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr))
+ case ex => ex
+ }
+ s match {
+ case sx: DefWire => flows(sx.name) = DuplexFlow; sx
+ case sx: DefRegister => flows(sx.name) = DuplexFlow; sx
+ case sx: WDefInstance => flows(sx.name) = SourceFlow; sx
+ case sx: DefMemory => flows(sx.name) = SourceFlow; sx
+ case sx: DefNode => flows(sx.name) = SourceFlow; sx
+ case sx: IsInvalid =>
+ val invalids = create_exps(sx.expr).flatMap { case expx =>
+ flow(set_flow(expx)) match {
+ case DuplexFlow => Some(IsInvalid(sx.info, expx))
+ case SinkFlow => Some(IsInvalid(sx.info, expx))
+ case _ => None
+ }
+ }
+ invalids.size match {
+ case 0 => EmptyStmt
+ case 1 => invalids.head
+ case _ => Block(invalids)
+ }
+ case sx: Connect =>
+ val locs = create_exps(sx.loc)
+ val exps = create_exps(sx.expr)
+ Block(locs.zip(exps).map { case (locx, expx) =>
+ to_flip(flow(locx)) match {
+ case Default => Connect(sx.info, locx, expx)
+ case Flip => Connect(sx.info, expx, locx)
+ }
+ })
+ case sx: PartialConnect =>
+ val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default)
+ val locs = create_exps(sx.loc)
+ val exps = create_exps(sx.expr)
+ val stmts = ls map { case (x, y) =>
+ locs(x).tpe match {
+ case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y)))
+ case _ =>
+ to_flip(flow(locs(x))) match {
+ case Default => Connect(sx.info, locs(x), exps(y))
+ case Flip => Connect(sx.info, exps(y), locs(x))
+ }
+ }
+ }
+ Block(stmts)
+ case sx => sx map expand_s
+ }
+ }
+
+ m.ports.foreach { p => flows(p.name) = to_flow(p.direction) }
+ Module(m.info, m.name, m.ports, expand_s(m.body))
+ }
+
+ val modulesx = c.modules.map {
+ case (m: ExtModule) => m
+ case (m: Module) => expand_connects(m)
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
+}
diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala
new file mode 100644
index 00000000..37556769
--- /dev/null
+++ b/src/main/scala/firrtl/passes/Legalize.scala
@@ -0,0 +1,89 @@
+package firrtl.passes
+
+import firrtl.PrimOps._
+import firrtl.Utils.{BoolType, error, zero}
+import firrtl.ir._
+import firrtl.options.{PreservesAll, Dependency}
+import firrtl.transforms.ConstantPropagation
+import firrtl.{Transform, bitWidth}
+import firrtl.Mappers._
+
+// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts
+// TODO replace UInt with zero-width wire instead
+object Legalize extends Pass with PreservesAll[Transform] {
+
+ override val prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes)
+
+ override val optionalPrerequisites = Seq.empty
+
+ override val dependents = Seq.empty
+
+ private def legalizeShiftRight(e: DoPrim): Expression = {
+ require(e.op == Shr)
+ e.args.head match {
+ case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e)
+ case _ =>
+ val amount = e.consts.head.toInt
+ val width = bitWidth(e.args.head.tpe)
+ lazy val msb = width - 1
+ if (amount >= width) {
+ e.tpe match {
+ case UIntType(_) => zero
+ case SIntType(_) =>
+ val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
+ DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
+ case t => error(s"Unsupported type $t for Primop Shift Right")
+ }
+ } else {
+ e
+ }
+ }
+ }
+ private def legalizeBitExtract(expr: DoPrim): Expression = {
+ expr.args.head match {
+ case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr)
+ case _ => expr
+ }
+ }
+ private def legalizePad(expr: DoPrim): Expression = expr.args.head match {
+ case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head =>
+ UIntLiteral(value, IntWidth(expr.consts.head))
+ case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head =>
+ SIntLiteral(value, IntWidth(expr.consts.head))
+ case _ => expr
+ }
+ private def legalizeConnect(c: Connect): Statement = {
+ val t = c.loc.tpe
+ val w = bitWidth(t)
+ if (w >= bitWidth(c.expr.tpe)) {
+ c
+ } else {
+ val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w)))
+ val expr = t match {
+ case UIntType(_) => bits
+ case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w)))
+ case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t)
+ }
+ Connect(c.info, c.loc, expr)
+ }
+ }
+ def run (c: Circuit): Circuit = {
+ def legalizeE(expr: Expression): Expression = expr map legalizeE match {
+ case prim: DoPrim => prim.op match {
+ case Shr => legalizeShiftRight(prim)
+ case Pad => legalizePad(prim)
+ case Bits | Head | Tail => legalizeBitExtract(prim)
+ case _ => prim
+ }
+ case e => e // respect pre-order traversal
+ }
+ def legalizeS (s: Statement): Statement = {
+ val legalizedStmt = s match {
+ case c: Connect => legalizeConnect(c)
+ case _ => s
+ }
+ legalizedStmt map legalizeS map legalizeE
+ }
+ c copy (modules = c.modules map (_ map legalizeS))
+ }
+}
diff --git a/src/main/scala/firrtl/passes/Pass.scala b/src/main/scala/firrtl/passes/Pass.scala
new file mode 100644
index 00000000..4673a8e1
--- /dev/null
+++ b/src/main/scala/firrtl/passes/Pass.scala
@@ -0,0 +1,39 @@
+package firrtl.passes
+
+import firrtl.Utils.error
+import firrtl.ir.Circuit
+import firrtl.{CircuitForm, CircuitState, FirrtlUserException, Transform, UnknownForm}
+
+/** [[Pass]] is simple transform that is generally part of a larger [[Transform]]
+ * Has an [[UnknownForm]], because larger [[Transform]] should specify form
+ */
+trait Pass extends Transform {
+ def inputForm: CircuitForm = UnknownForm
+ def outputForm: CircuitForm = UnknownForm
+ def run(c: Circuit): Circuit
+ def execute(state: CircuitState): CircuitState = {
+ val result = (state.form, inputForm) match {
+ case (_, UnknownForm) => run(state.circuit)
+ case (UnknownForm, _) => run(state.circuit)
+ case (x, y) if x > y =>
+ error(s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}")
+ case _ => run(state.circuit)
+ }
+ CircuitState(result, outputForm, state.annotations, state.renames)
+ }
+}
+
+// Error handling
+class PassException(message: String) extends FirrtlUserException(message)
+class PassExceptions(val exceptions: Seq[PassException]) extends FirrtlUserException("\n" + exceptions.mkString("\n"))
+class Errors {
+ val errors = collection.mutable.ArrayBuffer[PassException]()
+ def append(pe: PassException) = errors.append(pe)
+ def trigger() = errors.size match {
+ case 0 =>
+ case 1 => throw errors.head
+ case _ =>
+ append(new PassException(s"${errors.length} errors detected!"))
+ throw new PassExceptions(errors)
+ }
+}
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
deleted file mode 100644
index a8d37758..00000000
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ /dev/null
@@ -1,367 +0,0 @@
-// See LICENSE for license details.
-
-package firrtl.passes
-
-import firrtl._
-import firrtl.ir._
-import firrtl.Utils._
-import firrtl.Mappers._
-import firrtl.PrimOps._
-import firrtl.options.{Dependency, PreservesAll}
-import firrtl.transforms.ConstantPropagation
-
-import scala.collection.mutable
-
-/** [[Pass]] is simple transform that is generally part of a larger [[Transform]]
- * Has an [[UnknownForm]], because larger [[Transform]] should specify form
- */
-trait Pass extends Transform {
- def inputForm: CircuitForm = UnknownForm
- def outputForm: CircuitForm = UnknownForm
- def run(c: Circuit): Circuit
- def execute(state: CircuitState): CircuitState = {
- val result = (state.form, inputForm) match {
- case (_, UnknownForm) => run(state.circuit)
- case (UnknownForm, _) => run(state.circuit)
- case (x, y) if x > y =>
- error(s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}")
- case _ => run(state.circuit)
- }
- CircuitState(result, outputForm, state.annotations, state.renames)
- }
-}
-
-// Error handling
-class PassException(message: String) extends FirrtlUserException(message)
-class PassExceptions(val exceptions: Seq[PassException]) extends FirrtlUserException("\n" + exceptions.mkString("\n"))
-class Errors {
- val errors = collection.mutable.ArrayBuffer[PassException]()
- def append(pe: PassException) = errors.append(pe)
- def trigger() = errors.size match {
- case 0 =>
- case 1 => throw errors.head
- case _ =>
- append(new PassException(s"${errors.length} errors detected!"))
- throw new PassExceptions(errors)
- }
-}
-
-// These should be distributed into separate files
-object ToWorkingIR extends Pass with PreservesAll[Transform] {
-
- override val prerequisites = firrtl.stage.Forms.MinimalHighForm
-
- def toExp(e: Expression): Expression = e map toExp match {
- case ex: Reference => WRef(ex.name, ex.tpe, UnknownKind, UnknownFlow)
- case ex: SubField => WSubField(ex.expr, ex.name, ex.tpe, UnknownFlow)
- case ex: SubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, UnknownFlow)
- case ex: SubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, UnknownFlow)
- case ex => ex // This might look like a case to use case _ => e, DO NOT!
- }
-
- def toStmt(s: Statement): Statement = s map toExp match {
- case sx: DefInstance => WDefInstance(sx.info, sx.name, sx.module, UnknownType)
- case sx => sx map toStmt
- }
-
- def run (c:Circuit): Circuit =
- c copy (modules = c.modules map (_ map toStmt))
-}
-
-object PullMuxes extends Pass with PreservesAll[Transform] {
-
- override val prerequisites = firrtl.stage.Forms.Deduped
-
- def run(c: Circuit): Circuit = {
- def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match {
- case ex: WSubField => ex.expr match {
- case exx: Mux => Mux(exx.cond,
- WSubField(exx.tval, ex.name, ex.tpe, ex.flow),
- WSubField(exx.fval, ex.name, ex.tpe, ex.flow), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex: WSubIndex => ex.expr match {
- case exx: Mux => Mux(exx.cond,
- WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow),
- WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex: WSubAccess => ex.expr match {
- case exx: Mux => Mux(exx.cond,
- WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow),
- WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), ex.tpe)
- case exx: ValidIf => ValidIf(exx.cond,
- WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe)
- case _ => ex // case exx => exx causes failed tests
- }
- case ex => ex
- }
- def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e
- val modulesx = c.modules.map {
- case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body))
- case (m:ExtModule) => m
- }
- Circuit(c.info, modulesx, c.main)
- }
-}
-
-object ExpandConnects extends Pass with PreservesAll[Transform] {
-
- override val prerequisites =
- Seq( Dependency(PullMuxes),
- Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped
-
- def run(c: Circuit): Circuit = {
- def expand_connects(m: Module): Module = {
- val flows = collection.mutable.LinkedHashMap[String,Flow]()
- def expand_s(s: Statement): Statement = {
- def set_flow(e: Expression): Expression = e map set_flow match {
- case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name))
- case ex: WSubField =>
- val f = get_field(ex.expr.tpe, ex.name)
- val flowx = times(flow(ex.expr), f.flip)
- WSubField(ex.expr, ex.name, ex.tpe, flowx)
- case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr))
- case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr))
- case ex => ex
- }
- s match {
- case sx: DefWire => flows(sx.name) = DuplexFlow; sx
- case sx: DefRegister => flows(sx.name) = DuplexFlow; sx
- case sx: WDefInstance => flows(sx.name) = SourceFlow; sx
- case sx: DefMemory => flows(sx.name) = SourceFlow; sx
- case sx: DefNode => flows(sx.name) = SourceFlow; sx
- case sx: IsInvalid =>
- val invalids = create_exps(sx.expr).flatMap { case expx =>
- flow(set_flow(expx)) match {
- case DuplexFlow => Some(IsInvalid(sx.info, expx))
- case SinkFlow => Some(IsInvalid(sx.info, expx))
- case _ => None
- }
- }
- invalids.size match {
- case 0 => EmptyStmt
- case 1 => invalids.head
- case _ => Block(invalids)
- }
- case sx: Connect =>
- val locs = create_exps(sx.loc)
- val exps = create_exps(sx.expr)
- Block(locs.zip(exps).map { case (locx, expx) =>
- to_flip(flow(locx)) match {
- case Default => Connect(sx.info, locx, expx)
- case Flip => Connect(sx.info, expx, locx)
- }
- })
- case sx: PartialConnect =>
- val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default)
- val locs = create_exps(sx.loc)
- val exps = create_exps(sx.expr)
- val stmts = ls map { case (x, y) =>
- locs(x).tpe match {
- case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y)))
- case _ =>
- to_flip(flow(locs(x))) match {
- case Default => Connect(sx.info, locs(x), exps(y))
- case Flip => Connect(sx.info, exps(y), locs(x))
- }
- }
- }
- Block(stmts)
- case sx => sx map expand_s
- }
- }
-
- m.ports.foreach { p => flows(p.name) = to_flow(p.direction) }
- Module(m.info, m.name, m.ports, expand_s(m.body))
- }
-
- val modulesx = c.modules.map {
- case (m: ExtModule) => m
- case (m: Module) => expand_connects(m)
- }
- Circuit(c.info, modulesx, c.main)
- }
-}
-
-
-// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts
-// TODO replace UInt with zero-width wire instead
-object Legalize extends Pass with PreservesAll[Transform] {
-
- override val prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes)
-
- override val optionalPrerequisites = Seq.empty
-
- override val dependents = Seq.empty
-
- private def legalizeShiftRight(e: DoPrim): Expression = {
- require(e.op == Shr)
- e.args.head match {
- case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e)
- case _ =>
- val amount = e.consts.head.toInt
- val width = bitWidth(e.args.head.tpe)
- lazy val msb = width - 1
- if (amount >= width) {
- e.tpe match {
- case UIntType(_) => zero
- case SIntType(_) =>
- val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType)
- DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1)))
- case t => error(s"Unsupported type $t for Primop Shift Right")
- }
- } else {
- e
- }
- }
- }
- private def legalizeBitExtract(expr: DoPrim): Expression = {
- expr.args.head match {
- case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr)
- case _ => expr
- }
- }
- private def legalizePad(expr: DoPrim): Expression = expr.args.head match {
- case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head =>
- UIntLiteral(value, IntWidth(expr.consts.head))
- case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head =>
- SIntLiteral(value, IntWidth(expr.consts.head))
- case _ => expr
- }
- private def legalizeConnect(c: Connect): Statement = {
- val t = c.loc.tpe
- val w = bitWidth(t)
- if (w >= bitWidth(c.expr.tpe)) {
- c
- } else {
- val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w)))
- val expr = t match {
- case UIntType(_) => bits
- case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w)))
- case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t)
- }
- Connect(c.info, c.loc, expr)
- }
- }
- def run (c: Circuit): Circuit = {
- def legalizeE(expr: Expression): Expression = expr map legalizeE match {
- case prim: DoPrim => prim.op match {
- case Shr => legalizeShiftRight(prim)
- case Pad => legalizePad(prim)
- case Bits | Head | Tail => legalizeBitExtract(prim)
- case _ => prim
- }
- case e => e // respect pre-order traversal
- }
- def legalizeS (s: Statement): Statement = {
- val legalizedStmt = s match {
- case c: Connect => legalizeConnect(c)
- case _ => s
- }
- legalizedStmt map legalizeS map legalizeE
- }
- c copy (modules = c.modules map (_ map legalizeS))
- }
-}
-
-/** Makes changes to the Firrtl AST to make Verilog emission easier
- *
- * - For each instance, adds wires to connect to each port
- * - Note that no Namespace is required because Uniquify ensures that there will be no
- * collisions with the lowered names of instance ports
- * - Also removes Attaches where a single Port OR Wire connects to 1 or more instance ports
- * - These are expressed in the portCons of WDefInstConnectors
- *
- * @note The result of this pass is NOT legal Firrtl
- */
-object VerilogPrep extends Pass with PreservesAll[Transform] {
-
- override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
- Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
- Dependency[firrtl.transforms.FixAddingNegativeLiterals],
- Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
- Dependency[firrtl.transforms.InlineBitExtractionsTransform],
- Dependency[firrtl.transforms.InlineCastsTransform],
- Dependency[firrtl.transforms.LegalizeClocksTransform],
- Dependency[firrtl.transforms.FlattenRegUpdate],
- Dependency(passes.VerilogModulusCleanup),
- Dependency[firrtl.transforms.VerilogRename] )
-
- override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
-
- override val dependents = Seq.empty
-
- type AttachSourceMap = Map[WrappedExpression, Expression]
-
- // Finds attaches with only a single source (Port or Wire)
- // - Creates a map of attached expressions to their source
- // - Removes the Attach
- private def collectAndRemoveAttach(m: DefModule): (DefModule, AttachSourceMap) = {
- val sourceMap = mutable.HashMap.empty[WrappedExpression, Expression]
- lazy val namespace = Namespace(m)
-
- def onStmt(stmt: Statement): Statement = stmt map onStmt match {
- case attach: Attach =>
- val wires = attach.exprs groupBy kind
- val sources = wires.getOrElse(PortKind, Seq.empty) ++ wires.getOrElse(WireKind, Seq.empty)
- val instPorts = wires.getOrElse(InstanceKind, Seq.empty)
- // Sanity check (Should be caught by CheckTypes)
- assert(sources.size + instPorts.size == attach.exprs.size)
-
- sources match {
- case Seq() => // Zero sources, can add a wire to connect and remove
- val name = namespace.newTemp
- val wire = DefWire(NoInfo, name, instPorts.head.tpe)
- val ref = WRef(wire)
- for (inst <- instPorts) sourceMap(inst) = ref
- wire // Replace the attach with new source wire definition
- case Seq(source) => // One source can be removed
- assert(!sourceMap.contains(source)) // should have been merged
- for (inst <- instPorts) sourceMap(inst) = source
- EmptyStmt
- case moreThanOne =>
- attach
- }
- case s => s
- }
-
- (m map onStmt, sourceMap.toMap)
- }
-
- def run(c: Circuit): Circuit = {
- def lowerE(e: Expression): Expression = e match {
- case (_: WRef | _: WSubField) if kind(e) == InstanceKind =>
- WRef(LowerTypes.loweredName(e), e.tpe, kind(e), flow(e))
- case _ => e map lowerE
- }
-
- def lowerS(attachMap: AttachSourceMap)(s: Statement): Statement = s match {
- case WDefInstance(info, name, module, tpe) =>
- val portRefs = create_exps(WRef(name, tpe, ExpKind, SourceFlow))
- val (portCons, wires) = portRefs.map { p =>
- attachMap.get(p) match {
- // If it has a source in attachMap use that
- case Some(ref) => (p -> ref, None)
- // If no source, create a wire corresponding to the port and connect it up
- case None =>
- val wire = DefWire(info, LowerTypes.loweredName(p), p.tpe)
- (p -> WRef(wire), Some(wire))
- }
- }.unzip
- val newInst = WDefInstanceConnector(info, name, module, tpe, portCons)
- Block(wires.flatten :+ newInst)
- case other => other map lowerS(attachMap) map lowerE
- }
-
- val modulesx = c.modules map { mod =>
- val (modx, attachMap) = collectAndRemoveAttach(mod)
- modx map lowerS(attachMap)
- }
- c.copy(modules = modulesx)
- }
-}
diff --git a/src/main/scala/firrtl/passes/PullMuxes.scala b/src/main/scala/firrtl/passes/PullMuxes.scala
new file mode 100644
index 00000000..8befd9fa
--- /dev/null
+++ b/src/main/scala/firrtl/passes/PullMuxes.scala
@@ -0,0 +1,47 @@
+package firrtl.passes
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.options.PreservesAll
+import firrtl.{Transform, WSubAccess, WSubField, WSubIndex}
+
+object PullMuxes extends Pass with PreservesAll[Transform] {
+
+ override val prerequisites = firrtl.stage.Forms.Deduped
+
+ def run(c: Circuit): Circuit = {
+ def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match {
+ case ex: WSubField => ex.expr match {
+ case exx: Mux => Mux(exx.cond,
+ WSubField(exx.tval, ex.name, ex.tpe, ex.flow),
+ WSubField(exx.fval, ex.name, ex.tpe, ex.flow), ex.tpe)
+ case exx: ValidIf => ValidIf(exx.cond,
+ WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex: WSubIndex => ex.expr match {
+ case exx: Mux => Mux(exx.cond,
+ WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow),
+ WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), ex.tpe)
+ case exx: ValidIf => ValidIf(exx.cond,
+ WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex: WSubAccess => ex.expr match {
+ case exx: Mux => Mux(exx.cond,
+ WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow),
+ WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), ex.tpe)
+ case exx: ValidIf => ValidIf(exx.cond,
+ WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe)
+ case _ => ex // case exx => exx causes failed tests
+ }
+ case ex => ex
+ }
+ def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e
+ val modulesx = c.modules.map {
+ case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body))
+ case (m:ExtModule) => m
+ }
+ Circuit(c.info, modulesx, c.main)
+ }
+}
diff --git a/src/main/scala/firrtl/passes/ToWorkingIR.scala b/src/main/scala/firrtl/passes/ToWorkingIR.scala
new file mode 100644
index 00000000..109654ee
--- /dev/null
+++ b/src/main/scala/firrtl/passes/ToWorkingIR.scala
@@ -0,0 +1,28 @@
+package firrtl.passes
+
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.options.{PreservesAll}
+import firrtl.{Transform, UnknownFlow, UnknownKind, WDefInstance, WRef, WSubAccess, WSubField, WSubIndex}
+
+// These should be distributed into separate files
+object ToWorkingIR extends Pass with PreservesAll[Transform] {
+
+ override val prerequisites = firrtl.stage.Forms.MinimalHighForm
+
+ def toExp(e: Expression): Expression = e map toExp match {
+ case ex: Reference => WRef(ex.name, ex.tpe, UnknownKind, UnknownFlow)
+ case ex: SubField => WSubField(ex.expr, ex.name, ex.tpe, UnknownFlow)
+ case ex: SubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, UnknownFlow)
+ case ex: SubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, UnknownFlow)
+ case ex => ex // This might look like a case to use case _ => e, DO NOT!
+ }
+
+ def toStmt(s: Statement): Statement = s map toExp match {
+ case sx: DefInstance => WDefInstance(sx.info, sx.name, sx.module, UnknownType)
+ case sx => sx map toStmt
+ }
+
+ def run (c:Circuit): Circuit =
+ c copy (modules = c.modules map (_ map toStmt))
+}
diff --git a/src/main/scala/firrtl/passes/VerilogPrep.scala b/src/main/scala/firrtl/passes/VerilogPrep.scala
new file mode 100644
index 00000000..776c0f5f
--- /dev/null
+++ b/src/main/scala/firrtl/passes/VerilogPrep.scala
@@ -0,0 +1,106 @@
+package firrtl.passes
+
+import firrtl.Utils.{create_exps, flow, kind, toWrappedExpression}
+import firrtl.ir._
+import firrtl.Mappers._
+import firrtl.options.{Dependency, PreservesAll}
+import firrtl._
+
+import scala.collection.mutable
+
+/** Makes changes to the Firrtl AST to make Verilog emission easier
+ *
+ * - For each instance, adds wires to connect to each port
+ * - Note that no Namespace is required because Uniquify ensures that there will be no
+ * collisions with the lowered names of instance ports
+ * - Also removes Attaches where a single Port OR Wire connects to 1 or more instance ports
+ * - These are expressed in the portCons of WDefInstConnectors
+ *
+ * @note The result of this pass is NOT legal Firrtl
+ */
+object VerilogPrep extends Pass with PreservesAll[Transform] {
+
+ override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++
+ Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper],
+ Dependency[firrtl.transforms.FixAddingNegativeLiterals],
+ Dependency[firrtl.transforms.ReplaceTruncatingArithmetic],
+ Dependency[firrtl.transforms.InlineBitExtractionsTransform],
+ Dependency[firrtl.transforms.InlineCastsTransform],
+ Dependency[firrtl.transforms.LegalizeClocksTransform],
+ Dependency[firrtl.transforms.FlattenRegUpdate],
+ Dependency(passes.VerilogModulusCleanup),
+ Dependency[firrtl.transforms.VerilogRename] )
+
+ override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized
+
+ override val dependents = Seq.empty
+
+ type AttachSourceMap = Map[WrappedExpression, Expression]
+
+ // Finds attaches with only a single source (Port or Wire)
+ // - Creates a map of attached expressions to their source
+ // - Removes the Attach
+ private def collectAndRemoveAttach(m: DefModule): (DefModule, AttachSourceMap) = {
+ val sourceMap = mutable.HashMap.empty[WrappedExpression, Expression]
+ lazy val namespace = Namespace(m)
+
+ def onStmt(stmt: Statement): Statement = stmt map onStmt match {
+ case attach: Attach =>
+ val wires = attach.exprs groupBy kind
+ val sources = wires.getOrElse(PortKind, Seq.empty) ++ wires.getOrElse(WireKind, Seq.empty)
+ val instPorts = wires.getOrElse(InstanceKind, Seq.empty)
+ // Sanity check (Should be caught by CheckTypes)
+ assert(sources.size + instPorts.size == attach.exprs.size)
+
+ sources match {
+ case Seq() => // Zero sources, can add a wire to connect and remove
+ val name = namespace.newTemp
+ val wire = DefWire(NoInfo, name, instPorts.head.tpe)
+ val ref = WRef(wire)
+ for (inst <- instPorts) sourceMap(inst) = ref
+ wire // Replace the attach with new source wire definition
+ case Seq(source) => // One source can be removed
+ assert(!sourceMap.contains(source)) // should have been merged
+ for (inst <- instPorts) sourceMap(inst) = source
+ EmptyStmt
+ case moreThanOne =>
+ attach
+ }
+ case s => s
+ }
+
+ (m map onStmt, sourceMap.toMap)
+ }
+
+ def run(c: Circuit): Circuit = {
+ def lowerE(e: Expression): Expression = e match {
+ case (_: WRef | _: WSubField) if kind(e) == InstanceKind =>
+ WRef(LowerTypes.loweredName(e), e.tpe, kind(e), flow(e))
+ case _ => e map lowerE
+ }
+
+ def lowerS(attachMap: AttachSourceMap)(s: Statement): Statement = s match {
+ case WDefInstance(info, name, module, tpe) =>
+ val portRefs = create_exps(WRef(name, tpe, ExpKind, SourceFlow))
+ val (portCons, wires) = portRefs.map { p =>
+ attachMap.get(p) match {
+ // If it has a source in attachMap use that
+ case Some(ref) => (p -> ref, None)
+ // If no source, create a wire corresponding to the port and connect it up
+ case None =>
+ val wire = DefWire(info, LowerTypes.loweredName(p), p.tpe)
+ (p -> WRef(wire), Some(wire))
+ }
+ }.unzip
+ val newInst = WDefInstanceConnector(info, name, module, tpe, portCons)
+ Block(wires.flatten :+ newInst)
+ case other => other map lowerS(attachMap) map lowerE
+ }
+
+ val modulesx = c.modules map { mod =>
+ val (modx, attachMap) = collectAndRemoveAttach(mod)
+ modx map lowerS(attachMap)
+ }
+ c.copy(modules = modulesx)
+ }
+}