diff options
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 18 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveValidIf.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/RemoveWires.scala | 5 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/RemoveWiresSpec.scala | 14 |
4 files changed, 37 insertions, 8 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index a52e451f..921ec60b 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -250,6 +250,24 @@ object Utils extends LazyLogging { val one = UIntLiteral(1) val zero = UIntLiteral(0) + private val ClockZero = DoPrim(PrimOps.AsClock, Seq(zero), Seq.empty, ClockType) + private val AsyncZero = DoPrim(PrimOps.AsAsyncReset, Seq(zero), Nil, AsyncResetType) + + /** Returns an [[firrtl.ir.Expression Expression]] equal to zero for a given [[firrtl.ir.GroundType GroundType]] + * @note Does not support [[firrtl.ir.AnalogType AnalogType]] nor [[firrtl.ir.IntervalType IntervalType]] + */ + def getGroundZero(tpe: GroundType): Expression = tpe match { + case u: UIntType => UIntLiteral(0, u.width) + case s: SIntType => SIntLiteral(0, s.width) + case f: FixedType => FixedLiteral(0, f.width, f.point) + // Default reset type is Bool + case ResetType => Utils.zero + case ClockType => ClockZero + case AsyncResetType => AsyncZero + // TODO Support IntervalType + case other => throwInternalError(s"Unexpected type $other") + } + def create_exps(n: String, t: Type): Seq[Expression] = create_exps(WRef(n, t, ExpKind, UnknownFlow)) def create_exps(e: Expression): Seq[Expression] = e match { diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index bc8e8c1b..03214f83 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -21,12 +21,8 @@ object RemoveValidIf extends Pass { * @note Accepts [[firrtl.ir.Type Type]] but dyanmically expects [[firrtl.ir.GroundType GroundType]] */ def getGroundZero(tpe: Type): Expression = tpe match { - case _: UIntType => UIntZero - case _: SIntType => SIntZero - case ClockType => ClockZero - case _: FixedType => FixedZero - case AsyncResetType => AsyncZero - case other => throwInternalError(s"Unexpected type $other") + case g: GroundType => Utils.getGroundZero(g) + case other => throwInternalError(s"Unexpected type $other") } override def prerequisites = firrtl.stage.Forms.LowForm diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala index 440989f4..ee03ad30 100644 --- a/src/main/scala/firrtl/transforms/RemoveWires.scala +++ b/src/main/scala/firrtl/transforms/RemoveWires.scala @@ -10,6 +10,7 @@ import firrtl.traversals.Foreachers._ import firrtl.WrappedExpression._ import firrtl.graph.{CyclicException, MutableDiGraph} import firrtl.options.Dependency +import firrtl.Utils.getGroundZero import scala.collection.mutable import scala.util.{Failure, Success, Try} @@ -128,8 +129,8 @@ class RemoveWires extends Transform with DependencyAPIMigration { case invalid @ IsInvalid(info, expr) => kind(expr) match { case WireKind => - val width = expr.tpe match { case GroundType(width) => width } // LowFirrtl - netlist(we(expr)) = (Seq(ValidIf(Utils.zero, UIntLiteral(BigInt(0), width), expr.tpe)), info) + val (tpe, width) = expr.tpe match { case g: GroundType => (g, g.width) } // LowFirrtl + netlist(we(expr)) = (Seq(ValidIf(Utils.zero, getGroundZero(tpe), tpe)), info) case _ => otherStmts += invalid } case other @ (_: Print | _: Stop | _: Attach | _: Verification) => diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala index 48eaaa65..58d42710 100644 --- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala +++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala @@ -6,6 +6,7 @@ import firrtl._ import firrtl.ir._ import firrtl.Mappers._ import firrtl.testutils._ +import FirrtlCheckers._ import collection.mutable @@ -187,4 +188,17 @@ class RemoveWiresSpec extends FirrtlFlatSpec { firrtl.passes.CheckHighForm.execute(result) } + it should "give nodes made from invalid wires the correct type" in { + val result = compileBody( + s"""|input a : SInt<4> + |input sel : UInt<1> + |output z : SInt<4> + |wire w : SInt<4> + |w is invalid + |z <= mux(sel, a, w) + |""".stripMargin + ) + result should containLine("""node w = validif(UInt<1>("h0"), SInt<4>("h0"))""") + } + } |
