aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/ir
diff options
context:
space:
mode:
authorAdam Izraelevitz2019-10-18 19:01:19 -0700
committerGitHub2019-10-18 19:01:19 -0700
commitfd981848c7d2a800a15f9acfbf33b57dd1c6225b (patch)
tree3609a301cb0ec867deefea4a0d08425810b00418 /src/main/scala/firrtl/ir
parent973ecf516c0ef2b222f2eb68dc8b514767db59af (diff)
Upstream intervals (#870)
Major features: - Added Interval type, as well as PrimOps asInterval, clip, wrap, and sqz. - Changed PrimOp names: bpset -> setp, bpshl -> incp, bpshr -> decp - Refactored width/bound inferencer into a separate constraint solver - Added transforms to infer, trim, and remove interval bounds - Tests for said features Plan to be released with 1.3
Diffstat (limited to 'src/main/scala/firrtl/ir')
-rw-r--r--src/main/scala/firrtl/ir/IR.scala177
1 files changed, 175 insertions, 2 deletions
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index e721363c..63c620d1 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -3,7 +3,11 @@
package firrtl
package ir
-import Utils.indent
+import Utils.{dec2string, indent, trim}
+import firrtl.constraint.{Constraint, IsKnown, IsVar}
+
+import scala.math.BigDecimal.RoundingMode._
+import scala.collection.mutable
/** Intermediate Representation */
abstract class FirrtlNode {
@@ -103,6 +107,29 @@ object StringLit {
*/
abstract class PrimOp extends FirrtlNode {
def serialize: String = this.toString
+ def propagateType(e: DoPrim): Type = UnknownType
+ def apply(args: Any*): DoPrim = {
+ val groups = args.groupBy {
+ case x: Expression => "exp"
+ case x: BigInt => "int"
+ case x: Int => "int"
+ case other => "other"
+ }
+ val exprs = groups.getOrElse("exp", Nil).collect {
+ case e: Expression => e
+ }
+ val consts = groups.getOrElse("int", Nil).map {
+ _ match {
+ case i: BigInt => i
+ case i: Int => BigInt(i)
+ }
+ }
+ groups.get("other") match {
+ case None =>
+ case Some(x) => sys.error(s"Shouldn't be here: $x")
+ }
+ DoPrim(this, exprs, consts, UnknownType)
+ }
}
abstract class Expression extends FirrtlNode {
@@ -151,7 +178,7 @@ case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Exp
def foreachType(f: Type => Unit): Unit = f(tpe)
def foreachWidth(f: Width => Unit): Unit = Unit
}
-case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression {
+case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type = UnknownType) extends Expression {
def serialize: String = s"mux(${cond.serialize}, ${tval.serialize}, ${fval.serialize})"
def mapExpr(f: Expression => Expression): Expression = Mux(f(cond), f(tval), f(fval), tpe)
def mapType(f: Type => Type): Expression = this.copy(tpe = f(tpe))
@@ -535,6 +562,12 @@ class IntWidth(val width: BigInt) extends Width with Product {
case object UnknownWidth extends Width {
def serialize: String = ""
}
+case class CalcWidth(arg: Constraint) extends Width {
+ def serialize: String = s"calcw(${arg.serialize})"
+}
+case class VarWidth(name: String) extends Width with IsVar {
+ override def serialize: String = s"<$name>"
+}
/** Orientation of [[Field]] */
abstract class Orientation extends FirrtlNode
@@ -550,6 +583,67 @@ case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode
def serialize: String = flip.serialize + name + " : " + tpe.serialize
}
+
+/** Bounds of [[IntervalType]] */
+
+trait Bound extends Constraint {
+ def serialize: String
+}
+case object UnknownBound extends Bound {
+ def serialize: String = "?"
+ def map(f: Constraint=>Constraint): Constraint = this
+ override def reduce(): Constraint = this
+ val children = Vector()
+}
+case class CalcBound(arg: Constraint) extends Bound {
+ def serialize: String = s"calcb(${arg.serialize})"
+ def map(f: Constraint=>Constraint): Constraint = f(arg)
+ override def reduce(): Constraint = arg
+ val children = Vector(arg)
+}
+case class VarBound(name: String) extends IsVar with Bound
+object KnownBound {
+ def unapply(b: Constraint): Option[BigDecimal] = b match {
+ case k: IsKnown => Some(k.value)
+ case _ => None
+ }
+ def unapply(b: Bound): Option[BigDecimal] = b match {
+ case k: IsKnown => Some(k.value)
+ case _ => None
+ }
+}
+case class Open(value: BigDecimal) extends IsKnown with Bound {
+ def serialize = s"o($value)"
+ def +(that: IsKnown): IsKnown = Open(value + that.value)
+ def *(that: IsKnown): IsKnown = that match {
+ case Closed(x) if x == 0 => Closed(x)
+ case _ => Open(value * that.value)
+ }
+ def min(that: IsKnown): IsKnown = if(value < that.value) this else that
+ def max(that: IsKnown): IsKnown = if(value > that.value) this else that
+ def neg: IsKnown = Open(-value)
+ def floor: IsKnown = Open(value.setScale(0, BigDecimal.RoundingMode.FLOOR))
+ def pow: IsKnown = if(value.isBinaryDouble) Open(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here")
+}
+case class Closed(value: BigDecimal) extends IsKnown with Bound {
+ def serialize = s"c($value)"
+ def +(that: IsKnown): IsKnown = that match {
+ case Open(x) => Open(value + x)
+ case Closed(x) => Closed(value + x)
+ }
+ def *(that: IsKnown): IsKnown = that match {
+ case IsKnown(x) if value == BigInt(0) => Closed(0)
+ case Open(x) => Open(value * x)
+ case Closed(x) => Closed(value * x)
+ }
+ def min(that: IsKnown): IsKnown = if(value <= that.value) this else that
+ def max(that: IsKnown): IsKnown = if(value >= that.value) this else that
+ def neg: IsKnown = Closed(-value)
+ def floor: IsKnown = Closed(value.setScale(0, BigDecimal.RoundingMode.FLOOR))
+ def pow: IsKnown = if(value.isBinaryDouble) Closed(BigDecimal(BigInt(1) << value.toInt)) else sys.error("Shouldn't be here")
+}
+
+/** Types of [[FirrtlNode]] */
abstract class Type extends FirrtlNode {
def mapType(f: Type => Type): Type
def mapWidth(f: Width => Width): Type
@@ -586,6 +680,84 @@ case class FixedType(width: Width, point: Width) extends GroundType {
def mapWidth(f: Width => Width): Type = FixedType(f(width), f(point))
def foreachWidth(f: Width => Unit): Unit = { f(width); f(point) }
}
+case class IntervalType(lower: Bound, upper: Bound, point: Width) extends GroundType {
+ override def serialize: String = {
+ val lowerString = lower match {
+ case Open(l) => s"(${dec2string(l)}, "
+ case Closed(l) => s"[${dec2string(l)}, "
+ case UnknownBound => s"[?, "
+ case _ => s"[?, "
+ }
+ val upperString = upper match {
+ case Open(u) => s"${dec2string(u)})"
+ case Closed(u) => s"${dec2string(u)}]"
+ case UnknownBound => s"?]"
+ case _ => s"?]"
+ }
+ val bounds = (lower, upper) match {
+ case (k1: IsKnown, k2: IsKnown) => lowerString + upperString
+ case _ => ""
+ }
+ val pointString = point match {
+ case IntWidth(i) => "." + i.toString
+ case _ => ""
+ }
+ "Interval" + bounds + pointString
+ }
+
+ private lazy val bp = point.asInstanceOf[IntWidth].width.toInt
+ private def precision: Option[BigDecimal] = point match {
+ case IntWidth(width) =>
+ val bp = width.toInt
+ if(bp >= 0) Some(BigDecimal(1) / BigDecimal(BigInt(1) << bp)) else Some(BigDecimal(BigInt(1) << -bp))
+ case other => None
+ }
+
+ def min: Option[BigDecimal] = (lower, precision) match {
+ case (Open(a), Some(prec)) => a / prec match {
+ case x if trim(x).isWhole => Some(a + prec) // add precision for open lower bound i.e. (-4 -> [3 for bp = 0
+ case x => Some(x.setScale(0, CEILING) * prec) // Deal with unrepresentable bound representations (finite BP) -- new closed form l > original l
+ }
+ case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, CEILING) * prec)
+ case other => None
+ }
+
+ def max: Option[BigDecimal] = (upper, precision) match {
+ case (Open(a), Some(prec)) => a / prec match {
+ case x if trim(x).isWhole => Some(a - prec) // subtract precision for open upper bound
+ case x => Some(x.setScale(0, FLOOR) * prec)
+ }
+ case (Closed(a), Some(prec)) => Some((a / prec).setScale(0, FLOOR) * prec)
+ }
+
+ def minAdjusted: Option[BigInt] = min.map(_ * BigDecimal(BigInt(1) << bp) match {
+ case x if trim(x).isWhole | x.doubleValue == 0.0 => x.toBigInt
+ case x => sys.error(s"MinAdjusted should be a whole number: $x. Min is $min. BP is $bp. Precision is $precision. Lower is ${lower}.")
+ })
+
+ def maxAdjusted: Option[BigInt] = max.map(_ * BigDecimal(BigInt(1) << bp) match {
+ case x if trim(x).isWhole => x.toBigInt
+ case x => sys.error(s"MaxAdjusted should be a whole number: $x")
+ })
+
+ /** If bounds are known, calculates the width, otherwise returns UnknownWidth */
+ lazy val width: Width = (point, lower, upper) match {
+ case (IntWidth(i), l: IsKnown, u: IsKnown) =>
+ IntWidth(Math.max(Utils.getSIntWidth(minAdjusted.get), Utils.getSIntWidth(maxAdjusted.get)))
+ case _ => UnknownWidth
+ }
+
+ /** If bounds are known, returns a sequence of all possible values inside this interval */
+ lazy val range: Option[Seq[BigDecimal]] = (lower, upper, point) match {
+ case (l: IsKnown, u: IsKnown, p: IntWidth) =>
+ if(min.get > max.get) Some(Nil) else Some(Range.BigDecimal(min.get, max.get, precision.get))
+ case _ => None
+ }
+
+ override def mapWidth(f: Width => Width): Type = this.copy(point = f(point))
+ override def foreachWidth(f: Width => Unit): Unit = f(point)
+}
+
case class BundleType(fields: Seq[Field]) extends AggregateType {
def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}"
def mapType(f: Type => Type): Type =
@@ -645,6 +817,7 @@ case class Port(
direction: Direction,
tpe: Type) extends FirrtlNode with IsDeclaration {
def serialize: String = s"${direction.serialize} $name : ${tpe.serialize}" + info.serialize
+ def mapType(f: Type => Type): Port = Port(info, name, direction, f(tpe))
}
/** Parameters for external modules */