diff options
| author | Adam Izraelevitz | 2020-04-10 12:47:33 -0700 |
|---|---|---|
| committer | GitHub | 2020-04-10 19:47:33 +0000 |
| commit | 54ff9451f285cc18bca0ab519e013ff8326538b8 (patch) | |
| tree | fff7384be0131e52249b245fe78e18f413fc0c61 | |
| parent | 632930723adc2f78d4d6445acf5f8bcc250a6c0c (diff) | |
Split Passes.scala into separate files (#1496)
* Split Passes.scala into separate files
* Add imports of implicit things
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandConnects.scala | 86 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Legalize.scala | 89 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Pass.scala | 39 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 367 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/PullMuxes.scala | 47 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ToWorkingIR.scala | 28 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/VerilogPrep.scala | 106 |
7 files changed, 395 insertions, 367 deletions
diff --git a/src/main/scala/firrtl/passes/ExpandConnects.scala b/src/main/scala/firrtl/passes/ExpandConnects.scala new file mode 100644 index 00000000..250c9ce0 --- /dev/null +++ b/src/main/scala/firrtl/passes/ExpandConnects.scala @@ -0,0 +1,86 @@ +package firrtl.passes + +import firrtl.Utils.{create_exps, flow, get_field, get_valid_points, times, to_flip, to_flow} +import firrtl.ir._ +import firrtl.options.{PreservesAll, Dependency} +import firrtl.{DuplexFlow, Flow, SinkFlow, SourceFlow, Transform, WDefInstance, WRef, WSubAccess, WSubField, WSubIndex} +import firrtl.Mappers._ + +object ExpandConnects extends Pass with PreservesAll[Transform] { + + override val prerequisites = + Seq( Dependency(PullMuxes), + Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped + + def run(c: Circuit): Circuit = { + def expand_connects(m: Module): Module = { + val flows = collection.mutable.LinkedHashMap[String,Flow]() + def expand_s(s: Statement): Statement = { + def set_flow(e: Expression): Expression = e map set_flow match { + case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name)) + case ex: WSubField => + val f = get_field(ex.expr.tpe, ex.name) + val flowx = times(flow(ex.expr), f.flip) + WSubField(ex.expr, ex.name, ex.tpe, flowx) + case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr)) + case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr)) + case ex => ex + } + s match { + case sx: DefWire => flows(sx.name) = DuplexFlow; sx + case sx: DefRegister => flows(sx.name) = DuplexFlow; sx + case sx: WDefInstance => flows(sx.name) = SourceFlow; sx + case sx: DefMemory => flows(sx.name) = SourceFlow; sx + case sx: DefNode => flows(sx.name) = SourceFlow; sx + case sx: IsInvalid => + val invalids = create_exps(sx.expr).flatMap { case expx => + flow(set_flow(expx)) match { + case DuplexFlow => Some(IsInvalid(sx.info, expx)) + case SinkFlow => Some(IsInvalid(sx.info, expx)) + case _ => None + } + } + invalids.size match { + case 0 => EmptyStmt + case 1 => invalids.head + case _ => Block(invalids) + } + case sx: Connect => + val locs = create_exps(sx.loc) + val exps = create_exps(sx.expr) + Block(locs.zip(exps).map { case (locx, expx) => + to_flip(flow(locx)) match { + case Default => Connect(sx.info, locx, expx) + case Flip => Connect(sx.info, expx, locx) + } + }) + case sx: PartialConnect => + val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default) + val locs = create_exps(sx.loc) + val exps = create_exps(sx.expr) + val stmts = ls map { case (x, y) => + locs(x).tpe match { + case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y))) + case _ => + to_flip(flow(locs(x))) match { + case Default => Connect(sx.info, locs(x), exps(y)) + case Flip => Connect(sx.info, exps(y), locs(x)) + } + } + } + Block(stmts) + case sx => sx map expand_s + } + } + + m.ports.foreach { p => flows(p.name) = to_flow(p.direction) } + Module(m.info, m.name, m.ports, expand_s(m.body)) + } + + val modulesx = c.modules.map { + case (m: ExtModule) => m + case (m: Module) => expand_connects(m) + } + Circuit(c.info, modulesx, c.main) + } +} diff --git a/src/main/scala/firrtl/passes/Legalize.scala b/src/main/scala/firrtl/passes/Legalize.scala new file mode 100644 index 00000000..37556769 --- /dev/null +++ b/src/main/scala/firrtl/passes/Legalize.scala @@ -0,0 +1,89 @@ +package firrtl.passes + +import firrtl.PrimOps._ +import firrtl.Utils.{BoolType, error, zero} +import firrtl.ir._ +import firrtl.options.{PreservesAll, Dependency} +import firrtl.transforms.ConstantPropagation +import firrtl.{Transform, bitWidth} +import firrtl.Mappers._ + +// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts +// TODO replace UInt with zero-width wire instead +object Legalize extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) + + override val optionalPrerequisites = Seq.empty + + override val dependents = Seq.empty + + private def legalizeShiftRight(e: DoPrim): Expression = { + require(e.op == Shr) + e.args.head match { + case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e) + case _ => + val amount = e.consts.head.toInt + val width = bitWidth(e.args.head.tpe) + lazy val msb = width - 1 + if (amount >= width) { + e.tpe match { + case UIntType(_) => zero + case SIntType(_) => + val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) + DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) + case t => error(s"Unsupported type $t for Primop Shift Right") + } + } else { + e + } + } + } + private def legalizeBitExtract(expr: DoPrim): Expression = { + expr.args.head match { + case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr) + case _ => expr + } + } + private def legalizePad(expr: DoPrim): Expression = expr.args.head match { + case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head => + UIntLiteral(value, IntWidth(expr.consts.head)) + case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head => + SIntLiteral(value, IntWidth(expr.consts.head)) + case _ => expr + } + private def legalizeConnect(c: Connect): Statement = { + val t = c.loc.tpe + val w = bitWidth(t) + if (w >= bitWidth(c.expr.tpe)) { + c + } else { + val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w))) + val expr = t match { + case UIntType(_) => bits + case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) + case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t) + } + Connect(c.info, c.loc, expr) + } + } + def run (c: Circuit): Circuit = { + def legalizeE(expr: Expression): Expression = expr map legalizeE match { + case prim: DoPrim => prim.op match { + case Shr => legalizeShiftRight(prim) + case Pad => legalizePad(prim) + case Bits | Head | Tail => legalizeBitExtract(prim) + case _ => prim + } + case e => e // respect pre-order traversal + } + def legalizeS (s: Statement): Statement = { + val legalizedStmt = s match { + case c: Connect => legalizeConnect(c) + case _ => s + } + legalizedStmt map legalizeS map legalizeE + } + c copy (modules = c.modules map (_ map legalizeS)) + } +} diff --git a/src/main/scala/firrtl/passes/Pass.scala b/src/main/scala/firrtl/passes/Pass.scala new file mode 100644 index 00000000..4673a8e1 --- /dev/null +++ b/src/main/scala/firrtl/passes/Pass.scala @@ -0,0 +1,39 @@ +package firrtl.passes + +import firrtl.Utils.error +import firrtl.ir.Circuit +import firrtl.{CircuitForm, CircuitState, FirrtlUserException, Transform, UnknownForm} + +/** [[Pass]] is simple transform that is generally part of a larger [[Transform]] + * Has an [[UnknownForm]], because larger [[Transform]] should specify form + */ +trait Pass extends Transform { + def inputForm: CircuitForm = UnknownForm + def outputForm: CircuitForm = UnknownForm + def run(c: Circuit): Circuit + def execute(state: CircuitState): CircuitState = { + val result = (state.form, inputForm) match { + case (_, UnknownForm) => run(state.circuit) + case (UnknownForm, _) => run(state.circuit) + case (x, y) if x > y => + error(s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") + case _ => run(state.circuit) + } + CircuitState(result, outputForm, state.annotations, state.renames) + } +} + +// Error handling +class PassException(message: String) extends FirrtlUserException(message) +class PassExceptions(val exceptions: Seq[PassException]) extends FirrtlUserException("\n" + exceptions.mkString("\n")) +class Errors { + val errors = collection.mutable.ArrayBuffer[PassException]() + def append(pe: PassException) = errors.append(pe) + def trigger() = errors.size match { + case 0 => + case 1 => throw errors.head + case _ => + append(new PassException(s"${errors.length} errors detected!")) + throw new PassExceptions(errors) + } +} diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala deleted file mode 100644 index a8d37758..00000000 --- a/src/main/scala/firrtl/passes/Passes.scala +++ /dev/null @@ -1,367 +0,0 @@ -// See LICENSE for license details. - -package firrtl.passes - -import firrtl._ -import firrtl.ir._ -import firrtl.Utils._ -import firrtl.Mappers._ -import firrtl.PrimOps._ -import firrtl.options.{Dependency, PreservesAll} -import firrtl.transforms.ConstantPropagation - -import scala.collection.mutable - -/** [[Pass]] is simple transform that is generally part of a larger [[Transform]] - * Has an [[UnknownForm]], because larger [[Transform]] should specify form - */ -trait Pass extends Transform { - def inputForm: CircuitForm = UnknownForm - def outputForm: CircuitForm = UnknownForm - def run(c: Circuit): Circuit - def execute(state: CircuitState): CircuitState = { - val result = (state.form, inputForm) match { - case (_, UnknownForm) => run(state.circuit) - case (UnknownForm, _) => run(state.circuit) - case (x, y) if x > y => - error(s"[$name]: Input form must be lower or equal to $inputForm. Got ${state.form}") - case _ => run(state.circuit) - } - CircuitState(result, outputForm, state.annotations, state.renames) - } -} - -// Error handling -class PassException(message: String) extends FirrtlUserException(message) -class PassExceptions(val exceptions: Seq[PassException]) extends FirrtlUserException("\n" + exceptions.mkString("\n")) -class Errors { - val errors = collection.mutable.ArrayBuffer[PassException]() - def append(pe: PassException) = errors.append(pe) - def trigger() = errors.size match { - case 0 => - case 1 => throw errors.head - case _ => - append(new PassException(s"${errors.length} errors detected!")) - throw new PassExceptions(errors) - } -} - -// These should be distributed into separate files -object ToWorkingIR extends Pass with PreservesAll[Transform] { - - override val prerequisites = firrtl.stage.Forms.MinimalHighForm - - def toExp(e: Expression): Expression = e map toExp match { - case ex: Reference => WRef(ex.name, ex.tpe, UnknownKind, UnknownFlow) - case ex: SubField => WSubField(ex.expr, ex.name, ex.tpe, UnknownFlow) - case ex: SubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, UnknownFlow) - case ex: SubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, UnknownFlow) - case ex => ex // This might look like a case to use case _ => e, DO NOT! - } - - def toStmt(s: Statement): Statement = s map toExp match { - case sx: DefInstance => WDefInstance(sx.info, sx.name, sx.module, UnknownType) - case sx => sx map toStmt - } - - def run (c:Circuit): Circuit = - c copy (modules = c.modules map (_ map toStmt)) -} - -object PullMuxes extends Pass with PreservesAll[Transform] { - - override val prerequisites = firrtl.stage.Forms.Deduped - - def run(c: Circuit): Circuit = { - def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { - case ex: WSubField => ex.expr match { - case exx: Mux => Mux(exx.cond, - WSubField(exx.tval, ex.name, ex.tpe, ex.flow), - WSubField(exx.fval, ex.name, ex.tpe, ex.flow), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex: WSubIndex => ex.expr match { - case exx: Mux => Mux(exx.cond, - WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow), - WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex: WSubAccess => ex.expr match { - case exx: Mux => Mux(exx.cond, - WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow), - WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), ex.tpe) - case exx: ValidIf => ValidIf(exx.cond, - WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe) - case _ => ex // case exx => exx causes failed tests - } - case ex => ex - } - def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e - val modulesx = c.modules.map { - case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body)) - case (m:ExtModule) => m - } - Circuit(c.info, modulesx, c.main) - } -} - -object ExpandConnects extends Pass with PreservesAll[Transform] { - - override val prerequisites = - Seq( Dependency(PullMuxes), - Dependency(ReplaceAccesses) ) ++ firrtl.stage.Forms.Deduped - - def run(c: Circuit): Circuit = { - def expand_connects(m: Module): Module = { - val flows = collection.mutable.LinkedHashMap[String,Flow]() - def expand_s(s: Statement): Statement = { - def set_flow(e: Expression): Expression = e map set_flow match { - case ex: WRef => WRef(ex.name, ex.tpe, ex.kind, flows(ex.name)) - case ex: WSubField => - val f = get_field(ex.expr.tpe, ex.name) - val flowx = times(flow(ex.expr), f.flip) - WSubField(ex.expr, ex.name, ex.tpe, flowx) - case ex: WSubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, flow(ex.expr)) - case ex: WSubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, flow(ex.expr)) - case ex => ex - } - s match { - case sx: DefWire => flows(sx.name) = DuplexFlow; sx - case sx: DefRegister => flows(sx.name) = DuplexFlow; sx - case sx: WDefInstance => flows(sx.name) = SourceFlow; sx - case sx: DefMemory => flows(sx.name) = SourceFlow; sx - case sx: DefNode => flows(sx.name) = SourceFlow; sx - case sx: IsInvalid => - val invalids = create_exps(sx.expr).flatMap { case expx => - flow(set_flow(expx)) match { - case DuplexFlow => Some(IsInvalid(sx.info, expx)) - case SinkFlow => Some(IsInvalid(sx.info, expx)) - case _ => None - } - } - invalids.size match { - case 0 => EmptyStmt - case 1 => invalids.head - case _ => Block(invalids) - } - case sx: Connect => - val locs = create_exps(sx.loc) - val exps = create_exps(sx.expr) - Block(locs.zip(exps).map { case (locx, expx) => - to_flip(flow(locx)) match { - case Default => Connect(sx.info, locx, expx) - case Flip => Connect(sx.info, expx, locx) - } - }) - case sx: PartialConnect => - val ls = get_valid_points(sx.loc.tpe, sx.expr.tpe, Default, Default) - val locs = create_exps(sx.loc) - val exps = create_exps(sx.expr) - val stmts = ls map { case (x, y) => - locs(x).tpe match { - case AnalogType(_) => Attach(sx.info, Seq(locs(x), exps(y))) - case _ => - to_flip(flow(locs(x))) match { - case Default => Connect(sx.info, locs(x), exps(y)) - case Flip => Connect(sx.info, exps(y), locs(x)) - } - } - } - Block(stmts) - case sx => sx map expand_s - } - } - - m.ports.foreach { p => flows(p.name) = to_flow(p.direction) } - Module(m.info, m.name, m.ports, expand_s(m.body)) - } - - val modulesx = c.modules.map { - case (m: ExtModule) => m - case (m: Module) => expand_connects(m) - } - Circuit(c.info, modulesx, c.main) - } -} - - -// Replace shr by amount >= arg width with 0 for UInts and MSB for SInts -// TODO replace UInt with zero-width wire instead -object Legalize extends Pass with PreservesAll[Transform] { - - override val prerequisites = firrtl.stage.Forms.MidForm :+ Dependency(LowerTypes) - - override val optionalPrerequisites = Seq.empty - - override val dependents = Seq.empty - - private def legalizeShiftRight(e: DoPrim): Expression = { - require(e.op == Shr) - e.args.head match { - case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.foldShiftRight(e) - case _ => - val amount = e.consts.head.toInt - val width = bitWidth(e.args.head.tpe) - lazy val msb = width - 1 - if (amount >= width) { - e.tpe match { - case UIntType(_) => zero - case SIntType(_) => - val bits = DoPrim(Bits, e.args, Seq(msb, msb), BoolType) - DoPrim(AsSInt, Seq(bits), Seq.empty, SIntType(IntWidth(1))) - case t => error(s"Unsupported type $t for Primop Shift Right") - } - } else { - e - } - } - } - private def legalizeBitExtract(expr: DoPrim): Expression = { - expr.args.head match { - case _: UIntLiteral | _: SIntLiteral => ConstantPropagation.constPropBitExtract(expr) - case _ => expr - } - } - private def legalizePad(expr: DoPrim): Expression = expr.args.head match { - case UIntLiteral(value, IntWidth(width)) if width < expr.consts.head => - UIntLiteral(value, IntWidth(expr.consts.head)) - case SIntLiteral(value, IntWidth(width)) if width < expr.consts.head => - SIntLiteral(value, IntWidth(expr.consts.head)) - case _ => expr - } - private def legalizeConnect(c: Connect): Statement = { - val t = c.loc.tpe - val w = bitWidth(t) - if (w >= bitWidth(c.expr.tpe)) { - c - } else { - val bits = DoPrim(Bits, Seq(c.expr), Seq(w - 1, 0), UIntType(IntWidth(w))) - val expr = t match { - case UIntType(_) => bits - case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) - case FixedType(_, IntWidth(p)) => DoPrim(AsFixedPoint, Seq(bits), Seq(p), t) - } - Connect(c.info, c.loc, expr) - } - } - def run (c: Circuit): Circuit = { - def legalizeE(expr: Expression): Expression = expr map legalizeE match { - case prim: DoPrim => prim.op match { - case Shr => legalizeShiftRight(prim) - case Pad => legalizePad(prim) - case Bits | Head | Tail => legalizeBitExtract(prim) - case _ => prim - } - case e => e // respect pre-order traversal - } - def legalizeS (s: Statement): Statement = { - val legalizedStmt = s match { - case c: Connect => legalizeConnect(c) - case _ => s - } - legalizedStmt map legalizeS map legalizeE - } - c copy (modules = c.modules map (_ map legalizeS)) - } -} - -/** Makes changes to the Firrtl AST to make Verilog emission easier - * - * - For each instance, adds wires to connect to each port - * - Note that no Namespace is required because Uniquify ensures that there will be no - * collisions with the lowered names of instance ports - * - Also removes Attaches where a single Port OR Wire connects to 1 or more instance ports - * - These are expressed in the portCons of WDefInstConnectors - * - * @note The result of this pass is NOT legal Firrtl - */ -object VerilogPrep extends Pass with PreservesAll[Transform] { - - override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ - Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], - Dependency[firrtl.transforms.FixAddingNegativeLiterals], - Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], - Dependency[firrtl.transforms.InlineBitExtractionsTransform], - Dependency[firrtl.transforms.InlineCastsTransform], - Dependency[firrtl.transforms.LegalizeClocksTransform], - Dependency[firrtl.transforms.FlattenRegUpdate], - Dependency(passes.VerilogModulusCleanup), - Dependency[firrtl.transforms.VerilogRename] ) - - override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized - - override val dependents = Seq.empty - - type AttachSourceMap = Map[WrappedExpression, Expression] - - // Finds attaches with only a single source (Port or Wire) - // - Creates a map of attached expressions to their source - // - Removes the Attach - private def collectAndRemoveAttach(m: DefModule): (DefModule, AttachSourceMap) = { - val sourceMap = mutable.HashMap.empty[WrappedExpression, Expression] - lazy val namespace = Namespace(m) - - def onStmt(stmt: Statement): Statement = stmt map onStmt match { - case attach: Attach => - val wires = attach.exprs groupBy kind - val sources = wires.getOrElse(PortKind, Seq.empty) ++ wires.getOrElse(WireKind, Seq.empty) - val instPorts = wires.getOrElse(InstanceKind, Seq.empty) - // Sanity check (Should be caught by CheckTypes) - assert(sources.size + instPorts.size == attach.exprs.size) - - sources match { - case Seq() => // Zero sources, can add a wire to connect and remove - val name = namespace.newTemp - val wire = DefWire(NoInfo, name, instPorts.head.tpe) - val ref = WRef(wire) - for (inst <- instPorts) sourceMap(inst) = ref - wire // Replace the attach with new source wire definition - case Seq(source) => // One source can be removed - assert(!sourceMap.contains(source)) // should have been merged - for (inst <- instPorts) sourceMap(inst) = source - EmptyStmt - case moreThanOne => - attach - } - case s => s - } - - (m map onStmt, sourceMap.toMap) - } - - def run(c: Circuit): Circuit = { - def lowerE(e: Expression): Expression = e match { - case (_: WRef | _: WSubField) if kind(e) == InstanceKind => - WRef(LowerTypes.loweredName(e), e.tpe, kind(e), flow(e)) - case _ => e map lowerE - } - - def lowerS(attachMap: AttachSourceMap)(s: Statement): Statement = s match { - case WDefInstance(info, name, module, tpe) => - val portRefs = create_exps(WRef(name, tpe, ExpKind, SourceFlow)) - val (portCons, wires) = portRefs.map { p => - attachMap.get(p) match { - // If it has a source in attachMap use that - case Some(ref) => (p -> ref, None) - // If no source, create a wire corresponding to the port and connect it up - case None => - val wire = DefWire(info, LowerTypes.loweredName(p), p.tpe) - (p -> WRef(wire), Some(wire)) - } - }.unzip - val newInst = WDefInstanceConnector(info, name, module, tpe, portCons) - Block(wires.flatten :+ newInst) - case other => other map lowerS(attachMap) map lowerE - } - - val modulesx = c.modules map { mod => - val (modx, attachMap) = collectAndRemoveAttach(mod) - modx map lowerS(attachMap) - } - c.copy(modules = modulesx) - } -} diff --git a/src/main/scala/firrtl/passes/PullMuxes.scala b/src/main/scala/firrtl/passes/PullMuxes.scala new file mode 100644 index 00000000..8befd9fa --- /dev/null +++ b/src/main/scala/firrtl/passes/PullMuxes.scala @@ -0,0 +1,47 @@ +package firrtl.passes + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.options.PreservesAll +import firrtl.{Transform, WSubAccess, WSubField, WSubIndex} + +object PullMuxes extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.Deduped + + def run(c: Circuit): Circuit = { + def pull_muxes_e(e: Expression): Expression = e map pull_muxes_e match { + case ex: WSubField => ex.expr match { + case exx: Mux => Mux(exx.cond, + WSubField(exx.tval, ex.name, ex.tpe, ex.flow), + WSubField(exx.fval, ex.name, ex.tpe, ex.flow), ex.tpe) + case exx: ValidIf => ValidIf(exx.cond, + WSubField(exx.value, ex.name, ex.tpe, ex.flow), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex: WSubIndex => ex.expr match { + case exx: Mux => Mux(exx.cond, + WSubIndex(exx.tval, ex.value, ex.tpe, ex.flow), + WSubIndex(exx.fval, ex.value, ex.tpe, ex.flow), ex.tpe) + case exx: ValidIf => ValidIf(exx.cond, + WSubIndex(exx.value, ex.value, ex.tpe, ex.flow), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex: WSubAccess => ex.expr match { + case exx: Mux => Mux(exx.cond, + WSubAccess(exx.tval, ex.index, ex.tpe, ex.flow), + WSubAccess(exx.fval, ex.index, ex.tpe, ex.flow), ex.tpe) + case exx: ValidIf => ValidIf(exx.cond, + WSubAccess(exx.value, ex.index, ex.tpe, ex.flow), ex.tpe) + case _ => ex // case exx => exx causes failed tests + } + case ex => ex + } + def pull_muxes(s: Statement): Statement = s map pull_muxes map pull_muxes_e + val modulesx = c.modules.map { + case (m:Module) => Module(m.info, m.name, m.ports, pull_muxes(m.body)) + case (m:ExtModule) => m + } + Circuit(c.info, modulesx, c.main) + } +} diff --git a/src/main/scala/firrtl/passes/ToWorkingIR.scala b/src/main/scala/firrtl/passes/ToWorkingIR.scala new file mode 100644 index 00000000..109654ee --- /dev/null +++ b/src/main/scala/firrtl/passes/ToWorkingIR.scala @@ -0,0 +1,28 @@ +package firrtl.passes + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.options.{PreservesAll} +import firrtl.{Transform, UnknownFlow, UnknownKind, WDefInstance, WRef, WSubAccess, WSubField, WSubIndex} + +// These should be distributed into separate files +object ToWorkingIR extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.MinimalHighForm + + def toExp(e: Expression): Expression = e map toExp match { + case ex: Reference => WRef(ex.name, ex.tpe, UnknownKind, UnknownFlow) + case ex: SubField => WSubField(ex.expr, ex.name, ex.tpe, UnknownFlow) + case ex: SubIndex => WSubIndex(ex.expr, ex.value, ex.tpe, UnknownFlow) + case ex: SubAccess => WSubAccess(ex.expr, ex.index, ex.tpe, UnknownFlow) + case ex => ex // This might look like a case to use case _ => e, DO NOT! + } + + def toStmt(s: Statement): Statement = s map toExp match { + case sx: DefInstance => WDefInstance(sx.info, sx.name, sx.module, UnknownType) + case sx => sx map toStmt + } + + def run (c:Circuit): Circuit = + c copy (modules = c.modules map (_ map toStmt)) +} diff --git a/src/main/scala/firrtl/passes/VerilogPrep.scala b/src/main/scala/firrtl/passes/VerilogPrep.scala new file mode 100644 index 00000000..776c0f5f --- /dev/null +++ b/src/main/scala/firrtl/passes/VerilogPrep.scala @@ -0,0 +1,106 @@ +package firrtl.passes + +import firrtl.Utils.{create_exps, flow, kind, toWrappedExpression} +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.options.{Dependency, PreservesAll} +import firrtl._ + +import scala.collection.mutable + +/** Makes changes to the Firrtl AST to make Verilog emission easier + * + * - For each instance, adds wires to connect to each port + * - Note that no Namespace is required because Uniquify ensures that there will be no + * collisions with the lowered names of instance ports + * - Also removes Attaches where a single Port OR Wire connects to 1 or more instance ports + * - These are expressed in the portCons of WDefInstConnectors + * + * @note The result of this pass is NOT legal Firrtl + */ +object VerilogPrep extends Pass with PreservesAll[Transform] { + + override val prerequisites = firrtl.stage.Forms.LowFormMinimumOptimized ++ + Seq( Dependency[firrtl.transforms.BlackBoxSourceHelper], + Dependency[firrtl.transforms.FixAddingNegativeLiterals], + Dependency[firrtl.transforms.ReplaceTruncatingArithmetic], + Dependency[firrtl.transforms.InlineBitExtractionsTransform], + Dependency[firrtl.transforms.InlineCastsTransform], + Dependency[firrtl.transforms.LegalizeClocksTransform], + Dependency[firrtl.transforms.FlattenRegUpdate], + Dependency(passes.VerilogModulusCleanup), + Dependency[firrtl.transforms.VerilogRename] ) + + override val optionalPrerequisites = firrtl.stage.Forms.LowFormOptimized + + override val dependents = Seq.empty + + type AttachSourceMap = Map[WrappedExpression, Expression] + + // Finds attaches with only a single source (Port or Wire) + // - Creates a map of attached expressions to their source + // - Removes the Attach + private def collectAndRemoveAttach(m: DefModule): (DefModule, AttachSourceMap) = { + val sourceMap = mutable.HashMap.empty[WrappedExpression, Expression] + lazy val namespace = Namespace(m) + + def onStmt(stmt: Statement): Statement = stmt map onStmt match { + case attach: Attach => + val wires = attach.exprs groupBy kind + val sources = wires.getOrElse(PortKind, Seq.empty) ++ wires.getOrElse(WireKind, Seq.empty) + val instPorts = wires.getOrElse(InstanceKind, Seq.empty) + // Sanity check (Should be caught by CheckTypes) + assert(sources.size + instPorts.size == attach.exprs.size) + + sources match { + case Seq() => // Zero sources, can add a wire to connect and remove + val name = namespace.newTemp + val wire = DefWire(NoInfo, name, instPorts.head.tpe) + val ref = WRef(wire) + for (inst <- instPorts) sourceMap(inst) = ref + wire // Replace the attach with new source wire definition + case Seq(source) => // One source can be removed + assert(!sourceMap.contains(source)) // should have been merged + for (inst <- instPorts) sourceMap(inst) = source + EmptyStmt + case moreThanOne => + attach + } + case s => s + } + + (m map onStmt, sourceMap.toMap) + } + + def run(c: Circuit): Circuit = { + def lowerE(e: Expression): Expression = e match { + case (_: WRef | _: WSubField) if kind(e) == InstanceKind => + WRef(LowerTypes.loweredName(e), e.tpe, kind(e), flow(e)) + case _ => e map lowerE + } + + def lowerS(attachMap: AttachSourceMap)(s: Statement): Statement = s match { + case WDefInstance(info, name, module, tpe) => + val portRefs = create_exps(WRef(name, tpe, ExpKind, SourceFlow)) + val (portCons, wires) = portRefs.map { p => + attachMap.get(p) match { + // If it has a source in attachMap use that + case Some(ref) => (p -> ref, None) + // If no source, create a wire corresponding to the port and connect it up + case None => + val wire = DefWire(info, LowerTypes.loweredName(p), p.tpe) + (p -> WRef(wire), Some(wire)) + } + }.unzip + val newInst = WDefInstanceConnector(info, name, module, tpe, portCons) + Block(wires.flatten :+ newInst) + case other => other map lowerS(attachMap) map lowerE + } + + val modulesx = c.modules map { mod => + val (modx, attachMap) = collectAndRemoveAttach(mod) + modx map lowerS(attachMap) + } + c.copy(modules = modulesx) + } +} |
