aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoredwardcwang2021-04-11 12:13:16 -0400
committerGitHub2021-04-11 16:13:16 +0000
commit344083ba3bdc30a25d8f2ecdf490749db9c36e4a (patch)
treefd353e2927934ccdd680066c76243acae41258aa
parent9a3dcf761e40b7ac36f9c867d0a36692d4d74c0c (diff)
smt: use existing bitWidth API (#2175)
* bitWidth: add scaladoc * smt: use existing bitWidth API
-rw-r--r--src/main/scala/firrtl/Utils.scala11
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala15
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala19
3 files changed, 25 insertions, 20 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 3d0f19b8..a58b6997 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -51,7 +51,18 @@ object getWidth {
def apply(e: Expression): Width = apply(e.tpe)
}
+/**
+ * Helper object for computing the width of a firrtl type.
+ */
object bitWidth {
+
+ /**
+ * Compute the width of a firrtl type.
+ * For example, a Vec of 4 UInts of width 8 should have a width of 32.
+ *
+ * @param dt firrtl type
+ * @return Width of the given type
+ */
def apply(dt: Type): BigInt = widthOf(dt)
private def widthOf(dt: Type): BigInt = dt match {
case t: VectorType => t.size * bitWidth(t.tpe)
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
index 13e0c312..099b6712 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
@@ -8,19 +8,10 @@ import firrtl.PrimOps
import firrtl.passes.CheckWidths.WidthTooBig
private trait TranslationContext {
- def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe))
+ def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, firrtl.bitWidth(tpe).toInt)
}
private object FirrtlExpressionSemantics {
- def getWidth(tpe: ir.Type): Int = tpe match {
- case ir.UIntType(ir.IntWidth(w)) => w.toInt
- case ir.SIntType(ir.IntWidth(w)) => w.toInt
- case ir.ClockType => 1
- case ir.ResetType => 1
- case ir.AnalogType(ir.IntWidth(w)) => w.toInt
- case other => throw new RuntimeException(s"Cannot handle type $other")
- }
-
def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = {
val eSMT = e match {
case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts)
@@ -183,5 +174,7 @@ private object FirrtlExpressionSemantics {
case _: ir.SIntType => true
case _ => false
}
- private def getWidth(e: ir.Expression): Int = getWidth(e.tpe)
+
+ // Helper function
+ private def getWidth(e: ir.Expression): Int = firrtl.bitWidth(e.tpe).toInt
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
index fea92c75..cfab61b9 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -4,6 +4,7 @@
package firrtl.backends.experimental.smt
import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation}
+import firrtl.bitWidth
import FirrtlExpressionSemantics.getWidth
import firrtl.backends.experimental.smt.random._
import firrtl.graph.MutableDiGraph
@@ -258,7 +259,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
val inputs = connects.filter(_._1.startsWith(m.name)).toMap
// derive the type of the memory from the dataType and depth
- val dataWidth = getWidth(m.dataType)
+ val dataWidth = bitWidth(m.dataType).toInt
val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1)
val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth)
@@ -391,7 +392,7 @@ private class ModuleScanner(
if (isClock(p.tpe)) {
clocks.add(p.name)
} else {
- inputs.append(BVSymbol(p.name, getWidth(p.tpe)))
+ inputs.append(BVSymbol(p.name, bitWidth(p.tpe).toInt))
}
case ir.Output =>
if (!isClock(p.tpe)) { // we ignore clock outputs
@@ -406,7 +407,7 @@ private class ModuleScanner(
assert(!isClock(tpe), "rand should never be a clock!")
// we model random sources as inputs and ignore the enable signal
infos.append(name -> info)
- inputs.append(BVSymbol(name, getWidth(tpe)))
+ inputs.append(BVSymbol(name, bitWidth(tpe).toInt))
case ir.DefWire(info, name, tpe) =>
namespace.newName(name)
if (!isClock(tpe)) {
@@ -427,7 +428,7 @@ private class ModuleScanner(
insertDummyAssignsForUnusedOutputs(reset)
insertDummyAssignsForUnusedOutputs(init)
infos.append(name -> info)
- val width = getWidth(tpe)
+ val width = bitWidth(tpe).toInt
val resetExpr = onExpression(reset, 1)
val initExpr = onExpression(init, width)
registers.append((name, width, resetExpr, initExpr))
@@ -436,7 +437,7 @@ private class ModuleScanner(
infos.append(m.name -> m.info)
val outputs = getMemOutputs(m)
(getMemInputs(m) ++ outputs).foreach(memSignals.append(_))
- val dataWidth = getWidth(m.dataType)
+ val dataWidth = bitWidth(m.dataType).toInt
outputs.foreach(name => unusedOutputs(name) = BVSymbol(name, dataWidth))
memories.append(m)
case ir.Connect(info, loc, expr) =>
@@ -445,7 +446,7 @@ private class ModuleScanner(
val name = loc.serialize
insertDummyAssignsForUnusedOutputs(expr)
infos.append(name -> info)
- connects.append((name, onExpression(expr, getWidth(loc.tpe))))
+ connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt)))
}
case i @ ir.IsInvalid(info, loc) =>
if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
@@ -507,7 +508,7 @@ private class ModuleScanner(
if (isClock(p.tpe)) {
clocks.add(pName)
} else {
- inputs.append(BVSymbol(pName, getWidth(p.tpe)))
+ inputs.append(BVSymbol(pName, bitWidth(p.tpe).toInt))
}
} else {
if (!isClock(p.tpe)) { // we ignore clock outputs
@@ -524,8 +525,8 @@ private class ModuleScanner(
// sanity checks for ports were done already using the UninterpretedModule.checkModule function
val ports = tpe.asInstanceOf[ir.BundleType].fields
- val outputs = ports.filter(_.flip == ir.Default).map(p => BVSymbol(p.name, getWidth(p.tpe)))
- val inputs = ports.filterNot(_.flip == ir.Default).map(p => BVSymbol(p.name, getWidth(p.tpe)))
+ val outputs = ports.filter(_.flip == ir.Default).map(p => BVSymbol(p.name, bitWidth(p.tpe).toInt))
+ val inputs = ports.filterNot(_.flip == ir.Default).map(p => BVSymbol(p.name, bitWidth(p.tpe).toInt))
assert(anno.stateBits == 0, "TODO: implement support for uninterpreted stateful modules!")