aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorgrebe2018-05-21 13:09:00 -0700
committerJack Koenig2018-05-21 13:09:00 -0700
commitb1709242b5c7b60e21308642947d292545eb2e37 (patch)
tree68ed90e520135d62cec32f6ca091ee5884be6e70 /src
parenta9529670ebbb2a44697fd14299b37c47d01f6623 (diff)
Fix more problems with zero width things. (#779)
This should close #757. It should also allow for stop() and printf() to be used with zero-width fields.
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/LoweringCompilers.scala3
-rw-r--r--src/main/scala/firrtl/passes/ZeroWidth.scala91
-rw-r--r--src/test/resources/features/ZeroWidthMem.fir9
-rw-r--r--src/test/scala/firrtlTests/ZeroWidthTests.scala36
4 files changed, 118 insertions, 21 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala
index 57aa1533..c4230b90 100644
--- a/src/main/scala/firrtl/LoweringCompilers.scala
+++ b/src/main/scala/firrtl/LoweringCompilers.scala
@@ -69,7 +69,8 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform {
passes.InferWidths,
passes.CheckWidths,
passes.ConvertFixedToSInt,
- passes.ZeroWidth)
+ passes.ZeroWidth,
+ passes.InferTypes)
}
/** Expands all aggregate types into many ground-typed components. Must
diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala
index 5b61c373..a8e6141c 100644
--- a/src/main/scala/firrtl/passes/ZeroWidth.scala
+++ b/src/main/scala/firrtl/passes/ZeroWidth.scala
@@ -2,16 +2,59 @@
package firrtl.passes
-import scala.collection.mutable
import firrtl.PrimOps._
import firrtl.ir._
import firrtl._
import firrtl.Mappers._
-
object ZeroWidth extends Transform {
- def inputForm = UnknownForm
- def outputForm = UnknownForm
+ def inputForm: CircuitForm = UnknownForm
+ def outputForm: CircuitForm = UnknownForm
+
+ private def makeEmptyMemBundle(name: String): Field =
+ Field(name, Flip, BundleType(Seq(
+ Field("addr", Default, UIntType(IntWidth(0))),
+ Field("en", Default, UIntType(IntWidth(0))),
+ Field("clk", Default, UIntType(IntWidth(0))),
+ Field("data", Flip, UIntType(IntWidth(0)))
+ )))
+
+ private def onEmptyMemStmt(s: Statement): Statement = s match {
+ case d @ DefMemory(info, name, tpe, _, _, _, rs, ws, rws, _) => removeZero(tpe) match {
+ case None =>
+ DefWire(info, name, BundleType(
+ rs.map(r => makeEmptyMemBundle(r)) ++
+ ws.map(w => makeEmptyMemBundle(w)) ++
+ rws.map(rw => makeEmptyMemBundle(rw))
+ ))
+ case Some(_) => d
+ }
+ case sx => sx map onEmptyMemStmt
+ }
+
+ private def onModuleEmptyMemStmt(m: DefModule): DefModule = {
+ m match {
+ case ext: ExtModule => ext
+ case in: Module => in.copy(body = onEmptyMemStmt(in.body))
+ }
+ }
+
+ /**
+ * Replace zero width mems before running the rest of the ZeroWidth transform.
+ * Dealing with mems is a bit tricky because the address, en, clk ports
+ * of the memory are not width zero even if data is.
+ *
+ * This replaces memories with a DefWire() bundle that contains the address, en,
+ * clk, and data fields implemented as zero width wires. Running the rest of the ZeroWidth
+ * transform will remove these dangling references properly.
+ *
+ */
+ def executeEmptyMemStmt(state: CircuitState): CircuitState = {
+ val c = state.circuit
+ val result = c.copy(modules = c.modules map onModuleEmptyMemStmt)
+ state.copy(circuit = result)
+ }
+
private val ZERO = BigInt(0)
private def getRemoved(x: IsDeclaration): Seq[String] = {
var removedNames: Seq[String] = Seq.empty
@@ -27,7 +70,7 @@ object ZeroWidth extends Transform {
}
removedNames
}
- private def removeZero(t: Type): Option[Type] = t match {
+ private[passes] def removeZero(t: Type): Option[Type] = t match {
case GroundType(IntWidth(ZERO)) => None
case BundleType(fields) =>
fields map (f => (f, removeZero(f.tpe))) collect {
@@ -60,18 +103,26 @@ object ZeroWidth extends Transform {
}
}
private def onStmt(renames: RenameMap)(s: Statement): Statement = s match {
- case (_: DefWire| _: DefRegister| _: DefMemory) =>
- // List all removed expression names, and delete them from renames
- renames.delete(getRemoved(s.asInstanceOf[IsDeclaration]))
- // Create new types without zero-width wires
- var removed = false
- def applyRemoveZero(t: Type): Type = removeZero(t) match {
- case None => removed = true; t
- case Some(tx) => tx
+ case d @ DefWire(info, name, tpe) =>
+ renames.delete(getRemoved(d))
+ removeZero(tpe) match {
+ case None => EmptyStmt
+ case Some(t) => DefWire(info, name, t)
+ }
+ case d @ DefRegister(info, name, tpe, clock, reset, init) =>
+ renames.delete(getRemoved(d))
+ removeZero(tpe) match {
+ case None => EmptyStmt
+ case Some(t) =>
+ DefRegister(info, name, t, onExp(clock), onExp(reset), onExp(init))
+ }
+ case d: DefMemory =>
+ renames.delete(getRemoved(d))
+ removeZero(d.dataType) match {
+ case None =>
+ Utils.throwInternalError(s"private pass ZeroWidthMemRemove should have removed this memory: $d")
+ case Some(t) => d.copy(dataType = t)
}
- val sxx = (s map onExp) map applyRemoveZero
- // Return new declaration
- if(removed) EmptyStmt else sxx
case Connect(info, loc, exp) => removeZero(loc.tpe) match {
case None => EmptyStmt
case Some(t) => Connect(info, loc, onExp(exp))
@@ -84,7 +135,7 @@ object ZeroWidth extends Transform {
case None => EmptyStmt
case Some(t) => DefNode(info, name, onExp(value))
}
- case sx => sx map onStmt(renames)
+ case sx => sx map onStmt(renames) map onExp
}
private def onModule(renames: RenameMap)(m: DefModule): DefModule = {
renames.setModule(m.name)
@@ -102,10 +153,12 @@ object ZeroWidth extends Transform {
}
}
def execute(state: CircuitState): CircuitState = {
- val c = state.circuit
+ // run executeEmptyMemStmt first to remove zero-width memories
+ // then run InferTypes to update widths for addr, en, clk, etc
+ val c = InferTypes.run(executeEmptyMemStmt(state).circuit)
val renames = RenameMap()
renames.setCircuit(c.main)
- val result = InferTypes.run(c.copy(modules = c.modules map onModule(renames)))
+ val result = c.copy(modules = c.modules map onModule(renames))
CircuitState(result, outputForm, state.annotations, Some(renames))
}
}
diff --git a/src/test/resources/features/ZeroWidthMem.fir b/src/test/resources/features/ZeroWidthMem.fir
index c56f8390..a909b041 100644
--- a/src/test/resources/features/ZeroWidthMem.fir
+++ b/src/test/resources/features/ZeroWidthMem.fir
@@ -12,9 +12,13 @@ circuit ZeroWidthMem :
infer mport ramin = ram[waddr], clock
infer mport ramout = ram[raddr], clock
+ cmem totallyEmptyRam : UInt<0>[16]
+ infer mport emptyRamout = totallyEmptyRam[raddr], clock
+
ramin.0 <= in.0
ramin.1 <= in.1
- out <= ramout
+ out.0 <= ramout.0
+ out.1 <= ramout.1
wire foo : UInt<32>
foo <= UInt<32>("hdeadbeef")
@@ -26,3 +30,6 @@ circuit ZeroWidthMem :
printf(clock, UInt(1), "Assertion failed!\n")
stop(clock, UInt(1), 1) ; Failure!
+ when neq(emptyRamout, UInt<1>("h0")) :
+ stop(clock, UInt(1), 1) ; Failure! empty mem should be zero
+
diff --git a/src/test/scala/firrtlTests/ZeroWidthTests.scala b/src/test/scala/firrtlTests/ZeroWidthTests.scala
index 50385a80..6443e131 100644
--- a/src/test/scala/firrtlTests/ZeroWidthTests.scala
+++ b/src/test/scala/firrtlTests/ZeroWidthTests.scala
@@ -176,6 +176,42 @@ class ZeroWidthTests extends FirrtlFlatSpec {
| node a = cat(x, z)""".stripMargin
(parse(exec(input)).serialize) should be (parse(check).serialize)
}
+ "Stop with type <0>" should "be replaced with UInt(0)" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input x: UInt<1>
+ | input y: UInt<0>
+ | input z: UInt<1>
+ | stop(clk, y, 1)""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input x: UInt<1>
+ | input z: UInt<1>
+ | stop(clk, UInt(0), 1)""".stripMargin
+ (parse(exec(input)).serialize) should be (parse(check).serialize)
+ }
+ "Print with type <0>" should "be replaced with UInt(0)" in {
+ val input =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input x: UInt<1>
+ | input y: UInt<0>
+ | input z: UInt<1>
+ | printf(clk, UInt(1), "%d %d %d\n", x, y, z)""".stripMargin
+ val check =
+ """circuit Top :
+ | module Top :
+ | input clk: Clock
+ | input x: UInt<1>
+ | input z: UInt<1>
+ | printf(clk, UInt(1), "%d %d %d\n", x, UInt(0), z)""".stripMargin
+ (parse(exec(input)).serialize) should be (parse(check).serialize)
+ }
}
class ZeroWidthVerilog extends FirrtlFlatSpec {