diff options
| author | Adam Izraelevitz | 2019-10-18 19:01:19 -0700 |
|---|---|---|
| committer | GitHub | 2019-10-18 19:01:19 -0700 |
| commit | fd981848c7d2a800a15f9acfbf33b57dd1c6225b (patch) | |
| tree | 3609a301cb0ec867deefea4a0d08425810b00418 /src/main/scala/firrtl/ir | |
| parent | 973ecf516c0ef2b222f2eb68dc8b514767db59af (diff) | |
Upstream intervals (#870)
Major features:
- Added Interval type, as well as PrimOps asInterval, clip, wrap, and sqz.
- Changed PrimOp names: bpset -> setp, bpshl -> incp, bpshr -> decp
- Refactored width/bound inferencer into a separate constraint solver
- Added transforms to infer, trim, and remove interval bounds
- Tests for said features
Plan to be released with 1.3
Diffstat (limited to 'src/main/scala/firrtl/ir')
| -rw-r--r-- | src/main/scala/firrtl/ir/IR.scala | 177 |
1 files changed, 175 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index e721363c..63c620d1 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -3,7 +3,11 @@ package firrtl package ir -import Utils.indent +import Utils.{dec2string, indent, trim} +import firrtl.constraint.{Constraint, IsKnown, IsVar} + +import scala.math.BigDecimal.RoundingMode._ +import scala.collection.mutable /** Intermediate Representation */ abstract class FirrtlNode { @@ -103,6 +107,29 @@ object StringLit { */ abstract class PrimOp extends FirrtlNode { def serialize: String = this.toString + def propagateType(e: DoPrim): Type = UnknownType + def apply(args: Any*): DoPrim = { + val groups = args.groupBy { + case x: Expression => "exp" + case x: BigInt => "int" + case x: Int => "int" + case other => "other" + } + val exprs = groups.getOrElse("exp", Nil).collect { + case e: Expression => e + } + val consts = groups.getOrElse("int", Nil).map { + _ match { + case i: BigInt => i + case i: Int => BigInt(i) + } + } + groups.get("other") match { + case None => + case Some(x) => sys.error(s"Shouldn't be here: $x") + } + DoPrim(this, exprs, consts, UnknownType) + } } abstract class Expression extends FirrtlNode { @@ -151,7 +178,7 @@ case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Exp def foreachType(f: Type => Unit): Unit = f(tpe) def foreachWidth(f: Width => Unit): Unit = Unit } -case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression { +case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type = UnknownType) extends Expression { def serialize: String = s"mux(${cond.serialize}, ${tval.serialize}, ${fval.serialize})" def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe) def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe)) @@ -535,6 +562,12 @@ class IntWidth(val width: BigInt) extends Width with Product { case object UnknownWidth extends Width { def serialize: String = "" } +case class CalcWidth(arg: Constraint) extends Width { + def serialize: String = s"calcw(${arg.serialize})" +} +case class VarWidth(name: String) extends Width with IsVar { + override def serialize: String = s"<$name>" +} /** Orientation of [[Field]] */ abstract class Orientation extends FirrtlNode @@ -550,6 +583,67 @@ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode def serialize: String = flip.serialize + name + " : " + tpe.serialize } + +/** Bounds of [[IntervalType]] */ + +trait Bound extends Constraint { + def serialize: String +} +case object UnknownBound extends Bound { + def serialize: String = "?" + def map(f: Constraint=>Constraint): Constraint = this + override def reduce(): Constraint = this + val children = Vector() +} +case class CalcBound(arg: Constraint) extends Bound { + def serialize: String = s"calcb(${arg.serialize})" + def map(f: Constraint=>Constraint): Constraint = f(arg) + override def reduce(): Constraint = arg + val children = Vector(arg) +} +case class VarBound(name: String) extends IsVar with Bound +object KnownBound { + def unapply(b: Constraint): Option[BigDecimal] = b match { + case k: IsKnown => Some(k.value) + case _ => None + } + def unapply(b: Bound): Option[BigDecimal] = b match { + case k: IsKnown => Some(k.value) + case _ => None + } +} +case class Open(value: BigDecimal) extends IsKnown with Bound { + def serialize = s"o($value)" + def +(that: IsKnown): IsKnown = Open(value + that.value) + def *(that: IsKnown): IsKnown = that match { + case Closed(x) if x == 0 => Closed(x) + case _ => Open(value * that.value) + } + def min(that: IsKnown): IsKnown = if(value < that.value) this else that + def max(that: IsKnown): IsKnown = if(value > that.value) this else that + def neg: IsKnown = Open(-value) + def floor: IsKnown = Open(value.setScale(0, BigDecimal.RoundingMode.FLOOR)) + def pow: IsKnown = if(value.isBinaryDouble) Open(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here") +} +case class Closed(value: BigDecimal) extends IsKnown with Bound { + def serialize = s"c($value)" + def +(that: IsKnown): IsKnown = that match { + case Open(x) => Open(value + x) + case Closed(x) => Closed(value + x) + } + def *(that: IsKnown): IsKnown = that match { + case IsKnown(x) if value == BigInt(0) => Closed(0) + case Open(x) => Open(value * x) + case Closed(x) => Closed(value * x) + } + def min(that: IsKnown): IsKnown = if(value <= that.value) this else that + def max(that: IsKnown): IsKnown = if(value >= that.value) this else that + def neg: IsKnown = Closed(-value) + def floor: IsKnown = Closed(value.setScale(0, BigDecimal.RoundingMode.FLOOR)) + def pow: IsKnown = if(value.isBinaryDouble) Closed(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here") +} + +/** Types of [[FirrtlNode]] */ abstract class Type extends FirrtlNode { def mapType(f: Type => Type): Type def mapWidth(f: Width => Width): Type @@ -586,6 +680,84 @@ case class FixedType(width: Width, point: Width) extends GroundType { def mapWidth(f: Width => Width): Type = FixedType(f(width), f(point)) def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) } } +case class IntervalType(lower: Bound, upper: Bound, point: Width) extends GroundType { + override def serialize: String = { + val lowerString = lower match { + case Open(l) => s"(${dec2string(l)}, " + case Closed(l) => s"[${dec2string(l)}, " + case UnknownBound => s"[?, " + case _ => s"[?, " + } + val upperString = upper match { + case Open(u) => s"${dec2string(u)})" + case Closed(u) => s"${dec2string(u)}]" + case UnknownBound => s"?]" + case _ => s"?]" + } + val bounds = (lower, upper) match { + case (k1: IsKnown, k2: IsKnown) => lowerString + upperString + case _ => "" + } + val pointString = point match { + case IntWidth(i) => "." + i.toString + case _ => "" + } + "Interval" + bounds + pointString + } + + private lazy val bp = point.asInstanceOf[IntWidth].width.toInt + private def precision: Option[BigDecimal] = point match { + case IntWidth(width) => + val bp = width.toInt + if(bp >= 0) Some(BigDecimal(1) / BigDecimal(BigInt(1) << bp)) else Some(BigDecimal(BigInt(1) << -bp)) + case other => None + } + + def min: Option[BigDecimal] = (lower, precision) match { + case (Open(a), Some(prec)) => a / prec match { + case x if trim(x).isWhole => Some(a + prec) // add precision for open lower bound i.e. (-4 -> [3 for bp = 0 + case x => Some(x.setScale(0, CEILING) * prec) // Deal with unrepresentable bound representations (finite BP) -- new closed form l > original l + } + case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, CEILING) * prec) + case other => None + } + + def max: Option[BigDecimal] = (upper, precision) match { + case (Open(a), Some(prec)) => a / prec match { + case x if trim(x).isWhole => Some(a - prec) // subtract precision for open upper bound + case x => Some(x.setScale(0, FLOOR) * prec) + } + case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, FLOOR) * prec) + } + + def minAdjusted: Option[BigInt] = min.map(_ * BigDecimal(BigInt(1) << bp) match { + case x if trim(x).isWhole | x.doubleValue == 0.0 => x.toBigInt + case x => sys.error(s"MinAdjusted should be a whole number: $x. Min is $min. BP is $bp. Precision is $precision. Lower is ${lower}.") + }) + + def maxAdjusted: Option[BigInt] = max.map(_ * BigDecimal(BigInt(1) << bp) match { + case x if trim(x).isWhole => x.toBigInt + case x => sys.error(s"MaxAdjusted should be a whole number: $x") + }) + + /** If bounds are known, calculates the width, otherwise returns UnknownWidth */ + lazy val width: Width = (point, lower, upper) match { + case (IntWidth(i), l: IsKnown, u: IsKnown) => + IntWidth(Math.max(Utils.getSIntWidth(minAdjusted.get), Utils.getSIntWidth(maxAdjusted.get))) + case _ => UnknownWidth + } + + /** If bounds are known, returns a sequence of all possible values inside this interval */ + lazy val range: Option[Seq[BigDecimal]] = (lower, upper, point) match { + case (l: IsKnown, u: IsKnown, p: IntWidth) => + if(min.get > max.get) Some(Nil) else Some(Range.BigDecimal(min.get, max.get, precision.get)) + case _ => None + } + + override def mapWidth(f: Width => Width): Type = this.copy(point = f(point)) + override def foreachWidth(f: Width => Unit): Unit = f(point) +} + case class BundleType(fields: Seq[Field]) extends AggregateType { def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}" def mapType(f: Type => Type): Type = @@ -645,6 +817,7 @@ case class Port( direction: Direction, tpe: Type) extends FirrtlNode with IsDeclaration { def serialize: String = s"${direction.serialize} $name : ${tpe.serialize}" + info.serialize + def mapType(f: Type => Type): Port = Port(info, name, direction, f(tpe)) } /** Parameters for external modules */ |
