diff options
| author | Jack Koenig | 2018-05-15 11:29:43 -0700 |
|---|---|---|
| committer | GitHub | 2018-05-15 11:29:43 -0700 |
| commit | 84b5fc1bc97e014bc03056a3f752c40ec6100701 (patch) | |
| tree | 2af78be6b61fbb82c1261d3d30ab9cabbcf401f4 | |
| parent | abcb22d6c34eb51749e7bc848b437a165bc5b330 (diff) | |
Replace truncating add and sub with addw/subw (#800)
Replaces old VerilogWrap which didn't work with split expressions and was
actually buggy anyway. This functionality reduces unnecessary intermediates in
emitted Verilog.
4 files changed, 135 insertions, 28 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 195f786d..9bb8a466 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -12,7 +12,7 @@ import scala.io.Source import firrtl.ir._ import firrtl.passes._ -import firrtl.transforms.{DeadCodeElimination, FlattenRegUpdate} +import firrtl.transforms._ import firrtl.annotations._ import firrtl.Mappers._ import firrtl.PrimOps._ @@ -659,10 +659,10 @@ class VerilogEmitter extends SeqTransform with Emitter { /** Preamble for every emitted Verilog file */ def transforms = Seq( + new ReplaceTruncatingArithmetic, new FlattenRegUpdate, new DeadCodeElimination, passes.VerilogModulusCleanup, - passes.VerilogWrap, passes.VerilogRename, passes.VerilogPrep) diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 61afade6..9bbde5f6 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -252,32 +252,6 @@ object Legalize extends Pass { } } -object VerilogWrap extends Pass { - def vWrapE(e: Expression): Expression = e map vWrapE match { - case e: DoPrim => e.op match { - case Tail => e.args.head match { - case e0: DoPrim => e0.op match { - case Add => DoPrim(Addw, e0.args, Nil, e.tpe) - case Sub => DoPrim(Subw, e0.args, Nil, e.tpe) - case _ => e - } - case _ => e - } - case _ => e - } - case _ => e - } - def vWrapS(s: Statement): Statement = { - s map vWrapS map vWrapE match { - case sx: Print => sx.copy(string = sx.string.verilogFormat) - case sx => sx - } - } - - def run(c: Circuit): Circuit = - c copy (modules = c.modules map (_ map vWrapS)) -} - object VerilogRename extends Pass { def verilogRenameN(n: String): String = if (v_keywords(n)) "%s$".format(n) else n diff --git a/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala new file mode 100644 index 00000000..9c809c5f --- /dev/null +++ b/src/main/scala/firrtl/transforms/ReplaceTruncatingArithmetic.scala @@ -0,0 +1,73 @@ +package firrtl +package transforms + +import firrtl.ir._ +import firrtl.Mappers._ +import firrtl.PrimOps._ +import firrtl.WrappedExpression._ + +import scala.collection.mutable + +object ReplaceTruncatingArithmetic { + + /** Mapping from references to the [[Expression]]s that drive them */ + type Netlist = mutable.HashMap[WrappedExpression, Expression] + + private val SeqBIOne = Seq(BigInt(1)) + + /** Replaces truncating arithmetic in an Expression + * + * @param netlist a '''mutable''' HashMap mapping references to [[DefNode]]s to their connected + * [[Expression]]s. It is '''not''' mutated in this function + * @param expr the Expression being transformed + * @return Returns expr with truncating arithmetic replaced + */ + def onExpr(netlist: Netlist)(expr: Expression): Expression = + expr.map(onExpr(netlist)) match { + case orig @ DoPrim(Tail, Seq(e), SeqBIOne, tailtpe) => + netlist.getOrElse(we(e), e) match { + case DoPrim(Add, args, cs, _) => DoPrim(Addw, args, cs, tailtpe) + case DoPrim(Sub, args, cs, _) => DoPrim(Subw, args, cs, tailtpe) + case _ => orig // Not a candidate + } + case other => other // Not a candidate + } + + /** Replaces truncating arithmetic in a Statement + * + * @param netlist a '''mutable''' HashMap mapping references to [[DefNode]]s to their connected + * [[Expression]]s. This function '''will''' mutate it if stmt contains a [[DefNode]] + * @param stmt the Statement being searched for nodes and transformed + * @return Returns stmt with truncating arithmetic replaced + */ + def onStmt(netlist: Netlist)(stmt: Statement): Statement = + stmt.map(onStmt(netlist)).map(onExpr(netlist)) match { + case node @ DefNode(_, name, value) => + netlist(we(WRef(name))) = value + node + case other => other + } + + /** Replaces truncating arithmetic in a Module */ + def onMod(mod: DefModule): DefModule = mod.map(onStmt(new Netlist)) +} + +/** Replaces non-expanding arithmetic + * + * In the case where the result of `add` or `sub` immediately throws away the expanded msb, this + * transform will replace the operation with a non-expanding operator `addw` or `subw` + * respectively. + * + * @note This replaces some FIRRTL primops with ops that are not actually legal FIRRTL. They are + * useful for emission to languages that support non-expanding arithmetic (like Verilog) + */ +class ReplaceTruncatingArithmetic extends Transform { + def inputForm = LowForm + def outputForm = LowForm + + def execute(state: CircuitState): CircuitState = { + val modulesx = state.circuit.modules.map(ReplaceTruncatingArithmetic.onMod(_)) + state.copy(circuit = state.circuit.copy(modules = modulesx)) + } +} + diff --git a/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala b/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala new file mode 100644 index 00000000..b9c04e99 --- /dev/null +++ b/src/test/scala/firrtlTests/ReplaceTruncatingArithmeticSpec.scala @@ -0,0 +1,60 @@ +// See LICENSE for license details. + +package firrtlTests + +import firrtl._ +import firrtl.ir._ +import FirrtlCheckers._ + +class ReplaceTruncatingArithmeticSpec extends FirrtlFlatSpec { + def compile(input: String): CircuitState = + (new VerilogCompiler).compileAndEmit(CircuitState(parse(input), ChirrtlForm), List.empty) + def compileBody(body: String) = { + val str = """ + |circuit Test : + | module Test : + |""".stripMargin + body.split("\n").mkString(" ", "\n ", "") + compile(str) + } + + "Truncting addition" should "be inferred and emitted in Verilog" in { + val result = compileBody(s""" + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |z <= tail(add(x, y), 1)""".stripMargin + ) + result should containLine (s"assign z = x + y;") + } + it should "be inferred and emitted in Verilog even with an intermediate node" in { + val result = compileBody(s""" + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |node n = add(x, y) + |z <= tail(n, 1)""".stripMargin + ) + result should containLine (s"assign z = x + y;") + } + "Truncting subtraction" should "be inferred and emitted in Verilog" in { + val result = compileBody(s""" + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<8> + |z <= tail(sub(x, y), 1)""".stripMargin + ) + result should containLine (s"assign z = x - y;") + } + "Tailing more than 1" should "not result in a truncating operator" in { + val result = compileBody(s""" + |input x : UInt<8> + |input y : UInt<8> + |output z : UInt<7> + |node n = sub(x, y) + |z <= tail(n, 2)""".stripMargin + ) + result should containLine (s"assign n = x - y;") + result should containLine (s"assign z = n[6:0];") + } + +} |
