diff options
| author | edwardcwang | 2021-04-11 12:13:16 -0400 |
|---|---|---|
| committer | GitHub | 2021-04-11 16:13:16 +0000 |
| commit | 344083ba3bdc30a25d8f2ecdf490749db9c36e4a (patch) | |
| tree | fd353e2927934ccdd680066c76243acae41258aa /src | |
| parent | 9a3dcf761e40b7ac36f9c867d0a36692d4d74c0c (diff) | |
smt: use existing bitWidth API (#2175)
* bitWidth: add scaladoc
* smt: use existing bitWidth API
Diffstat (limited to 'src')
3 files changed, 25 insertions, 20 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 3d0f19b8..a58b6997 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -51,7 +51,18 @@ object getWidth { def apply(e: Expression): Width = apply(e.tpe) } +/** + * Helper object for computing the width of a firrtl type. + */ object bitWidth { + + /** + * Compute the width of a firrtl type. + * For example, a Vec of 4 UInts of width 8 should have a width of 32. + * + * @param dt firrtl type + * @return Width of the given type + */ def apply(dt: Type): BigInt = widthOf(dt) private def widthOf(dt: Type): BigInt = dt match { case t: VectorType => t.size * bitWidth(t.tpe) diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala index 13e0c312..099b6712 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala @@ -8,19 +8,10 @@ import firrtl.PrimOps import firrtl.passes.CheckWidths.WidthTooBig private trait TranslationContext { - def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe)) + def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, firrtl.bitWidth(tpe).toInt) } private object FirrtlExpressionSemantics { - def getWidth(tpe: ir.Type): Int = tpe match { - case ir.UIntType(ir.IntWidth(w)) => w.toInt - case ir.SIntType(ir.IntWidth(w)) => w.toInt - case ir.ClockType => 1 - case ir.ResetType => 1 - case ir.AnalogType(ir.IntWidth(w)) => w.toInt - case other => throw new RuntimeException(s"Cannot handle type $other") - } - def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = { val eSMT = e match { case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts) @@ -183,5 +174,7 @@ private object FirrtlExpressionSemantics { case _: ir.SIntType => true case _ => false } - private def getWidth(e: ir.Expression): Int = getWidth(e.tpe) + + // Helper function + private def getWidth(e: ir.Expression): Int = firrtl.bitWidth(e.tpe).toInt } diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala index fea92c75..cfab61b9 100644 --- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala +++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala @@ -4,6 +4,7 @@ package firrtl.backends.experimental.smt import firrtl.annotations.{MemoryInitAnnotation, NoTargetAnnotation, PresetRegAnnotation} +import firrtl.bitWidth import FirrtlExpressionSemantics.getWidth import firrtl.backends.experimental.smt.random._ import firrtl.graph.MutableDiGraph @@ -258,7 +259,7 @@ private class ModuleToTransitionSystem extends LazyLogging { val inputs = connects.filter(_._1.startsWith(m.name)).toMap // derive the type of the memory from the dataType and depth - val dataWidth = getWidth(m.dataType) + val dataWidth = bitWidth(m.dataType).toInt val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1) val memSymbol = ArraySymbol(m.name, indexWidth, dataWidth) @@ -391,7 +392,7 @@ private class ModuleScanner( if (isClock(p.tpe)) { clocks.add(p.name) } else { - inputs.append(BVSymbol(p.name, getWidth(p.tpe))) + inputs.append(BVSymbol(p.name, bitWidth(p.tpe).toInt)) } case ir.Output => if (!isClock(p.tpe)) { // we ignore clock outputs @@ -406,7 +407,7 @@ private class ModuleScanner( assert(!isClock(tpe), "rand should never be a clock!") // we model random sources as inputs and ignore the enable signal infos.append(name -> info) - inputs.append(BVSymbol(name, getWidth(tpe))) + inputs.append(BVSymbol(name, bitWidth(tpe).toInt)) case ir.DefWire(info, name, tpe) => namespace.newName(name) if (!isClock(tpe)) { @@ -427,7 +428,7 @@ private class ModuleScanner( insertDummyAssignsForUnusedOutputs(reset) insertDummyAssignsForUnusedOutputs(init) infos.append(name -> info) - val width = getWidth(tpe) + val width = bitWidth(tpe).toInt val resetExpr = onExpression(reset, 1) val initExpr = onExpression(init, width) registers.append((name, width, resetExpr, initExpr)) @@ -436,7 +437,7 @@ private class ModuleScanner( infos.append(m.name -> m.info) val outputs = getMemOutputs(m) (getMemInputs(m) ++ outputs).foreach(memSignals.append(_)) - val dataWidth = getWidth(m.dataType) + val dataWidth = bitWidth(m.dataType).toInt outputs.foreach(name => unusedOutputs(name) = BVSymbol(name, dataWidth)) memories.append(m) case ir.Connect(info, loc, expr) => @@ -445,7 +446,7 @@ private class ModuleScanner( val name = loc.serialize insertDummyAssignsForUnusedOutputs(expr) infos.append(name -> info) - connects.append((name, onExpression(expr, getWidth(loc.tpe)))) + connects.append((name, onExpression(expr, bitWidth(loc.tpe).toInt))) } case i @ ir.IsInvalid(info, loc) => if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!") @@ -507,7 +508,7 @@ private class ModuleScanner( if (isClock(p.tpe)) { clocks.add(pName) } else { - inputs.append(BVSymbol(pName, getWidth(p.tpe))) + inputs.append(BVSymbol(pName, bitWidth(p.tpe).toInt)) } } else { if (!isClock(p.tpe)) { // we ignore clock outputs @@ -524,8 +525,8 @@ private class ModuleScanner( // sanity checks for ports were done already using the UninterpretedModule.checkModule function val ports = tpe.asInstanceOf[ir.BundleType].fields - val outputs = ports.filter(_.flip == ir.Default).map(p => BVSymbol(p.name, getWidth(p.tpe))) - val inputs = ports.filterNot(_.flip == ir.Default).map(p => BVSymbol(p.name, getWidth(p.tpe))) + val outputs = ports.filter(_.flip == ir.Default).map(p => BVSymbol(p.name, bitWidth(p.tpe).toInt)) + val inputs = ports.filterNot(_.flip == ir.Default).map(p => BVSymbol(p.name, bitWidth(p.tpe).toInt)) assert(anno.stateBits == 0, "TODO: implement support for uninterpreted stateful modules!") |
