aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Utils.scala18
-rw-r--r--src/main/scala/firrtl/passes/RemoveValidIf.scala8
-rw-r--r--src/main/scala/firrtl/transforms/RemoveWires.scala5
-rw-r--r--src/test/scala/firrtlTests/RemoveWiresSpec.scala14
4 files changed, 37 insertions, 8 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index a52e451f..921ec60b 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -250,6 +250,24 @@ object Utils extends LazyLogging {
val one = UIntLiteral(1)
val zero = UIntLiteral(0)
+ private val ClockZero = DoPrim(PrimOps.AsClock, Seq(zero), Seq.empty, ClockType)
+ private val AsyncZero = DoPrim(PrimOps.AsAsyncReset, Seq(zero), Nil, AsyncResetType)
+
+ /** Returns an [[firrtl.ir.Expression Expression]] equal to zero for a given [[firrtl.ir.GroundType GroundType]]
+ * @note Does not support [[firrtl.ir.AnalogType AnalogType]] nor [[firrtl.ir.IntervalType IntervalType]]
+ */
+ def getGroundZero(tpe: GroundType): Expression = tpe match {
+ case u: UIntType => UIntLiteral(0, u.width)
+ case s: SIntType => SIntLiteral(0, s.width)
+ case f: FixedType => FixedLiteral(0, f.width, f.point)
+ // Default reset type is Bool
+ case ResetType => Utils.zero
+ case ClockType => ClockZero
+ case AsyncResetType => AsyncZero
+ // TODO Support IntervalType
+ case other => throwInternalError(s"Unexpected type $other")
+ }
+
def create_exps(n: String, t: Type): Seq[Expression] =
create_exps(WRef(n, t, ExpKind, UnknownFlow))
def create_exps(e: Expression): Seq[Expression] = e match {
diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala
index bc8e8c1b..03214f83 100644
--- a/src/main/scala/firrtl/passes/RemoveValidIf.scala
+++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala
@@ -21,12 +21,8 @@ object RemoveValidIf extends Pass {
* @note Accepts [[firrtl.ir.Type Type]] but dyanmically expects [[firrtl.ir.GroundType GroundType]]
*/
def getGroundZero(tpe: Type): Expression = tpe match {
- case _: UIntType => UIntZero
- case _: SIntType => SIntZero
- case ClockType => ClockZero
- case _: FixedType => FixedZero
- case AsyncResetType => AsyncZero
- case other => throwInternalError(s"Unexpected type $other")
+ case g: GroundType => Utils.getGroundZero(g)
+ case other => throwInternalError(s"Unexpected type $other")
}
override def prerequisites = firrtl.stage.Forms.LowForm
diff --git a/src/main/scala/firrtl/transforms/RemoveWires.scala b/src/main/scala/firrtl/transforms/RemoveWires.scala
index 440989f4..ee03ad30 100644
--- a/src/main/scala/firrtl/transforms/RemoveWires.scala
+++ b/src/main/scala/firrtl/transforms/RemoveWires.scala
@@ -10,6 +10,7 @@ import firrtl.traversals.Foreachers._
import firrtl.WrappedExpression._
import firrtl.graph.{CyclicException, MutableDiGraph}
import firrtl.options.Dependency
+import firrtl.Utils.getGroundZero
import scala.collection.mutable
import scala.util.{Failure, Success, Try}
@@ -128,8 +129,8 @@ class RemoveWires extends Transform with DependencyAPIMigration {
case invalid @ IsInvalid(info, expr) =>
kind(expr) match {
case WireKind =>
- val width = expr.tpe match { case GroundType(width) => width } // LowFirrtl
- netlist(we(expr)) = (Seq(ValidIf(Utils.zero, UIntLiteral(BigInt(0), width), expr.tpe)), info)
+ val (tpe, width) = expr.tpe match { case g: GroundType => (g, g.width) } // LowFirrtl
+ netlist(we(expr)) = (Seq(ValidIf(Utils.zero, getGroundZero(tpe), tpe)), info)
case _ => otherStmts += invalid
}
case other @ (_: Print | _: Stop | _: Attach | _: Verification) =>
diff --git a/src/test/scala/firrtlTests/RemoveWiresSpec.scala b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
index 48eaaa65..58d42710 100644
--- a/src/test/scala/firrtlTests/RemoveWiresSpec.scala
+++ b/src/test/scala/firrtlTests/RemoveWiresSpec.scala
@@ -6,6 +6,7 @@ import firrtl._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.testutils._
+import FirrtlCheckers._
import collection.mutable
@@ -187,4 +188,17 @@ class RemoveWiresSpec extends FirrtlFlatSpec {
firrtl.passes.CheckHighForm.execute(result)
}
+ it should "give nodes made from invalid wires the correct type" in {
+ val result = compileBody(
+ s"""|input a : SInt<4>
+ |input sel : UInt<1>
+ |output z : SInt<4>
+ |wire w : SInt<4>
+ |w is invalid
+ |z <= mux(sel, a, w)
+ |""".stripMargin
+ )
+ result should containLine("""node w = validif(UInt<1>("h0"), SInt<4>("h0"))""")
+ }
+
}