aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/backends
diff options
context:
space:
mode:
authorchick2020-08-14 19:47:53 -0700
committerJack Koenig2020-08-14 19:47:53 -0700
commit6fc742bfaf5ee508a34189400a1a7dbffe3f1cac (patch)
tree2ed103ee80b0fba613c88a66af854ae9952610ce /src/main/scala/firrtl/backends
parentb516293f703c4de86397862fee1897aded2ae140 (diff)
All of src/ formatted with scalafmt
Diffstat (limited to 'src/main/scala/firrtl/backends')
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala82
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala86
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala316
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala29
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala89
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala49
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala131
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala20
-rw-r--r--src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala83
9 files changed, 502 insertions, 383 deletions
diff --git a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
index f7ab9927..66690f56 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/Btor2Serializer.scala
@@ -26,7 +26,8 @@ private class Btor2Serializer private () {
private def comment(c: String): Unit = { lines += s"; $c" }
private def trailingComment(c: String): Unit = {
val lastLine = lines.last
- val newLine = if(lastLine.contains(';')) { lastLine + " " + c} else { lastLine + " ; " + c }
+ val newLine = if (lastLine.contains(';')) { lastLine + " " + c }
+ else { lastLine + " ; " + c }
lines(lines.size - 1) = newLine
}
@@ -38,54 +39,55 @@ private class Btor2Serializer private () {
// bit vector expression serialization
private def s(expr: BVExpr): Int = expr match {
case BVLiteral(value, width) => lit(value, width)
- case BVSymbol(name, _) => symbols(name)
- case BVExtend(e, 0, _) => s(e)
- case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by")
- case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by")
+ case BVSymbol(name, _) => symbols(name)
+ case BVExtend(e, 0, _) => s(e)
+ case BVExtend(e, by, true) => line(s"sext ${t(expr.width)} ${s(e)} $by")
+ case BVExtend(e, by, false) => line(s"uext ${t(expr.width)} ${s(e)} $by")
case BVSlice(e, hi, lo) =>
- if (lo == 0 && hi == e.width - 1) { s(e) } else {
+ if (lo == 0 && hi == e.width - 1) { s(e) }
+ else {
line(s"slice ${t(expr.width)} ${s(e)} $hi $lo")
}
- case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b)
- case BVNot(BVNot(e)) => s(e)
- case BVNot(e) => unary("not", expr.width, e)
- case BVNegate(e) => unary("neg", expr.width, e)
- case BVReduceAnd(e) => unary("redand", expr.width, e)
- case BVReduceOr(e) => unary("redor", expr.width, e)
- case BVReduceXor(e) => unary("redxor", expr.width, e)
- case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b)
- case BVImplies(a, b) => binary("implies", expr.width, a, b)
- case BVEqual(a, b) => binary("eq", expr.width, a, b)
- case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}")
- case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b)
+ case BVNot(BVEqual(a, b)) => binary("neq", expr.width, a, b)
+ case BVNot(BVNot(e)) => s(e)
+ case BVNot(e) => unary("not", expr.width, e)
+ case BVNegate(e) => unary("neg", expr.width, e)
+ case BVReduceAnd(e) => unary("redand", expr.width, e)
+ case BVReduceOr(e) => unary("redor", expr.width, e)
+ case BVReduceXor(e) => unary("redxor", expr.width, e)
+ case BVImplies(BVLiteral(v, 1), b) if v == 1 => s(b)
+ case BVImplies(a, b) => binary("implies", expr.width, a, b)
+ case BVEqual(a, b) => binary("eq", expr.width, a, b)
+ case ArrayEqual(a, b) => line(s"eq ${t(expr.width)} ${s(a)} ${s(b)}")
+ case BVComparison(Compare.Greater, a, b, false) => binary("ugt", expr.width, a, b)
case BVComparison(Compare.GreaterEqual, a, b, false) => binary("ugte", expr.width, a, b)
- case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b)
- case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b)
- case BVOp(op, a, b) => binary(s(op), expr.width, a, b)
- case BVConcat(a, b) => binary("concat", expr.width, a, b)
+ case BVComparison(Compare.Greater, a, b, true) => binary("sgt", expr.width, a, b)
+ case BVComparison(Compare.GreaterEqual, a, b, true) => binary("sgte", expr.width, a, b)
+ case BVOp(op, a, b) => binary(s(op), expr.width, a, b)
+ case BVConcat(a, b) => binary("concat", expr.width, a, b)
case ArrayRead(array, index) =>
line(s"read ${t(expr.width)} ${s(array)} ${s(index)}")
case BVIte(cond, tru, fals) =>
line(s"ite ${t(expr.width)} ${s(cond)} ${s(tru)} ${s(fals)}")
- case r : BVRawExpr =>
+ case r: BVRawExpr =>
throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}")
}
private def s(op: Op.Value): String = op match {
- case Op.And => "and"
- case Op.Or => "or"
- case Op.Xor => "xor"
+ case Op.And => "and"
+ case Op.Or => "or"
+ case Op.Xor => "xor"
case Op.ArithmeticShiftRight => "sra"
- case Op.ShiftRight => "srl"
- case Op.ShiftLeft => "sll"
- case Op.Add => "add"
- case Op.Mul => "mul"
- case Op.Sub => "sub"
- case Op.SignedDiv => "sdiv"
- case Op.UnsignedDiv => "udiv"
- case Op.SignedMod => "smod"
- case Op.SignedRem => "srem"
- case Op.UnsignedRem => "urem"
+ case Op.ShiftRight => "srl"
+ case Op.ShiftLeft => "sll"
+ case Op.Add => "add"
+ case Op.Mul => "mul"
+ case Op.Sub => "sub"
+ case Op.SignedDiv => "sdiv"
+ case Op.UnsignedDiv => "udiv"
+ case Op.SignedMod => "smod"
+ case Op.SignedRem => "srem"
+ case Op.UnsignedRem => "urem"
}
private def unary(op: String, width: Int, e: BVExpr): Int = line(s"$op ${t(width)} ${s(e)}")
@@ -123,18 +125,18 @@ private class Btor2Serializer private () {
// It is essential to model memories, so any support in the wild should be fairly well tested.
line(s"ite ${t(expr.indexWidth, expr.dataWidth)} ${s(cond)} ${s(tru)} ${s(fals)}")
case ArrayConstant(e, _) => s(e)
- case r : ArrayRawExpr =>
+ case r: ArrayRawExpr =>
throw new RuntimeException(s"Raw expressions should never reach the btor2 encoder!: ${r.serialized}")
}
private def s(expr: SMTExpr): Int = expr match {
- case b: BVExpr => s(b)
+ case b: BVExpr => s(b)
case a: ArrayExpr => s(a)
}
// serialize the type of the expression
private def t(expr: SMTExpr): Int = expr match {
- case b: BVExpr => t(b.width)
+ case b: BVExpr => t(b.width)
case a: ArrayExpr => t(a.indexWidth, a.dataWidth)
}
@@ -145,7 +147,7 @@ private class Btor2Serializer private () {
symbols(name) = id
if (!skipOutput && sys.outputs.contains(name)) line(s"output $id ; $name")
if (sys.assumes.contains(name)) line(s"constraint $id ; $name")
- if (sys.asserts.contains(name)){
+ if (sys.asserts.contains(name)) {
val invertedId = line(s"not ${t(1)} $id")
line(s"bad $invertedId ; $name")
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
index 0a223840..efa89687 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlExpressionSemantics.scala
@@ -9,26 +9,26 @@ import firrtl.passes.CheckWidths.WidthTooBig
private trait TranslationContext {
def getReference(name: String, tpe: ir.Type): BVExpr = BVSymbol(name, FirrtlExpressionSemantics.getWidth(tpe))
- def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe))
- def getRandom(width: Int): BVExpr
+ def getRandom(tpe: ir.Type): BVExpr = getRandom(FirrtlExpressionSemantics.getWidth(tpe))
+ def getRandom(width: Int): BVExpr
}
private object FirrtlExpressionSemantics {
def getWidth(tpe: ir.Type): Int = tpe match {
- case ir.UIntType(ir.IntWidth(w)) => w.toInt
- case ir.SIntType(ir.IntWidth(w)) => w.toInt
- case ir.ClockType => 1
- case ir.ResetType => 1
+ case ir.UIntType(ir.IntWidth(w)) => w.toInt
+ case ir.SIntType(ir.IntWidth(w)) => w.toInt
+ case ir.ClockType => 1
+ case ir.ResetType => 1
case ir.AnalogType(ir.IntWidth(w)) => w.toInt
- case other => throw new RuntimeException(s"Cannot handle type $other")
+ case other => throw new RuntimeException(s"Cannot handle type $other")
}
def toSMT(e: ir.Expression)(implicit ctx: TranslationContext): BVExpr = {
val eSMT = e match {
case ir.DoPrim(op, args, consts, _) => onPrim(op, args, consts)
- case r : ir.Reference => ctx.getReference(r.serialize, r.tpe)
- case r : ir.SubField => ctx.getReference(r.serialize, r.tpe)
- case r : ir.SubIndex => ctx.getReference(r.serialize, r.tpe)
+ case r: ir.Reference => ctx.getReference(r.serialize, r.tpe)
+ case r: ir.SubField => ctx.getReference(r.serialize, r.tpe)
+ case r: ir.SubIndex => ctx.getReference(r.serialize, r.tpe)
case ir.UIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt)
case ir.SIntLiteral(value, ir.IntWidth(width)) => BVLiteral(value, width.toInt)
case ir.Mux(cond, tval, fval, _) =>
@@ -38,7 +38,10 @@ private object FirrtlExpressionSemantics {
val tru = toSMT(value)
BVIte(toSMT(cond), tru, ctx.getRandom(tpe))
}
- assert(eSMT.width == getWidth(e), "We aim to always produce a SMT expression of the same width as the firrtl expression.")
+ assert(
+ eSMT.width == getWidth(e),
+ "We aim to always produce a SMT expression of the same width as the firrtl expression."
+ )
eSMT
}
@@ -47,8 +50,8 @@ private object FirrtlExpressionSemantics {
forceWidth(toSMT(e), isSigned(e), width, allowNarrow)
private def forceWidth(eSMT: BVExpr, eSigned: Boolean, width: Int, allowNarrow: Boolean = false): BVExpr = {
- if(eSMT.width == width) { eSMT }
- else if(width < eSMT.width) {
+ if (eSMT.width == width) { eSMT }
+ else if (width < eSMT.width) {
assert(allowNarrow, s"Narrowing from ${eSMT.width} bits to $width bits is not allowed!")
BVSlice(eSMT, width - 1, 0)
} else {
@@ -57,8 +60,13 @@ private object FirrtlExpressionSemantics {
}
// see "Primitive Operations" section in the Firrtl Specification
- private def onPrim(op: ir.PrimOp, args: Seq[ir.Expression], consts: Seq[BigInt])(implicit ctx: TranslationContext):
- BVExpr = {
+ private def onPrim(
+ op: ir.PrimOp,
+ args: Seq[ir.Expression],
+ consts: Seq[BigInt]
+ )(
+ implicit ctx: TranslationContext
+ ): BVExpr = {
(op, args, consts) match {
case (PrimOps.Add, Seq(e1, e2), _) =>
val width = args.map(getWidth).max + 1
@@ -70,7 +78,7 @@ private object FirrtlExpressionSemantics {
val width = args.map(getWidth).sum
BVOp(Op.Mul, toSMT(e1, width), toSMT(e2, width))
case (PrimOps.Div, Seq(num, den), _) =>
- val (width, op) = if(isSigned(num)) {
+ val (width, op) = if (isSigned(num)) {
(getWidth(num) + 1, Op.SignedDiv)
} else { (getWidth(num), Op.UnsignedDiv) }
// "The result of a division where den is zero is undefined."
@@ -83,11 +91,12 @@ private object FirrtlExpressionSemantics {
val width = getWidth(num) + 1
BVOp(Op.SignedDiv, toSMT(num, width), toSMT(den, width))
case (PrimOps.Rem, Seq(num, den), _) =>
- val op = if(isSigned(num)) Op.SignedRem else Op.UnsignedRem
+ val op = if (isSigned(num)) Op.SignedRem else Op.UnsignedRem
val width = args.map(getWidth).max
val resWidth = args.map(getWidth).min
val res = BVOp(op, toSMT(num, width), toSMT(den, width))
- if(res.width > resWidth) { BVSlice(res, resWidth - 1, 0) } else { res }
+ if (res.width > resWidth) { BVSlice(res, resWidth - 1, 0) }
+ else { res }
case (PrimOps.Lt, Seq(e1, e2), _) =>
val width = args.map(getWidth).max
BVNot(BVComparison(Compare.GreaterEqual, toSMT(e1, width), toSMT(e2, width), isSigned(e1)))
@@ -108,25 +117,29 @@ private object FirrtlExpressionSemantics {
BVNot(BVEqual(toSMT(e1, width), toSMT(e2, width)))
case (PrimOps.Pad, Seq(e), Seq(n)) =>
val width = getWidth(e)
- if(n <= width) { toSMT(e) } else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) }
- case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e) ; toSMT(e)
- case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e) ; toSMT(e)
+ if (n <= width) { toSMT(e) }
+ else { BVExtend(toSMT(e), n.toInt - width, isSigned(e)) }
+ case (PrimOps.AsUInt, Seq(e), _) => checkForClockInCast(PrimOps.AsUInt, e); toSMT(e)
+ case (PrimOps.AsSInt, Seq(e), _) => checkForClockInCast(PrimOps.AsSInt, e); toSMT(e)
case (PrimOps.AsFixedPoint, Seq(e), _) => throw new AssertionError("Fixed-Point numbers need to be lowered!")
- case (PrimOps.AsClock, Seq(e), _) => toSMT(e)
+ case (PrimOps.AsClock, Seq(e), _) => toSMT(e)
case (PrimOps.AsAsyncReset, Seq(e), _) =>
checkForClockInCast(PrimOps.AsAsyncReset, e)
throw new AssertionError(s"Asynchronous resets are not supported! Cannot cast ${e.serialize}.")
- case (PrimOps.Shl, Seq(e), Seq(n)) => if(n == 0) { toSMT(e) } else {
- val zeros = BVLiteral(0, n.toInt)
- BVConcat(toSMT(e), zeros)
- }
+ case (PrimOps.Shl, Seq(e), Seq(n)) =>
+ if (n == 0) { toSMT(e) }
+ else {
+ val zeros = BVLiteral(0, n.toInt)
+ BVConcat(toSMT(e), zeros)
+ }
case (PrimOps.Shr, Seq(e), Seq(n)) =>
val width = getWidth(e)
// "If n is greater than or equal to the bit-width of e,
// the resulting value will be zero for unsigned types
// and the sign bit for signed types"
- if(n >= width) {
- if(isSigned(e)) { BV1BitZero } else { BVSlice(toSMT(e), width - 1, width - 1) }
+ if (n >= width) {
+ if (isSigned(e)) { BV1BitZero }
+ else { BVSlice(toSMT(e), width - 1, width - 1) }
} else {
BVSlice(toSMT(e), width - 1, n.toInt)
}
@@ -135,9 +148,11 @@ private object FirrtlExpressionSemantics {
BVOp(Op.ShiftLeft, toSMT(e1, width), toSMT(e2, width))
case (PrimOps.Dshr, Seq(e1, e2), _) =>
val width = getWidth(e1)
- val o = if(isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight
+ val o = if (isSigned(e1)) Op.ArithmeticShiftRight else Op.ShiftRight
BVOp(o, toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Cvt, Seq(e), _) => if(isSigned(e)) { toSMT(e) } else { BVConcat(BV1BitZero, toSMT(e)) }
+ case (PrimOps.Cvt, Seq(e), _) =>
+ if (isSigned(e)) { toSMT(e) }
+ else { BVConcat(BV1BitZero, toSMT(e)) }
case (PrimOps.Neg, Seq(e), _) => BVNegate(BVExtend(toSMT(e), 1, isSigned(e)))
case (PrimOps.Not, Seq(e), _) => BVNot(toSMT(e))
case (PrimOps.And, Seq(e1, e2), _) =>
@@ -149,10 +164,10 @@ private object FirrtlExpressionSemantics {
case (PrimOps.Xor, Seq(e1, e2), _) =>
val width = args.map(getWidth).max
BVOp(Op.Xor, toSMT(e1, width), toSMT(e2, width))
- case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e))
- case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e))
- case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e))
- case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2))
+ case (PrimOps.Andr, Seq(e), _) => BVReduceAnd(toSMT(e))
+ case (PrimOps.Orr, Seq(e), _) => BVReduceOr(toSMT(e))
+ case (PrimOps.Xorr, Seq(e), _) => BVReduceXor(toSMT(e))
+ case (PrimOps.Cat, Seq(e1, e2), _) => BVConcat(toSMT(e1), toSMT(e2))
case (PrimOps.Bits, Seq(e), Seq(hi, lo)) => BVSlice(toSMT(e), hi.toInt, lo.toInt)
case (PrimOps.Head, Seq(e), Seq(n)) =>
val width = getWidth(e)
@@ -167,7 +182,8 @@ private object FirrtlExpressionSemantics {
}
/** For now we strictly forbid casting clocks to anything else.
- * Eventually this should be replaced by a more sophisticated clock analysis pass. */
+ * Eventually this should be replaced by a more sophisticated clock analysis pass.
+ */
private def checkForClockInCast(cast: ir.PrimOp, signal: ir.Expression): Unit = {
assert(signal.tpe != ir.ClockType, s"Cannot cast (${cast.serialize}) clock expression ${signal.serialize}!")
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
index b3a2ff17..0888b062 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/FirrtlToTransitionSystem.scala
@@ -11,7 +11,16 @@ import firrtl.passes.PassException
import firrtl.stage.Forms
import firrtl.stage.TransformManager.TransformDependency
import firrtl.transforms.PropagatePresetAnnotations
-import firrtl.{CircuitState, DependencyAPIMigration, MemoryArrayInit, MemoryInitValue, MemoryScalarInit, Transform, Utils, ir}
+import firrtl.{
+ ir,
+ CircuitState,
+ DependencyAPIMigration,
+ MemoryArrayInit,
+ MemoryInitValue,
+ MemoryScalarInit,
+ Transform,
+ Utils
+}
import logger.LazyLogging
import scala.collection.mutable
@@ -22,15 +31,21 @@ import scala.collection.mutable
private case class State(sym: SMTSymbol, init: Option[SMTExpr], next: Option[SMTExpr])
private case class Signal(name: String, e: BVExpr) { def toSymbol: BVSymbol = BVSymbol(name, e.width) }
private case class TransitionSystem(
- name: String, inputs: Array[BVSymbol], states: Array[State], signals: Array[Signal],
- outputs: Set[String], assumes: Set[String], asserts: Set[String], fair: Set[String],
- comments: Map[String, String] = Map(), header: Array[String] = Array()) {
+ name: String,
+ inputs: Array[BVSymbol],
+ states: Array[State],
+ signals: Array[Signal],
+ outputs: Set[String],
+ assumes: Set[String],
+ asserts: Set[String],
+ fair: Set[String],
+ comments: Map[String, String] = Map(),
+ header: Array[String] = Array()) {
def serialize: String = {
(Iterator(name) ++
inputs.map(i => s"input ${i.name} : ${SMTExpr.serializeType(i)}") ++
signals.map(s => s"${s.name} : ${SMTExpr.serializeType(s.e)} = ${s.e}") ++
- states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")
- ).mkString("\n")
+ states.map(s => s"state ${s.sym} = [init] ${s.init} [next] ${s.next}")).mkString("\n")
}
}
@@ -53,26 +68,30 @@ object FirrtlToTransitionSystem extends Transform with DependencyAPIMigration {
// run the preset pass to extract all preset registers and remove preset reset signals
val afterPreset = presetPass.execute(state)
val circuit = afterPreset.circuit
- val presetRegs = afterPreset.annotations
- .collect { case PresetRegAnnotation(target) if target.module == circuit.main => target.ref }.toSet
+ val presetRegs = afterPreset.annotations.collect {
+ case PresetRegAnnotation(target) if target.module == circuit.main => target.ref
+ }.toSet
// collect all non-random memory initialization
val memInit = afterPreset.annotations.collect { case a: MemoryInitAnnotation if !a.isRandomInit => a }
- .filter(_.target.module == circuit.main).map(a => a.target.ref -> a.initValue).toMap
+ .filter(_.target.module == circuit.main)
+ .map(a => a.target.ref -> a.initValue)
+ .toMap
// convert the main module
val main = circuit.modules.find(_.name == circuit.main).get
val sys = main match {
case x: ir.ExtModule =>
throw new ExtModuleException(
- "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog.")
+ "External modules are not supported by the SMT backend. Use yosys if you need to convert Verilog."
+ )
case m: ir.Module =>
- new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit=memInit)
+ new ModuleToTransitionSystem().run(m, presetRegs = presetRegs, memInit = memInit)
}
val sortedSys = TopologicalSort.run(sys)
val anno = TransitionSystemAnnotation(sortedSys)
- state.copy(circuit=circuit, annotations = afterPreset.annotations :+ anno )
+ state.copy(circuit = circuit, annotations = afterPreset.annotations :+ anno)
}
}
@@ -94,18 +113,23 @@ private object UnsupportedException {
}
private class ExtModuleException(s: String) extends PassException(s)
-private class AsyncResetException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering)
-private class MultiClockException(s: String) extends PassException(s+UnsupportedException.HowToRunStuttering)
-private class MissingFeatureException(s: String) extends PassException("Unfortunately the SMT backend does not yet support: " + s)
+private class AsyncResetException(s: String) extends PassException(s + UnsupportedException.HowToRunStuttering)
+private class MultiClockException(s: String) extends PassException(s + UnsupportedException.HowToRunStuttering)
+private class MissingFeatureException(s: String)
+ extends PassException("Unfortunately the SMT backend does not yet support: " + s)
private class ModuleToTransitionSystem extends LazyLogging {
- def run(m: ir.Module, presetRegs: Set[String] = Set(), memInit: Map[String, MemoryInitValue] = Map()): TransitionSystem = {
+ def run(
+ m: ir.Module,
+ presetRegs: Set[String] = Set(),
+ memInit: Map[String, MemoryInitValue] = Map()
+ ): TransitionSystem = {
// first pass over the module to convert expressions; discover state and I/O
val scan = new ModuleScanner(makeRandom)
m.foreachPort(scan.onPort)
// multi-clock support requires the StutteringClock transform to be run
- if(scan.clocks.size > 1) {
+ if (scan.clocks.size > 1) {
throw new MultiClockException(s"The module ${m.name} has more than one clock: ${scan.clocks.mkString(", ")}")
}
m.foreachStmt(scan.onStatement)
@@ -115,14 +139,16 @@ private class ModuleToTransitionSystem extends LazyLogging {
val constraints = scan.assumes.toSet
val bad = scan.asserts.toSet
val isSignal = (scan.wires ++ scan.nodes ++ scan.memSignals).toSet ++ outputs ++ constraints ++ bad
- val signals = scan.connects.filter{ case(name, _) => isSignal.contains(name) }
- .map { case (name, expr) => Signal(name, expr) }
+ val signals = scan.connects.filter { case (name, _) => isSignal.contains(name) }.map {
+ case (name, expr) => Signal(name, expr)
+ }
// turn registers and memories into states
val registers = scan.registers.map(r => r._1 -> r).toMap
- val regStates = scan.connects.filter(s => registers.contains(s._1)).map { case (name, nextExpr) =>
- val (_, width, resetExpr, initExpr) = registers(name)
- onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs)
+ val regStates = scan.connects.filter(s => registers.contains(s._1)).map {
+ case (name, nextExpr) =>
+ val (_, width, resetExpr, initExpr) = registers(name)
+ onRegister(name, width, resetExpr, initExpr, nextExpr, presetRegs)
}
// turn memories into state
val memoryEncoding = new MemoryEncoding(makeRandom)
@@ -135,16 +161,22 @@ private class ModuleToTransitionSystem extends LazyLogging {
} else { s }
}
// filter out any left-over self assignments (this happens when we have a registered read port)
- .filter(s => s match { case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false case _ => true })
+ .filter(s =>
+ s match {
+ case Signal(n0, BVSymbol(n1, _)) if n0 == n1 => false
+ case _ => true
+ }
+ )
val states = regStates.toArray ++ memoryStatesAndOutputs.flatMap(_._1)
// generate comments from infos
val comments = mutable.HashMap[String, String]()
- scan.infos.foreach { case (name, info) =>
- serializeInfo(info).foreach { infoString =>
- if(comments.contains(name)) { comments(name) += InfoSeparator + infoString }
- else { comments(name) = InfoPrefix + infoString }
- }
+ scan.infos.foreach {
+ case (name, info) =>
+ serializeInfo(info).foreach { infoString =>
+ if (comments.contains(name)) { comments(name) += InfoSeparator + infoString }
+ else { comments(name) = InfoPrefix + infoString }
+ }
}
// inputs are original module inputs and any "random" signal we need for modelling
@@ -154,11 +186,28 @@ private class ModuleToTransitionSystem extends LazyLogging {
val header = serializeInfo(m.info).map(InfoPrefix + _).toArray
val fair = Set[String]() // as of firrtl 1.4 we do not support fairness constraints
- TransitionSystem(m.name, inputs.toArray, states, signalsWithMem.toArray, outputs, constraints, bad, fair, comments.toMap, header)
+ TransitionSystem(
+ m.name,
+ inputs.toArray,
+ states,
+ signalsWithMem.toArray,
+ outputs,
+ constraints,
+ bad,
+ fair,
+ comments.toMap,
+ header
+ )
}
- private def onRegister(name: String, width: Int, resetExpr: BVExpr, initExpr: BVExpr,
- nextExpr: BVExpr, presetRegs: Set[String]): State = {
+ private def onRegister(
+ name: String,
+ width: Int,
+ resetExpr: BVExpr,
+ initExpr: BVExpr,
+ nextExpr: BVExpr,
+ presetRegs: Set[String]
+ ): State = {
assert(initExpr.width == width)
assert(nextExpr.width == width)
assert(resetExpr.width == 1)
@@ -166,9 +215,9 @@ private class ModuleToTransitionSystem extends LazyLogging {
val hasReset = initExpr != sym
val isPreset = presetRegs.contains(name)
assert(!isPreset || hasReset, s"Expected preset register $name to have a reset value, not just $initExpr!")
- if(hasReset) {
- val init = if(isPreset) Some(initExpr) else None
- val next = if(isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr)
+ if (hasReset) {
+ val init = if (isPreset) Some(initExpr) else None
+ val next = if (isPreset) nextExpr else BVIte(resetExpr, initExpr, nextExpr)
State(sym, next = Some(next), init = init)
} else {
State(sym, next = Some(nextExpr), init = None)
@@ -179,10 +228,11 @@ private class ModuleToTransitionSystem extends LazyLogging {
private val InfoPrefix = "@ "
private def serializeInfo(info: ir.Info): Option[String] = info match {
case ir.NoInfo => None
- case f : ir.FileInfo => Some(f.escaped)
- case m : ir.MultiInfo =>
+ case f: ir.FileInfo => Some(f.escaped)
+ case m: ir.MultiInfo =>
val infos = m.flatten
- if(infos.isEmpty) { None } else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
+ if (infos.isEmpty) { None }
+ else { Some(infos.map(_.escaped).mkString(InfoSeparator)) }
}
private[firrtl] val randoms = mutable.LinkedHashMap[String, BVSymbol]()
@@ -190,7 +240,7 @@ private class ModuleToTransitionSystem extends LazyLogging {
// TODO: actually ensure that there cannot be any name clashes with other identifiers
val suffixes = Iterator(baseName) ++ (0 until 200).map(ii => baseName + "_" + ii)
val name = suffixes.map(s => "RANDOM." + s).find(!randoms.contains(_)).get
- val sym = BVSymbol(name, width)
+ val sym = BVSymbol(name, width)
randoms(name) = sym
sym
}
@@ -198,10 +248,16 @@ private class ModuleToTransitionSystem extends LazyLogging {
private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLogging {
type Connects = Iterable[(String, BVExpr)]
- def onMemory(defMem: ir.DefMemory, connects: Connects, initValue: Option[MemoryInitValue]): (Iterable[State], Connects) = {
+ def onMemory(
+ defMem: ir.DefMemory,
+ connects: Connects,
+ initValue: Option[MemoryInitValue]
+ ): (Iterable[State], Connects) = {
// we can only work on appropriately lowered memories
- assert(defMem.dataType.isInstanceOf[ir.GroundType],
- s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!")
+ assert(
+ defMem.dataType.isInstanceOf[ir.GroundType],
+ s"Memory $defMem is of type ${defMem.dataType} which is not a ground type!"
+ )
assert(defMem.readwriters.isEmpty, "Combined read/write ports are not supported! Please split them up.")
// collect all memory meta-data in a custom class
@@ -214,17 +270,19 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
val init = initValue.map(getInit(m, _))
// parse and check read and write ports
- val writers = defMem.writers.map( w => new WritePort(m, w, inputs))
- val readers = defMem.readers.map( r => new ReadPort(m, r, inputs))
+ val writers = defMem.writers.map(w => new WritePort(m, w, inputs))
+ val readers = defMem.readers.map(r => new ReadPort(m, r, inputs))
// derive next state from all write ports
assert(defMem.writeLatency == 1, "Only memories with write-latency of one are supported.")
- val next: ArrayExpr = if(writers.isEmpty) { m.sym } else {
- if(writers.length > 2) {
+ val next: ArrayExpr = if (writers.isEmpty) { m.sym }
+ else {
+ if (writers.length > 2) {
throw new UnsupportedFeatureException(s"memories with 3+ write ports (${m.name})")
}
val validData = writers.foldLeft[ArrayExpr](m.sym) { case (sym, w) => w.writeTo(sym) }
- if(writers.length == 1) { validData } else {
+ if (writers.length == 1) { validData }
+ else {
assert(writers.length == 2)
val conflict = writers.head.doesConflict(writers.last)
val conflictData = writers.head.makeRandomData("_write_write_collision")
@@ -236,13 +294,13 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
// derive data signals from all read ports
assert(defMem.readLatency >= 0)
- if(defMem.readLatency > 1) {
+ if (defMem.readLatency > 1) {
throw new UnsupportedFeatureException(s"memories with read latency 2+ (${m.name})")
}
- val readPortSignals = if(defMem.readLatency == 0) {
+ val readPortSignals = if (defMem.readLatency == 0) {
readers.map { r =>
// combinatorial read
- if(defMem.readUnderWrite != ir.ReadUnderWrite.New) {
+ if (defMem.readUnderWrite != ir.ReadUnderWrite.New) {
//logger.warn(s"WARN: Memory ${m.name} with combinatorial read port will always return the most recently written entry." +
// s" The read-under-write => ${defMem.readUnderWrite} setting will be ignored.")
}
@@ -251,22 +309,25 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
r.data.name -> data
}
} else { Seq() }
- val readPortStates = if(defMem.readLatency == 1) {
+ val readPortStates = if (defMem.readLatency == 1) {
readers.map { r =>
// we create a register for the read port data
val next = defMem.readUnderWrite match {
case ir.ReadUnderWrite.New =>
- throw new UnsupportedFeatureException(s"registered read ports that return the new value (${m.name}.${r.name})")
- // the thing that makes this hard is to properly handle write conflicts
+ throw new UnsupportedFeatureException(
+ s"registered read ports that return the new value (${m.name}.${r.name})"
+ )
+ // the thing that makes this hard is to properly handle write conflicts
case ir.ReadUnderWrite.Undefined =>
val anyWriteToTheSameAddress = any(writers.map(_.doesConflict(r)))
- if(anyWriteToTheSameAddress == False) { r.readOld() } else {
+ if (anyWriteToTheSameAddress == False) { r.readOld() }
+ else {
val readUnderWriteData = r.makeRandomData("_read_under_write_undefined")
BVIte(anyWriteToTheSameAddress, readUnderWriteData, r.readOld())
}
case ir.ReadUnderWrite.Old => r.readOld()
}
- State(r.data, init=None, next=Some(next))
+ State(r.data, init = None, next = Some(next))
}
} else { Seq() }
@@ -276,16 +337,20 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
private def getInit(m: MemInfo, initValue: MemoryInitValue): ArrayExpr = initValue match {
case MemoryScalarInit(value) => ArrayConstant(BVLiteral(value, m.dataWidth), m.indexWidth)
case MemoryArrayInit(values) =>
- assert(values.length == m.depth,
- s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!")
+ assert(
+ values.length == m.depth,
+ s"Memory ${m.name} of depth ${m.depth} cannot be initialized with an array of length ${values.length}!"
+ )
// in order to get a more compact encoding try to find the most common values
val histogram = mutable.LinkedHashMap[BigInt, Int]()
values.foreach(v => histogram(v) = 1 + histogram.getOrElse(v, 0))
val baseValue = histogram.maxBy(_._2)._1
val base = ArrayConstant(BVLiteral(baseValue, m.dataWidth), m.indexWidth)
- values.zipWithIndex.filterNot(_._1 == baseValue)
- .foldLeft[ArrayExpr](base) { case (array, (value, index)) =>
- ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth))
+ values.zipWithIndex
+ .filterNot(_._1 == baseValue)
+ .foldLeft[ArrayExpr](base) {
+ case (array, (value, index)) =>
+ ArrayStore(array, BVLiteral(index, m.indexWidth), BVLiteral(value, m.dataWidth))
}
case other => throw new RuntimeException(s"Unsupported memory init option: $other")
}
@@ -295,19 +360,20 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
val depth = m.depth
// derrive the type of the memory from the dataType and depth
val dataWidth = getWidth(m.dataType)
- val indexWidth = Utils.getUIntWidth(m.depth - 1) max 1
+ val indexWidth = Utils.getUIntWidth(m.depth - 1).max(1)
val sym = ArraySymbol(m.name, indexWidth, dataWidth)
val prefix = m.name + "."
val fullAddressRange = (BigInt(1) << indexWidth) == m.depth
lazy val depthBV = BVLiteral(m.depth, indexWidth)
def isValidAddress(addr: BVExpr): BVExpr = {
- if(fullAddressRange) { True } else {
+ if (fullAddressRange) { True }
+ else {
BVComparison(Compare.Greater, depthBV, addr, signed = false)
}
}
}
private abstract class MemPort(memory: MemInfo, val name: String, inputs: String => BVExpr) {
- val en: BVSymbol = makeField("en", 1)
+ val en: BVSymbol = makeField("en", 1)
val data: BVSymbol = makeField("data", memory.dataWidth)
val addr: BVSymbol = makeField("addr", memory.indexWidth)
protected def makeField(field: String, width: Int): BVSymbol = BVSymbol(memory.prefix + name + "." + field, width)
@@ -321,11 +387,11 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
val canBeOutOfRange = !memory.fullAddressRange
val canBeDisabled = !enIsTrue
val data = ArrayRead(memory.sym, addr)
- val dataWithRangeCheck = if(canBeOutOfRange) {
+ val dataWithRangeCheck = if (canBeOutOfRange) {
val outOfRangeData = makeRandomData("_addr_out_of_range")
BVIte(memory.isValidAddress(addr), data, outOfRangeData)
} else { data }
- val dataWithEnabledCheck = if(canBeDisabled) {
+ val dataWithEnabledCheck = if (canBeDisabled) {
val disabledData = makeRandomData("_not_enabled")
BVIte(en, dataWithRangeCheck, disabledData)
} else { dataWithRangeCheck }
@@ -333,48 +399,49 @@ private class MemoryEncoding(makeRandom: (String, Int) => BVExpr) extends LazyLo
}
}
private class WritePort(memory: MemInfo, name: String, inputs: String => BVExpr)
- extends MemPort(memory, name, inputs) {
+ extends MemPort(memory, name, inputs) {
assert(inputs(data.name).width == data.width)
val mask: BVSymbol = makeField("mask", 1)
assert(inputs(mask.name).width == mask.width)
val maskIsTrue: Boolean = inputs(mask.name) == True
val doWrite: BVExpr = (enIsTrue, maskIsTrue) match {
- case (true, true) => True
- case (true, false) => mask
- case (false, true) => en
+ case (true, true) => True
+ case (true, false) => mask
+ case (false, true) => en
case (false, false) => and(en, mask)
}
def doesConflict(r: ReadPort): BVExpr = {
val sameAddress = BVEqual(r.addr, addr)
- if(doWrite == True) { sameAddress } else { and(doWrite, sameAddress) }
+ if (doWrite == True) { sameAddress }
+ else { and(doWrite, sameAddress) }
}
def doesConflict(w: WritePort): BVExpr = {
val bothWrite = and(doWrite, w.doWrite)
val sameAddress = BVEqual(addr, w.addr)
- if(bothWrite == True) { sameAddress } else { and(doWrite, sameAddress) }
+ if (bothWrite == True) { sameAddress }
+ else { and(doWrite, sameAddress) }
}
def writeTo(array: ArrayExpr): ArrayExpr = {
- val doUpdate = if(memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr))
- val update = ArrayStore(array, index=addr, data=data)
- if(doUpdate == True) update else ArrayIte(doUpdate, update, array)
+ val doUpdate = if (memory.fullAddressRange) doWrite else and(doWrite, memory.isValidAddress(addr))
+ val update = ArrayStore(array, index = addr, data = data)
+ if (doUpdate == True) update else ArrayIte(doUpdate, update, array)
}
}
private class ReadPort(memory: MemInfo, name: String, inputs: String => BVExpr)
- extends MemPort(memory, name, inputs) {
- }
+ extends MemPort(memory, name, inputs) {}
- private def and(a: BVExpr, b: BVExpr): BVExpr = (a,b) match {
+ private def and(a: BVExpr, b: BVExpr): BVExpr = (a, b) match {
case (True, True) => True
- case (True, x) => x
- case (x, True) => x
- case _ => BVOp(Op.And, a, b)
+ case (True, x) => x
+ case (x, True) => x
+ case _ => BVOp(Op.And, a, b)
}
private def or(a: BVExpr, b: BVExpr): BVExpr = BVOp(Op.Or, a, b)
private val True = BVLiteral(1, 1)
private val False = BVLiteral(0, 1)
- private def all(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) False else b.reduce((a,b) => and(a,b))
- private def any(b: Iterable[BVExpr]): BVExpr = if(b.isEmpty) True else b.reduce((a,b) => or(a,b))
+ private def all(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) False else b.reduce((a, b) => and(a, b))
+ private def any(b: Iterable[BVExpr]): BVExpr = if (b.isEmpty) True else b.reduce((a, b) => or(a, b))
}
// performas a first pass over the module collecting all connections, wires, registers, input and outputs
@@ -399,13 +466,13 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
private val unusedMemOutputs = mutable.LinkedHashMap[String, Int]()
private[firrtl] def onPort(p: ir.Port): Unit = {
- if(isAsyncReset(p.tpe)) {
+ if (isAsyncReset(p.tpe)) {
throw new AsyncResetException(s"Found AsyncReset ${p.name}.")
}
infos.append(p.name -> p.info)
p.direction match {
case ir.Input =>
- if(isClock(p.tpe)) {
+ if (isClock(p.tpe)) {
clocks.add(p.name)
} else {
inputs.append(BVSymbol(p.name, getWidth(p.tpe)))
@@ -416,12 +483,12 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
private[firrtl] def onStatement(s: ir.Statement): Unit = s match {
case ir.DefWire(info, name, tpe) =>
- if(!isClock(tpe)) {
+ if (!isClock(tpe)) {
infos.append(name -> info)
wires.append(name)
}
case ir.DefNode(info, name, expr) =>
- if(!isClock(expr.tpe)) {
+ if (!isClock(expr.tpe)) {
insertDummyAssignsForMemoryOutputs(expr)
infos.append(name -> info)
val e = onExpression(expr, name)
@@ -436,7 +503,7 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
val resetExpr = onExpression(reset, 1, name + "_reset")
val initExpr = onExpression(init, width, name + "_init")
registers.append((name, width, resetExpr, initExpr))
- case m : ir.DefMemory =>
+ case m: ir.DefMemory =>
infos.append(m.name -> m.info)
val outputs = getMemOutputs(m)
(getMemInputs(m) ++ outputs).foreach(memSignals.append(_))
@@ -444,37 +511,39 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
outputs.foreach(name => unusedMemOutputs(name) = dataWidth)
memories.append(m)
case ir.Connect(info, loc, expr) =>
- if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
+ if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
val name = loc.serialize
insertDummyAssignsForMemoryOutputs(expr)
infos.append(name -> info)
connects.append((name, onExpression(expr, getWidth(loc.tpe), name)))
case ir.IsInvalid(info, loc) =>
- if(!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
+ if (!isGroundType(loc.tpe)) error("All connects should have been lowered to ground type!")
val name = loc.serialize
infos.append(name -> info)
connects.append((name, makeRandom(name + "_INVALID", getWidth(loc.tpe))))
case ir.DefInstance(info, name, module, tpe) =>
- if(!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}")
+ if (!tpe.isInstanceOf[ir.BundleType]) error(s"Instance $name of $module has an invalid type: ${tpe.serialize}")
// we treat all instances as blackboxes
- logger.warn(s"WARN: treating instance $name of $module as blackbox. " +
- "Please flatten your hierarchy if you want to include submodules in the formal model.")
+ logger.warn(
+ s"WARN: treating instance $name of $module as blackbox. " +
+ "Please flatten your hierarchy if you want to include submodules in the formal model."
+ )
val ports = tpe.asInstanceOf[ir.BundleType].fields
// skip clock and async reset ports
- ports.filterNot(p => isClock(p.tpe) || isAsyncReset(p.tpe) ).foreach { p =>
- if(!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p")
+ ports.filterNot(p => isClock(p.tpe) || isAsyncReset(p.tpe)).foreach { p =>
+ if (!p.tpe.isInstanceOf[ir.GroundType]) error(s"Instance $name of $module has an invalid port type: $p")
val isOutput = p.flip == ir.Default
val pName = name + "." + p.name
infos.append(pName -> info)
// outputs of the submodule become inputs to our module
- if(isOutput) {
+ if (isOutput) {
inputs.append(BVSymbol(pName, getWidth(p.tpe)))
} else {
outputs.append(pName)
}
}
case s @ ir.Verification(op, info, _, pred, en, msg) =>
- if(op == ir.Formal.Cover) {
+ if (op == ir.Formal.Cover) {
logger.warn(s"WARN: Cover statement was ignored: ${s.serialize}")
} else {
val name = msgToName(op.toString, msg.string)
@@ -483,22 +552,22 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
val e = BVImplies(enabled, predicate)
infos.append(name -> info)
connects.append(name -> e)
- if(op == ir.Formal.Assert) {
+ if (op == ir.Formal.Assert) {
asserts.append(name)
} else {
assumes.append(name)
}
}
- case s : ir.Conditionally =>
+ case s: ir.Conditionally =>
error(s"When conditions are not supported. Please run ExpandWhens: ${s.serialize}")
- case s : ir.PartialConnect =>
+ case s: ir.PartialConnect =>
error(s"PartialConnects are not supported. Please run ExpandConnects: ${s.serialize}")
- case s : ir.Attach =>
+ case s: ir.Attach =>
error(s"Analog wires are not supported in the SMT backend: ${s.serialize}")
- case s : ir.Stop =>
+ case s: ir.Stop =>
// we could wire up the stop condition as output for debug reasons
logger.warn(s"WARN: Stop statements are currently not supported. Ignoring: ${s.serialize}")
- case s : ir.Print =>
+ case s: ir.Print =>
logger.warn(s"WARN: Print statements are not supported. Ignoring: ${s.serialize}")
case other => other.foreachStmt(onStatement)
}
@@ -520,21 +589,22 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
// example:
// m.r.data <= m.r.data ; this is the dummy assign
// test <= m.r.data ; this is the first use of m.r.data
- private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if(unusedMemOutputs.nonEmpty) {
+ private def insertDummyAssignsForMemoryOutputs(next: ir.Expression): Unit = if (unusedMemOutputs.nonEmpty) {
implicit val uses = mutable.ArrayBuffer[String]()
findUnusedMemoryOutputUse(next)
- if(uses.nonEmpty) {
+ if (uses.nonEmpty) {
val useSet = uses.toSet
- unusedMemOutputs.foreach { case (name, width) =>
- if(useSet.contains(name)) connects.append(name -> BVSymbol(name, width))
+ unusedMemOutputs.foreach {
+ case (name, width) =>
+ if (useSet.contains(name)) connects.append(name -> BVSymbol(name, width))
}
useSet.foreach(name => unusedMemOutputs.remove(name))
}
}
private def findUnusedMemoryOutputUse(e: ir.Expression)(implicit uses: mutable.ArrayBuffer[String]): Unit = e match {
- case s : ir.SubField =>
+ case s: ir.SubField =>
val name = s.serialize
- if(unusedMemOutputs.contains(name)) uses.append(name)
+ if (unusedMemOutputs.contains(name)) uses.append(name)
case other => other.foreachExpr(findUnusedMemoryOutputUse)
}
@@ -555,17 +625,18 @@ private class ModuleScanner(makeRandom: (String, Int) => BVExpr) extends LazyLog
// TODO: ensure that we can generate unique names
prefix + "_" + msg.replace(" ", "_").replace("|", "")
}
- private def error(msg: String): Unit = throw new RuntimeException(msg)
+ private def error(msg: String): Unit = throw new RuntimeException(msg)
private def isGroundType(tpe: ir.Type): Boolean = tpe.isInstanceOf[ir.GroundType]
- private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType
+ private def isClock(tpe: ir.Type): Boolean = tpe == ir.ClockType
private def isAsyncReset(tpe: ir.Type): Boolean = tpe == ir.AsyncResetType
}
private object TopologicalSort {
+
/** Ensures that all signals in the resulting system are topologically sorted.
* This is necessary because [[firrtl.transforms.RemoveWires]] does
* not sort assignments to outputs, submodule inputs nor memory ports.
- * */
+ */
def run(sys: TransitionSystem): TransitionSystem = {
val inputsAndStates = sys.inputs.map(_.name) ++ sys.states.map(_.sym.name)
val signalOrder = sort(sys.signals.map(s => s.name -> s.e), inputsAndStates)
@@ -583,23 +654,24 @@ private object TopologicalSort {
val known = new mutable.HashSet[String]() ++ globalSignals
var needsReordering = false
val digraph = new MutableDiGraph[String]
- signals.foreach { case (name, expr) =>
- digraph.addVertex(name)
- val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr)
- uniqueDependencies.foreach { d =>
- if(!known.contains(d)) { needsReordering = true }
- digraph.addPairWithEdge(name, d)
- }
- known.add(name)
+ signals.foreach {
+ case (name, expr) =>
+ digraph.addVertex(name)
+ val uniqueDependencies = mutable.LinkedHashSet[String]() ++ findDependencies(expr)
+ uniqueDependencies.foreach { d =>
+ if (!known.contains(d)) { needsReordering = true }
+ digraph.addPairWithEdge(name, d)
+ }
+ known.add(name)
}
- if(needsReordering) {
+ if (needsReordering) {
Some(digraph.linearize.reverse)
} else { None }
}
private def findDependencies(expr: SMTExpr): List[String] = expr match {
- case BVSymbol(name, _) => List(name)
+ case BVSymbol(name, _) => List(name)
case ArraySymbol(name, _, _) => List(name)
- case other => other.children.flatMap(findDependencies)
+ case other => other.children.flatMap(findDependencies)
}
-} \ No newline at end of file
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala
index 322b8961..1c7ea42f 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTEmitter.scala
@@ -11,8 +11,10 @@ import firrtl.options.Viewer.view
import firrtl.options.{CustomFileEmission, Dependency}
import firrtl.stage.FirrtlOptions
-
-private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform with Emitter with DependencyAPIMigration {
+private[firrtl] abstract class SMTEmitter private[firrtl] ()
+ extends Transform
+ with Emitter
+ with DependencyAPIMigration {
override def prerequisites: Seq[Dependency[Transform]] = Seq(Dependency(FirrtlToTransitionSystem))
override def invalidates(a: Transform): Boolean = false
@@ -30,16 +32,16 @@ private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform w
override protected def execute(state: CircuitState): CircuitState = {
val emitCircuit = state.annotations.exists {
- case EmitCircuitAnnotation(a) if this.getClass == a => true
+ case EmitCircuitAnnotation(a) if this.getClass == a => true
case EmitAllModulesAnnotation(a) if this.getClass == a => error("EmitAllModulesAnnotation not supported!")
- case _ => false
+ case _ => false
}
- if(!emitCircuit) { return state }
+ if (!emitCircuit) { return state }
logger.warn(BleedingEdgeWarning)
- val sys = state.annotations.collectFirst{ case TransitionSystemAnnotation(sys) => sys }.getOrElse {
+ val sys = state.annotations.collectFirst { case TransitionSystemAnnotation(sys) => sys }.getOrElse {
error("Could not find the transition system!")
}
state.copy(annotations = state.annotations :+ serialize(sys))
@@ -52,11 +54,12 @@ private[firrtl] abstract class SMTEmitter private[firrtl] () extends Transform w
}
case class EmittedSMTModelAnnotation(name: String, src: String, outputSuffix: String)
- extends NoTargetAnnotation with CustomFileEmission {
+ extends NoTargetAnnotation
+ with CustomFileEmission {
override protected def baseFileName(annotations: AnnotationSeq): String =
view[FirrtlOptions](annotations).outputFileName.getOrElse(name)
override protected def suffix: Option[String] = Some(outputSuffix)
- override def getBytes: Iterable[Byte] = src.getBytes
+ override def getBytes: Iterable[Byte] = src.getBytes
}
private[firrtl] class Btor2Emitter extends SMTEmitter {
@@ -72,14 +75,14 @@ private[firrtl] class SMTLibEmitter extends SMTEmitter {
override protected def serialize(sys: TransitionSystem): Annotation = {
val hasMemory = sys.states.exists(_.sym.isInstanceOf[ArrayExpr])
val logic = SMTLibSerializer.setLogic(hasMemory) + "\n"
- val header = if(hasMemory) {
+ val header = if (hasMemory) {
"; We have to disable the logic for z3 to accept the non-standard \"as const\"\n" +
- "; see https://github.com/Z3Prover/z3/issues/1803\n" +
- "; for CVC4 you probably want to include the logic\n" +
- ";" + logic
+ "; see https://github.com/Z3Prover/z3/issues/1803\n" +
+ "; for CVC4 you probably want to include the logic\n" +
+ ";" + logic
} else { logic }
val smt = generatedHeader("SMT-LIBv2", sys.name) + header +
SMTTransitionSystemEncoder.encode(sys).map(SMTLibSerializer.serialize).mkString("\n") + "\n"
EmittedSMTModelAnnotation(sys.name, smt, outputSuffix)
}
-} \ No newline at end of file
+}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
index 10a89e8d..ebb9e309 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExpr.scala
@@ -9,7 +9,7 @@ private sealed trait SMTExpr { def children: List[SMTExpr] }
private sealed trait SMTSymbol extends SMTExpr with SMTNullaryExpr { val name: String }
private object SMTSymbol {
def fromExpr(name: String, e: SMTExpr): SMTSymbol = e match {
- case b: BVExpr => BVSymbol(name, b.width)
+ case b: BVExpr => BVSymbol(name, b.width)
case a: ArrayExpr => ArraySymbol(name, a.indexWidth, a.dataWidth)
}
}
@@ -19,19 +19,19 @@ private sealed trait SMTNullaryExpr extends SMTExpr {
private sealed trait BVExpr extends SMTExpr { def width: Int }
private case class BVLiteral(value: BigInt, width: Int) extends BVExpr with SMTNullaryExpr {
- private def minWidth = value.bitLength + (if(value <= 0) 1 else 0)
+ private def minWidth = value.bitLength + (if (value <= 0) 1 else 0)
assert(width > 0, "Zero or negative width literals are not allowed!")
assert(width >= minWidth, "Value (" + value.toString + ") too big for BitVector of width " + width + " bits.")
- override def toString: String = if(width <= 8) {
+ override def toString: String = if (width <= 8) {
width.toString + "'b" + value.toString(2)
} else { width.toString + "'x" + value.toString(16) }
}
private case class BVSymbol(name: String, width: Int) extends BVExpr with SMTSymbol {
- assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
+ assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
assert(!name.contains("\\"), s"Invalid id $name contains `\\`")
assert(width > 0, "Zero width bit vectors are not supported!")
override def toString: String = name
- def toStringWithType: String = name + " : " + SMTExpr.serializeType(this)
+ def toStringWithType: String = name + " : " + SMTExpr.serializeType(this)
}
private sealed trait BVUnaryExpr extends BVExpr {
@@ -41,34 +41,35 @@ private sealed trait BVUnaryExpr extends BVExpr {
private case class BVExtend(e: BVExpr, by: Int, signed: Boolean) extends BVUnaryExpr {
assert(by >= 0, "Extension must be non-negative!")
override val width: Int = e.width + by
- override def toString: String = if(signed) { s"sext($e, $by)" } else { s"zext($e, $by)" }
+ override def toString: String = if (signed) { s"sext($e, $by)" }
+ else { s"zext($e, $by)" }
}
// also known as bit extract operation
private case class BVSlice(e: BVExpr, hi: Int, lo: Int) extends BVUnaryExpr {
assert(lo >= 0, s"lo (lsb) must be non-negative!")
assert(hi >= lo, s"hi (msb) must not be smaller than lo (lsb): msb: $hi lsb: $lo")
assert(e.width > hi, s"Out off bounds hi (msb) access: width: ${e.width} msb: $hi")
- override def width: Int = hi - lo + 1
- override def toString: String = if(hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]"
+ override def width: Int = hi - lo + 1
+ override def toString: String = if (hi == lo) s"$e[$hi]" else s"$e[$hi:$lo]"
}
private case class BVNot(e: BVExpr) extends BVUnaryExpr {
- override val width: Int = e.width
+ override val width: Int = e.width
override def toString: String = s"not($e)"
}
private case class BVNegate(e: BVExpr) extends BVUnaryExpr {
- override val width: Int = e.width
+ override val width: Int = e.width
override def toString: String = s"neg($e)"
}
private case class BVReduceOr(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"redor($e)"
}
private case class BVReduceAnd(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"redand($e)"
}
private case class BVReduceXor(e: BVExpr) extends BVUnaryExpr {
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"redxor($e)"
}
@@ -79,12 +80,12 @@ private sealed trait BVBinaryExpr extends BVExpr {
}
private case class BVImplies(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
assert(a.width == 1 && b.width == 1, s"Both arguments need to be 1-bit!")
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"impl($a, $b)"
}
private case class BVEqual(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
assert(a.width == b.width, s"Both argument need to be the same width!")
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"eq($a, $b)"
}
private object Compare extends Enumeration {
@@ -94,8 +95,8 @@ private case class BVComparison(op: Compare.Value, a: BVExpr, b: BVExpr, signed:
assert(a.width == b.width, s"Both argument need to be the same width!")
override def width: Int = 1
override def toString: String = op match {
- case Compare.Greater => (if(signed) "sgt" else "ugt") + s"($a, $b)"
- case Compare.GreaterEqual => (if(signed) "sgeq" else "ugeq") + s"($a, $b)"
+ case Compare.Greater => (if (signed) "sgt" else "ugt") + s"($a, $b)"
+ case Compare.GreaterEqual => (if (signed) "sgeq" else "ugeq") + s"($a, $b)"
}
}
private object Op extends Enumeration {
@@ -116,81 +117,87 @@ private object Op extends Enumeration {
}
private case class BVOp(op: Op.Value, a: BVExpr, b: BVExpr) extends BVBinaryExpr {
assert(a.width == b.width, s"Both argument need to be the same width!")
- override val width: Int = a.width
+ override val width: Int = a.width
override def toString: String = s"$op($a, $b)"
}
private case class BVConcat(a: BVExpr, b: BVExpr) extends BVBinaryExpr {
- override val width: Int = a.width + b.width
+ override val width: Int = a.width + b.width
override def toString: String = s"concat($a, $b)"
}
private case class ArrayRead(array: ArrayExpr, index: BVExpr) extends BVExpr {
assert(array.indexWidth == index.width, "Index with does not match expected array index width!")
- override val width: Int = array.dataWidth
+ override val width: Int = array.dataWidth
override def toString: String = s"$array[$index]"
override def children: List[SMTExpr] = List(array, index)
}
private case class BVIte(cond: BVExpr, tru: BVExpr, fals: BVExpr) extends BVExpr {
assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!")
assert(tru.width == fals.width, s"Both branches need to be of the same width! ${tru.width} vs ${fals.width}")
- override val width: Int = tru.width
+ override val width: Int = tru.width
override def toString: String = s"ite($cond, $tru, $fals)"
override def children: List[BVExpr] = List(cond, tru, fals)
}
private sealed trait ArrayExpr extends SMTExpr { val indexWidth: Int; val dataWidth: Int }
private case class ArraySymbol(name: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTSymbol {
- assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
+ assert(!name.contains("|"), s"Invalid id $name contains escape character `|`")
assert(!name.contains("\\"), s"Invalid id $name contains `\\`")
override def toString: String = name
- def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>"
+ def toStringWithType: String = s"$name : bv<$indexWidth> -> bv<$dataWidth>"
}
private case class ArrayStore(array: ArrayExpr, index: BVExpr, data: BVExpr) extends ArrayExpr {
assert(array.indexWidth == index.width, "Index with does not match expected array index width!")
assert(array.dataWidth == data.width, "Data with does not match expected array data width!")
- override val dataWidth: Int = array.dataWidth
+ override val dataWidth: Int = array.dataWidth
override val indexWidth: Int = array.indexWidth
- override def toString: String = s"$array[$index := $data]"
- override def children: List[SMTExpr] = List(array, index, data)
+ override def toString: String = s"$array[$index := $data]"
+ override def children: List[SMTExpr] = List(array, index, data)
}
private case class ArrayIte(cond: BVExpr, tru: ArrayExpr, fals: ArrayExpr) extends ArrayExpr {
assert(cond.width == 1, s"Condition needs to be a 1-bit value not ${cond.width}-bit!")
- assert(tru.indexWidth == fals.indexWidth,
- s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}")
- assert(tru.dataWidth == fals.dataWidth,
- s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}")
- override val dataWidth: Int = tru.dataWidth
+ assert(
+ tru.indexWidth == fals.indexWidth,
+ s"Both branches need to be of the same type! ${tru.indexWidth} vs ${fals.indexWidth}"
+ )
+ assert(
+ tru.dataWidth == fals.dataWidth,
+ s"Both branches need to be of the same type! ${tru.dataWidth} vs ${fals.dataWidth}"
+ )
+ override val dataWidth: Int = tru.dataWidth
override val indexWidth: Int = tru.indexWidth
- override def toString: String = s"ite($cond, $tru, $fals)"
- override def children: List[SMTExpr] = List(cond, tru, fals)
+ override def toString: String = s"ite($cond, $tru, $fals)"
+ override def children: List[SMTExpr] = List(cond, tru, fals)
}
private case class ArrayEqual(a: ArrayExpr, b: ArrayExpr) extends BVExpr {
assert(a.indexWidth == b.indexWidth, s"Both argument need to be the same index width!")
assert(a.dataWidth == b.dataWidth, s"Both argument need to be the same data width!")
- override def width: Int = 1
+ override def width: Int = 1
override def toString: String = s"eq($a, $b)"
override def children: List[SMTExpr] = List(a, b)
}
private case class ArrayConstant(e: BVExpr, indexWidth: Int) extends ArrayExpr {
override val dataWidth: Int = e.width
- override def toString: String = s"([$e] x ${ (BigInt(1) << indexWidth) })"
- override def children: List[SMTExpr] = List(e)
+ override def toString: String = s"([$e] x ${(BigInt(1) << indexWidth)})"
+ override def children: List[SMTExpr] = List(e)
}
private object SMTEqual {
- def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a,b) match {
- case (ab : BVExpr, bb : BVExpr) => BVEqual(ab, bb)
- case (aa : ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba)
+ def apply(a: SMTExpr, b: SMTExpr): BVExpr = (a, b) match {
+ case (ab: BVExpr, bb: BVExpr) => BVEqual(ab, bb)
+ case (aa: ArrayExpr, ba: ArrayExpr) => ArrayEqual(aa, ba)
case _ => throw new RuntimeException(s"Cannot compare $a and $b")
}
}
private object SMTExpr {
def serializeType(e: SMTExpr): String = e match {
- case b: BVExpr => s"bv<${b.width}>"
+ case b: BVExpr => s"bv<${b.width}>"
case a: ArrayExpr => s"bv<${a.indexWidth}> -> bv<${a.dataWidth}>"
}
}
// Raw SMTLib encoded expressions as an escape hatch used in the [[SMTTransitionSystemEncoder]]
private case class BVRawExpr(serialized: String, width: Int) extends BVExpr with SMTNullaryExpr
-private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int) extends ArrayExpr with SMTNullaryExpr \ No newline at end of file
+private case class ArrayRawExpr(serialized: String, indexWidth: Int, dataWidth: Int)
+ extends ArrayExpr
+ with SMTNullaryExpr
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
index 14e73253..defc787c 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTExprVisitor.scala
@@ -9,7 +9,7 @@ private object SMTExprVisitor {
type BVFun = BVExpr => BVExpr
def map[T <: SMTExpr](bv: BVFun, ar: ArrayFun)(e: T): T = e match {
- case b: BVExpr => map(b, bv, ar).asInstanceOf[T]
+ case b: BVExpr => map(b, bv, ar).asInstanceOf[T]
case a: ArrayExpr => map(a, bv, ar).asInstanceOf[T]
}
def map[T <: SMTExpr](f: SMTExpr => SMTExpr)(e: T): T =
@@ -17,57 +17,56 @@ private object SMTExprVisitor {
private def map(e: BVExpr, bv: BVFun, ar: ArrayFun): BVExpr = e match {
// nullary
- case old : BVLiteral => bv(old)
- case old : BVSymbol => bv(old)
- case old : BVRawExpr => bv(old)
+ case old: BVLiteral => bv(old)
+ case old: BVSymbol => bv(old)
+ case old: BVRawExpr => bv(old)
// unary
- case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVExtend(n, by, signed))
- case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVSlice(n, hi, lo))
- case old @ BVNot(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNot(n))
- case old @ BVNegate(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVNegate(n))
- case old @ BVReduceAnd(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceAnd(n))
- case old @ BVReduceOr(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceOr(n))
- case old @ BVReduceXor(e) => val n = map(e, bv, ar) ; bv(if(n.eq(e)) old else BVReduceXor(n))
+ case old @ BVExtend(e, by, signed) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVExtend(n, by, signed))
+ case old @ BVSlice(e, hi, lo) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVSlice(n, hi, lo))
+ case old @ BVNot(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVNot(n))
+ case old @ BVNegate(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVNegate(n))
+ case old @ BVReduceAnd(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceAnd(n))
+ case old @ BVReduceOr(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceOr(n))
+ case old @ BVReduceXor(e) => val n = map(e, bv, ar); bv(if (n.eq(e)) old else BVReduceXor(n))
// binary
case old @ BVImplies(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else BVImplies(nA, nB))
case old @ BVEqual(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else BVEqual(nA, nB))
case old @ ArrayEqual(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else ArrayEqual(nA, nB))
case old @ BVComparison(op, a, b, signed) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed))
+ bv(if (nA.eq(a) && nB.eq(b)) old else BVComparison(op, nA, nB, signed))
case old @ BVOp(op, a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else BVOp(op, nA, nB))
case old @ BVConcat(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else BVConcat(nA, nB))
case old @ ArrayRead(a, b) =>
val (nA, nB) = (map(a, bv, ar), map(b, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB))
+ bv(if (nA.eq(a) && nB.eq(b)) old else ArrayRead(nA, nB))
// ternary
case old @ BVIte(a, b, c) =>
val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
- bv(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC))
+ bv(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else BVIte(nA, nB, nC))
}
-
private def map(e: ArrayExpr, bv: BVFun, ar: ArrayFun): ArrayExpr = e match {
- case old : ArrayRawExpr => ar(old)
- case old : ArraySymbol => ar(old)
+ case old: ArrayRawExpr => ar(old)
+ case old: ArraySymbol => ar(old)
case old @ ArrayConstant(e, indexWidth) =>
- val n = map(e, bv, ar) ; ar(if(n.eq(e)) old else ArrayConstant(n, indexWidth))
+ val n = map(e, bv, ar); ar(if (n.eq(e)) old else ArrayConstant(n, indexWidth))
case old @ ArrayStore(a, b, c) =>
val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
- ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC))
+ ar(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayStore(nA, nB, nC))
case old @ ArrayIte(a, b, c) =>
val (nA, nB, nC) = (map(a, bv, ar), map(b, bv, ar), map(c, bv, ar))
- ar(if(nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC))
+ ar(if (nA.eq(a) && nB.eq(b) && nC.eq(c)) old else ArrayIte(nA, nB, nC))
}
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
index 1993da87..bd5e4d8c 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTLibSerializer.scala
@@ -6,83 +6,87 @@ package firrtl.backends.experimental.smt
import scala.util.matching.Regex
/** Converts STM Expressions to a SMTLib compatible string representation.
- * See http://smtlib.cs.uiowa.edu/
- * Assumes well typed expression, so it is advisable to run the TypeChecker
- * before serializing!
- * Automatically converts 1-bit vectors to bool.
- */
+ * See http://smtlib.cs.uiowa.edu/
+ * Assumes well typed expression, so it is advisable to run the TypeChecker
+ * before serializing!
+ * Automatically converts 1-bit vectors to bool.
+ */
private object SMTLibSerializer {
- def setLogic(hasMem: Boolean) = "(set-logic QF_" + (if(hasMem) "A" else "") + "UFBV)"
+ def setLogic(hasMem: Boolean) = "(set-logic QF_" + (if (hasMem) "A" else "") + "UFBV)"
def serialize(e: SMTExpr): String = e match {
- case b : BVExpr => serialize(b)
- case a : ArrayExpr => serialize(a)
+ case b: BVExpr => serialize(b)
+ case a: ArrayExpr => serialize(a)
}
def serializeType(e: SMTExpr): String = e match {
- case b : BVExpr => serializeBitVectorType(b.width)
- case a : ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth)
+ case b: BVExpr => serializeBitVectorType(b.width)
+ case a: ArrayExpr => serializeArrayType(a.indexWidth, a.dataWidth)
}
private def serialize(e: BVExpr): String = e match {
case BVLiteral(value, width) =>
- val mask = (BigInt(1) << width) - 1
- val twosComplement = if(value < 0) { ((~(-value)) & mask) + 1 } else value
- if(width == 1) {
- if(twosComplement == 1) "true" else "false"
+ val mask = (BigInt(1) << width) - 1
+ val twosComplement = if (value < 0) { ((~(-value)) & mask) + 1 }
+ else value
+ if (width == 1) {
+ if (twosComplement == 1) "true" else "false"
} else {
s"(_ bv$twosComplement $width)"
}
- case BVSymbol(name, _) => escapeIdentifier(name)
- case BVExtend(e, 0, _) => serialize(e)
+ case BVSymbol(name, _) => escapeIdentifier(name)
+ case BVExtend(e, 0, _) => serialize(e)
case BVExtend(BVLiteral(value, width), by, false) => serialize(BVLiteral(value, width + by))
case BVExtend(e, by, signed) =>
- val foo = if(signed) "sign_extend" else "zero_extend"
+ val foo = if (signed) "sign_extend" else "zero_extend"
s"((_ $foo $by) ${asBitVector(e)})"
case BVSlice(e, hi, lo) =>
- if(lo == 0 && hi == e.width - 1) { serialize(e)
- } else {
+ if (lo == 0 && hi == e.width - 1) { serialize(e) }
+ else {
val bits = s"((_ extract $hi $lo) ${asBitVector(e)})"
// 1-bit extracts need to be turned into a boolean
- if(lo == hi) { toBool(bits) } else { bits }
+ if (lo == hi) { toBool(bits) }
+ else { bits }
}
case BVNot(BVEqual(a, b)) if a.width == 1 => s"(distinct ${serialize(a)} ${serialize(b)})"
- case BVNot(BVNot(e)) => serialize(e)
- case BVNot(e) => if(e.width == 1) { s"(not ${serialize(e)})" } else { s"(bvnot ${serialize(e)})" }
+ case BVNot(BVNot(e)) => serialize(e)
+ case BVNot(e) =>
+ if (e.width == 1) { s"(not ${serialize(e)})" }
+ else { s"(bvnot ${serialize(e)})" }
case BVNegate(e) => s"(bvneg ${asBitVector(e)})"
case r: BVReduceAnd => serialize(Expander.expand(r))
- case r: BVReduceOr => serialize(Expander.expand(r))
+ case r: BVReduceOr => serialize(Expander.expand(r))
case r: BVReduceXor => serialize(Expander.expand(r))
- case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b)
- case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})"
- case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
- case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
- case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})"
+ case BVImplies(BVLiteral(v, 1), b) if v == 1 => serialize(b)
+ case BVImplies(a, b) => s"(=> ${serialize(a)} ${serialize(b)})"
+ case BVEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
+ case ArrayEqual(a, b) => s"(= ${serialize(a)} ${serialize(b)})"
+ case BVComparison(Compare.Greater, a, b, false) => s"(bvugt ${asBitVector(a)} ${asBitVector(b)})"
case BVComparison(Compare.GreaterEqual, a, b, false) => s"(bvuge ${asBitVector(a)} ${asBitVector(b)})"
- case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})"
- case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})"
+ case BVComparison(Compare.Greater, a, b, true) => s"(bvsgt ${asBitVector(a)} ${asBitVector(b)})"
+ case BVComparison(Compare.GreaterEqual, a, b, true) => s"(bvsge ${asBitVector(a)} ${asBitVector(b)})"
// boolean operations get a special treatment for 1-bit vectors aka bools
case BVOp(Op.And, a, b) if a.width == 1 => s"(and ${serialize(a)} ${serialize(b)})"
- case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})"
+ case BVOp(Op.Or, a, b) if a.width == 1 => s"(or ${serialize(a)} ${serialize(b)})"
case BVOp(Op.Xor, a, b) if a.width == 1 => s"(xor ${serialize(a)} ${serialize(b)})"
- case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})")
- case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})"
- case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})"
- case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})"
- case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
- case BVRawExpr(serialized, _) => serialized
+ case BVOp(op, a, b) if a.width == 1 => toBool(s"(${serialize(op)} ${asBitVector(a)} ${asBitVector(b)})")
+ case BVOp(op, a, b) => s"(${serialize(op)} ${serialize(a)} ${serialize(b)})"
+ case BVConcat(a, b) => s"(concat ${asBitVector(a)} ${asBitVector(b)})"
+ case ArrayRead(array, index) => s"(select ${serialize(array)} ${asBitVector(index)})"
+ case BVIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
+ case BVRawExpr(serialized, _) => serialized
}
def serialize(e: ArrayExpr): String = e match {
- case ArraySymbol(name, _, _) => escapeIdentifier(name)
+ case ArraySymbol(name, _, _) => escapeIdentifier(name)
case ArrayStore(array, index, data) => s"(store ${serialize(array)} ${serialize(index)} ${serialize(data)})"
- case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
- case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})"
+ case ArrayIte(cond, tru, fals) => s"(ite ${serialize(cond)} ${serialize(tru)} ${serialize(fals)})"
+ case c @ ArrayConstant(e, _) => s"((as const ${serializeArrayType(c.indexWidth, c.dataWidth)}) ${serialize(e)})"
case ArrayRawExpr(serialized, _, _) => serialized
}
def serialize(c: SMTCommand): String = c match {
- case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n")
+ case Comment(msg) => msg.split("\n").map("; " + _).mkString("\n")
case DeclareUninterpretedSort(name) => s"(declare-sort ${escapeIdentifier(name)} 0)"
case DefineFunction(name, args, e) =>
val aa = args.map(a => s"(${escapeIdentifier(a._1)} ${a._2})").mkString(" ")
@@ -95,23 +99,24 @@ private object SMTLibSerializer {
private def serializeArrayType(indexWidth: Int, dataWidth: Int): String =
s"(Array ${serializeBitVectorType(indexWidth)} ${serializeBitVectorType(dataWidth)})"
private def serializeBitVectorType(width: Int): String =
- if(width == 1) { "Bool" } else { assert(width > 1) ; s"(_ BitVec $width)" }
+ if (width == 1) { "Bool" }
+ else { assert(width > 1); s"(_ BitVec $width)" }
private def serialize(op: Op.Value): String = op match {
- case Op.And => "bvand"
- case Op.Or => "bvor"
- case Op.Xor => "bvxor"
+ case Op.And => "bvand"
+ case Op.Or => "bvor"
+ case Op.Xor => "bvxor"
case Op.ArithmeticShiftRight => "bvashr"
- case Op.ShiftRight => "bvlshr"
- case Op.ShiftLeft => "bvshl"
- case Op.Add => "bvadd"
- case Op.Mul => "bvmul"
- case Op.Sub => "bvsub"
- case Op.SignedDiv => "bvsdiv"
- case Op.UnsignedDiv => "bvudiv"
- case Op.SignedMod => "bvsmod"
- case Op.SignedRem => "bvsrem"
- case Op.UnsignedRem => "bvurem"
+ case Op.ShiftRight => "bvlshr"
+ case Op.ShiftLeft => "bvshl"
+ case Op.Add => "bvadd"
+ case Op.Mul => "bvmul"
+ case Op.Sub => "bvsub"
+ case Op.SignedDiv => "bvsdiv"
+ case Op.UnsignedDiv => "bvudiv"
+ case Op.SignedMod => "bvsmod"
+ case Op.SignedRem => "bvsrem"
+ case Op.UnsignedRem => "bvurem"
}
private def toBool(e: String): String = s"(= $e (_ bv1 1))"
@@ -119,33 +124,37 @@ private object SMTLibSerializer {
private val bvZero = "(_ bv0 1)"
private val bvOne = "(_ bv1 1)"
private def asBitVector(e: BVExpr): String =
- if(e.width > 1) { serialize(e) } else { s"(ite ${serialize(e)} $bvOne $bvZero)" }
+ if (e.width > 1) { serialize(e) }
+ else { s"(ite ${serialize(e)} $bvOne $bvZero)" }
// See <simple_symbol> definition in the Concrete Syntax Appendix of the SMTLib Spec
private val simple: Regex = raw"[a-zA-Z\+-/\*\=%\?!\.\$$_~&\^<>@][a-zA-Z0-9\+-/\*\=%\?!\.\$$_~&\^<>@]*".r
def escapeIdentifier(name: String): String = name match {
case simple() => name
- case _ => if(name.startsWith("|") && name.endsWith("|")) name else s"|$name|"
+ case _ => if (name.startsWith("|") && name.endsWith("|")) name else s"|$name|"
}
}
/** Expands expressions that are not natively supported by SMTLib */
private object Expander {
def expand(r: BVReduceAnd): BVExpr = {
- if(r.e.width == 1) { r.e } else {
+ if (r.e.width == 1) { r.e }
+ else {
val allOnes = (BigInt(1) << r.e.width) - 1
BVEqual(r.e, BVLiteral(allOnes, r.e.width))
}
}
def expand(r: BVReduceOr): BVExpr = {
- if(r.e.width == 1) { r.e } else {
+ if (r.e.width == 1) { r.e }
+ else {
BVNot(BVEqual(r.e, BVLiteral(0, r.e.width)))
}
}
def expand(r: BVReduceXor): BVExpr = {
- if(r.e.width == 1) { r.e } else {
+ if (r.e.width == 1) { r.e }
+ else {
val bits = (0 until r.e.width).map(ii => BVSlice(r.e, ii, ii))
- bits.reduce[BVExpr]((a,b) => BVOp(Op.Xor, a, b))
+ bits.reduce[BVExpr]((a, b) => BVOp(Op.Xor, a, b))
}
}
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
index e9acc05b..4c60a1b0 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/SMTTransitionSystemEncoder.scala
@@ -10,7 +10,7 @@ import scala.collection.mutable
* It if fairly compact, but unfortunately, the use of an uninterpreted sort for the state
* prevents this encoding from working with boolector.
* For simplicity reasons, we do not support hierarchical designs (no `_h` function).
- * */
+ */
private object SMTTransitionSystemEncoder {
def encode(sys: TransitionSystem): Iterable[SMTCommand] = {
@@ -38,10 +38,10 @@ private object SMTTransitionSystemEncoder {
cmds += DefineFunction(sym.name + suffix, List((State, stateType)), replaceSymbols(e))
}
sys.signals.foreach { signal =>
- val kind = if(sys.outputs.contains(signal.name)) { "output"
- } else if(sys.assumes.contains(signal.name)) { "assume"
- } else if(sys.asserts.contains(signal.name)) { "assert"
- } else { "wire" }
+ val kind = if (sys.outputs.contains(signal.name)) { "output" }
+ else if (sys.assumes.contains(signal.name)) { "assume" }
+ else if (sys.asserts.contains(signal.name)) { "assert" }
+ else { "wire" }
val sym = SMTSymbol.fromExpr(signal.name, signal.e)
cmds ++= toDescription(sym, kind, sys.comments.get)
define(sym, signal.e)
@@ -105,18 +105,18 @@ private object SMTTransitionSystemEncoder {
}
private def andReduce(e: Iterable[BVExpr]): BVExpr =
- if(e.isEmpty) BVLiteral(1, 1) else e.reduce((a,b) => BVOp(Op.And, a, b))
+ if (e.isEmpty) BVLiteral(1, 1) else e.reduce((a, b) => BVOp(Op.And, a, b))
// All signals are modelled with functions that need to be called with the state as argument,
// this replaces all Symbols with function applications to the state.
private def replaceSymbols(e: SMTExpr): SMTExpr = {
SMTExprVisitor.map(symbolToFunApp(_, SignalSuffix, State))(e)
}
- private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr]
+ private def replaceSymbols(e: BVExpr): BVExpr = replaceSymbols(e.asInstanceOf[SMTExpr]).asInstanceOf[BVExpr]
private def symbolToFunApp(sym: SMTExpr, suffix: String, arg: String): SMTExpr = sym match {
- case BVSymbol(name, width) => BVRawExpr(s"(${id(name+suffix)} $arg)", width)
- case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name+suffix)} $arg)", indexWidth, dataWidth)
- case other => other
+ case BVSymbol(name, width) => BVRawExpr(s"(${id(name + suffix)} $arg)", width)
+ case ArraySymbol(name, indexWidth, dataWidth) => ArrayRawExpr(s"(${id(name + suffix)} $arg)", indexWidth, dataWidth)
+ case other => other
}
}
diff --git a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
index d8e203f8..95db95ef 100644
--- a/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
+++ b/src/main/scala/firrtl/backends/experimental/smt/StutteringClockTransform.scala
@@ -3,7 +3,7 @@
package firrtl.backends.experimental.smt
-import firrtl.{CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils, ir}
+import firrtl.{ir, CircuitState, DependencyAPIMigration, Namespace, PrimOps, RenameMap, Transform, Utils}
import firrtl.annotations.{Annotation, CircuitTarget, PresetAnnotation, ReferenceTarget, SingleTargetAnnotation}
import firrtl.ir.EmptyStmt
import firrtl.options.Dependency
@@ -32,16 +32,17 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// since this pass only runs on the main module, inlining needs to happen before
override def optionalPrerequisites: Seq[TransformDependency] = Seq(Dependency[firrtl.passes.InlineInstances])
-
override protected def execute(state: CircuitState): CircuitState = {
- if(state.circuit.modules.size > 1) {
- logger.warn("WARN: StutteringClockTransform currently only supports running on a single module.\n" +
- s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want.")
+ if (state.circuit.modules.size > 1) {
+ logger.warn(
+ "WARN: StutteringClockTransform currently only supports running on a single module.\n" +
+ s"All submodules of ${state.circuit.main} will be ignored! Please inline all submodules if this is not what you want."
+ )
}
// get main module
val main = state.circuit.modules.find(_.name == state.circuit.main).get match {
- case m: ir.Module => m
+ case m: ir.Module => m
case e: ir.ExtModule => unsupportedError(s"Cannot run on extmodule $e")
}
mainName = main.name
@@ -64,19 +65,21 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// replace all other clocks with enable signals, unless they are the global clock
val clocks = portsWithGlobalClock.filter(p => p.tpe == ir.ClockType && p.name != globalClock).map(_.name)
- val clockToEnable = clocks.map{c =>
+ val clockToEnable = clocks.map { c =>
c -> ir.Reference(namespace.newName(c + "_en"), Bool, firrtl.PortKind, firrtl.SourceFlow)
}.toMap
val portsWithEnableSignals = portsWithGlobalClock.map { p =>
- if(clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) } else { p }
+ if (clockToEnable.contains(p.name)) { p.copy(name = clockToEnable(p.name).name, tpe = Bool) }
+ else { p }
}
// replace async reset with synchronous reset (since everything will we synchronous with the global clock)
// unless it is a preset reset
val asyncResets = portsWithEnableSignals.filter(_.tpe == ir.AsyncResetType).map(_.name)
- val isPresetReset = state.annotations.collect{ case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet
+ val isPresetReset = state.annotations.collect { case PresetAnnotation(r) if r.module == main.name => r.ref }.toSet
val resetsToChange = asyncResets.filterNot(isPresetReset).toSet
val portsWithSyncReset = portsWithEnableSignals.map { p =>
- if(resetsToChange.contains(p.name)) { p.copy(tpe = Bool) } else { p }
+ if (resetsToChange.contains(p.name)) { p.copy(tpe = Bool) }
+ else { p }
}
// discover clock and reset connections
@@ -85,8 +88,9 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// rename clocks to clock enable signals
val mRef = CircuitTarget(state.circuit.main).module(main.name)
val renameMap = RenameMap()
- scan.clockToEnable.foreach { case (clk, en) =>
- renameMap.record(mRef.ref(clk), mRef.ref(en.name))
+ scan.clockToEnable.foreach {
+ case (clk, en) =>
+ renameMap.record(mRef.ref(clk), mRef.ref(en.name))
}
// make changes
@@ -103,51 +107,58 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
s match {
// memory field connects
case c @ ir.Connect(_, ir.SubField(ir.SubField(ir.Reference(mem, _, _, _), port, _, _), field, _, _), _)
- if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) =>
+ if ctx.isMem(mem) && ctx.memPortToClockEnable.contains(mem + "." + port) =>
// replace clock with the global clock
- if(field == "clk") {
+ if (field == "clk") {
c.copy(expr = ctx.globalClock)
- } else if(field == "en") {
+ } else if (field == "en") {
val m = ctx.memInfo(mem)
val isWritePort = m.writers.contains(port)
assert(isWritePort || m.readers.contains(port))
// for write ports we guard the write enable with the clock enable signal, similar to registers
- if(isWritePort) {
+ if (isWritePort) {
val clockEn = ctx.memPortToClockEnable(mem + "." + port)
val guardedEnable = and(clockEn, c.expr)
c.copy(expr = guardedEnable)
} else { c }
- } else { c}
+ } else { c }
// register field connects
- case c @ ir.Connect(_, r : ir.Reference, next) if ctx.registerToEnable.contains(r.name) =>
+ case c @ ir.Connect(_, r: ir.Reference, next) if ctx.registerToEnable.contains(r.name) =>
val clockEnable = ctx.registerToEnable(r.name)
val guardedNext = mux(clockEnable, next, r)
c.copy(expr = guardedNext)
// remove other clock wires and nodes
case ir.Connect(_, loc, expr) if expr.tpe == ir.ClockType && ctx.isRemovedClock(loc.serialize) => EmptyStmt
- case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
- case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
+ case ir.DefNode(_, name, value) if value.tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
+ case ir.DefWire(_, name, tpe) if tpe == ir.ClockType && ctx.isRemovedClock(name) => EmptyStmt
// change async reset to synchronous reset
- case ir.Connect(info, loc: ir.Reference, expr: ir.Reference) if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) =>
- ir.Connect(info, loc.copy(tpe=Bool), expr.copy(tpe=Bool))
- case d @ ir.DefNode(_, name, value: ir.Reference) if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) =>
- d.copy(value = value.copy(tpe=Bool))
- case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe=Bool)
+ case ir.Connect(info, loc: ir.Reference, expr: ir.Reference)
+ if expr.tpe == ir.AsyncResetType && ctx.isResetToChange(loc.serialize) =>
+ ir.Connect(info, loc.copy(tpe = Bool), expr.copy(tpe = Bool))
+ case d @ ir.DefNode(_, name, value: ir.Reference)
+ if value.tpe == ir.AsyncResetType && ctx.isResetToChange(name) =>
+ d.copy(value = value.copy(tpe = Bool))
+ case d @ ir.DefWire(_, name, tpe) if tpe == ir.AsyncResetType && ctx.isResetToChange(name) => d.copy(tpe = Bool)
// change memory clock and synchronize reset
case ir.DefRegister(info, name, tpe, clock, reset, init) if ctx.registerToEnable.contains(name) =>
val clockEnable = ctx.registerToEnable(name)
val newReset = reset match {
- case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe=Bool)
- case other => other
+ case r @ ir.Reference(name, _, _, _) if ctx.isResetToChange(name) => r.copy(tpe = Bool)
+ case other => other
}
- val synchronizedReset = if(reset.tpe == ir.AsyncResetType) { newReset } else { and(newReset, clockEnable) }
+ val synchronizedReset = if (reset.tpe == ir.AsyncResetType) { newReset }
+ else { and(newReset, clockEnable) }
ir.DefRegister(info, name, tpe, ctx.globalClock, synchronizedReset, init)
case other => other.mapStmt(onStatement)
}
}
- private def scanClocks(m: ir.Module, initialClockToEnable: Map[String, ir.Reference], resetsToChange: Set[String]): ScanCtx = {
+ private def scanClocks(
+ m: ir.Module,
+ initialClockToEnable: Map[String, ir.Reference],
+ resetsToChange: Set[String]
+ ): ScanCtx = {
implicit val ctx: ScanCtx = new ScanCtx(initialClockToEnable, resetsToChange)
m.foreachStmt(scanClocksAndResets)
ctx
@@ -162,9 +173,9 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
ctx.clockToEnable.get(expr.serialize).foreach { clockEn =>
ctx.clockToEnable(locName) = clockEn
// keep track of memory clocks
- if(loc.isInstanceOf[ir.SubField]) {
+ if (loc.isInstanceOf[ir.SubField]) {
val parts = locName.split('.')
- if(ctx.mems.contains(parts.head)) {
+ if (ctx.mems.contains(parts.head)) {
assert(parts.length == 3 && parts.last == "clk")
ctx.memPortToClockEnable.append(parts.dropRight(1).mkString(".") -> clockEn)
}
@@ -182,11 +193,11 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
ctx.clockToEnable.get(clock.serialize).foreach { clockEnable =>
ctx.registerToEnable.append(name -> clockEnable)
}
- case m : ir.DefMemory =>
+ case m: ir.DefMemory =>
assert(m.readwriters.isEmpty, "Combined read/write ports are not supported!")
assert(m.readLatency == 0 || m.readLatency == 1, "Only read-latency 1 and read latency 0 are supported!")
assert(m.writeLatency == 1, "Only write-latency 1 is supported!")
- if(m.readers.nonEmpty && m.readLatency == 1) {
+ if (m.readers.nonEmpty && m.readLatency == 1) {
unsupportedError("Registers memory read ports are not properly implemented yet :(")
}
ctx.mems(m.name) = m
@@ -233,8 +244,8 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
// memory enables which need to be guarded with clock enables
val memPortToClockEnable: Map[String, ir.Reference] = scanResults.memPortToClockEnable.toMap
// keep track of memory names
- val isMem: String => Boolean = scanResults.mems.contains
- val memInfo: String => ir.DefMemory = scanResults.mems
+ val isMem: String => Boolean = scanResults.mems.contains
+ val memInfo: String => ir.DefMemory = scanResults.mems
val isResetToChange: String => Boolean = scanResults.resetsToChange.contains
}
@@ -250,4 +261,4 @@ class StutteringClockTransform extends Transform with DependencyAPIMigration {
private val Bool = ir.UIntType(ir.IntWidth(1))
}
-private class UnsupportedFeatureException(s: String) extends PassException(s) \ No newline at end of file
+private class UnsupportedFeatureException(s: String) extends PassException(s)