diff options
Diffstat (limited to 'src/main/scala/firrtl/constraint')
| -rw-r--r-- | src/main/scala/firrtl/constraint/Constraint.scala | 22 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/ConstraintSolver.scala | 357 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/Inequality.scala | 24 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/IsAdd.scala | 52 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/IsFloor.scala | 32 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/IsKnown.scala | 44 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/IsMax.scala | 59 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/IsMin.scala | 57 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/IsMul.scala | 52 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/IsNeg.scala | 32 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/IsPow.scala | 33 | ||||
| -rw-r--r-- | src/main/scala/firrtl/constraint/IsVar.scala | 27 |
12 files changed, 791 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/constraint/Constraint.scala b/src/main/scala/firrtl/constraint/Constraint.scala new file mode 100644 index 00000000..247593ee --- /dev/null +++ b/src/main/scala/firrtl/constraint/Constraint.scala @@ -0,0 +1,22 @@ +// See LICENSE for license details. + +package firrtl.constraint + +/** Trait for all Constraint Solver expressions */ +trait Constraint { + def serialize: String + def map(f: Constraint => Constraint): Constraint + val children: Vector[Constraint] + def reduce(): Constraint +} + +/** Trait for constraints with more than one argument */ +trait MultiAry extends Constraint { + def op(a: IsKnown, b: IsKnown): IsKnown + def merge(b1: Option[IsKnown], b2: Option[IsKnown]): Option[IsKnown] = (b1, b2) match { + case (Some(x), Some(y)) => Some(op(x, y)) + case (_, y: Some[_]) => y + case (x: Some[_], _) => x + case _ => None + } +} diff --git a/src/main/scala/firrtl/constraint/ConstraintSolver.scala b/src/main/scala/firrtl/constraint/ConstraintSolver.scala new file mode 100644 index 00000000..52440b15 --- /dev/null +++ b/src/main/scala/firrtl/constraint/ConstraintSolver.scala @@ -0,0 +1,357 @@ +// See LICENSE for license details. + +package firrtl.constraint + +import firrtl._ +import firrtl.ir._ +import firrtl.Utils.throwInternalError +import firrtl.annotations.ReferenceTarget + +import scala.collection.mutable + +/** Forwards-Backwards Constraint Solver + * + * Used for computing [[Width]] and [[Bound]] constraints + * + * Note - this is an O(N) algorithm, but requires exponential memory. We rely on aggressive early optimization + * of constraint expressions to (usually) get around this. + */ +class ConstraintSolver { + + /** Initial, mutable constraint list, with function to add the constraint */ + private val constraints = mutable.ArrayBuffer[Inequality]() + + /** Solved constraints */ + type ConstraintMap = mutable.HashMap[String, (Constraint, Boolean)] + private val solvedConstraintMap = new ConstraintMap() + + + /** Clear all previously recorded/solved constraints */ + def clear(): Unit = { + constraints.clear() + solvedConstraintMap.clear() + } + + /** Updates internal list of inequalities with a new [[GreaterOrEqual]] + * @param big The larger constraint, must be either known or a variable + * @param small The smaller constraint + */ + def addGeq(big: Constraint, small: Constraint, r1: String, r2: String): Unit = (big, small) match { + case (IsVar(name), other: Constraint) => add(GreaterOrEqual(name, other)) + case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints + } + + /** Updates internal list of inequalities with a new [[GreaterOrEqual]] + * @param big The larger constraint, must be either known or a variable + * @param small The smaller constraint + */ + def addGeq(big: Width, small: Width, r1: String, r2: String): Unit = (big, small) match { + case (IsVar(name), other: CalcWidth) => add(GreaterOrEqual(name, other.arg)) + case (IsVar(name), other: IsVar) => add(GreaterOrEqual(name, other)) + case (IsVar(name), other: IntWidth) => add(GreaterOrEqual(name, Implicits.width2constraint(other))) + case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints + } + + /** Updates internal list of inequalities with a new [[LesserOrEqual]] + * @param small The smaller constraint, must be either known or a variable + * @param big The larger constraint + */ + def addLeq(small: Constraint, big: Constraint, r1: String, r2: String): Unit = (small, big) match { + case (IsVar(name), other: Constraint) => add(LesserOrEqual(name, other)) + case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints + } + + /** Updates internal list of inequalities with a new [[LesserOrEqual]] + * @param small The smaller constraint, must be either known or a variable + * @param big The larger constraint + */ + def addLeq(small: Width, big: Width, r1: String, r2: String): Unit = (small, big) match { + case (IsVar(name), other: CalcWidth) => add(LesserOrEqual(name, other.arg)) + case (IsVar(name), other: IsVar) => add(LesserOrEqual(name, other)) + case (IsVar(name), other: IntWidth) => add(LesserOrEqual(name, Implicits.width2constraint(other))) + case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints + } + + /** Returns a solved constraint, if it exists and is solved + * @param b + * @return + */ + def get(b: Constraint): Option[IsKnown] = { + val name = b match { + case IsVar(name) => name + case x => "" + } + solvedConstraintMap.get(name) match { + case None => None + case Some((k: IsKnown, _)) => Some(k) + case Some(_) => None + } + } + + /** Returns a solved width, if it exists and is solved + * @param b + * @return + */ + def get(b: Width): Option[IsKnown] = { + val name = b match { + case IsVar(name) => name + case x => "" + } + solvedConstraintMap.get(name) match { + case None => None + case Some((k: IsKnown, _)) => Some(k) + case Some(_) => None + } + } + + + private def add(c: Inequality) = constraints += c + + + /** Creates an Inequality given a variable name, constraint, and whether its >= or <= + * @param left + * @param right + * @param geq + * @return + */ + private def genConst(left: String, right: Constraint, geq: Boolean): Inequality = geq match { + case true => GreaterOrEqual(left, right) + case false => LesserOrEqual(left, right) + } + + /** For debugging, can serialize the initial constraints */ + def serializeConstraints: String = constraints.mkString("\n") + + /** For debugging, can serialize the solved constraints */ + def serializeSolutions: String = solvedConstraintMap.map{ + case (k, (v, true)) => s"$k >= ${v.serialize}" + case (k, (v, false)) => s"$k <= ${v.serialize}" + }.mkString("\n") + + + + /************* Constraint Solver Engine ****************/ + + /** Merges constraints on the same variable + * + * Returns a new list of Inequalities with a single Inequality per variable + * + * For example, given: + * a >= 1 + b + * a >= 3 + * + * Will return: + * a >= max(3, 1 + b) + * + * @param constraints + * @return + */ + private def mergeConstraints(constraints: Seq[Inequality]): Seq[Inequality] = { + val mergedMap = mutable.HashMap[String, Inequality]() + constraints.foreach { + case c if c.geq && mergedMap.contains(c.left) => + mergedMap(c.left) = genConst(c.left, IsMax(mergedMap(c.left).right, c.right), true) + case c if !c.geq && mergedMap.contains(c.left) => + mergedMap(c.left) = genConst(c.left, IsMin(mergedMap(c.left).right, c.right), false) + case c => + mergedMap(c.left) = c + } + mergedMap.values.toList + } + + + /** Attempts to substitute variables with their corresponding forward-solved constraints + * If no corresponding constraint has been visited yet, keep variable as is + * + * @param forwardSolved ConstraintMap containing earlier forward-solved constraints + * @param constraint Constraint to forward solve + * @return Forward solved constraint + */ + private def forwardSubstitution(forwardSolved: ConstraintMap)(constraint: Constraint): Constraint = { + val x = constraint map forwardSubstitution(forwardSolved) + x match { + case isVar: IsVar => forwardSolved get isVar.name match { + case None => isVar.asInstanceOf[Constraint] + case Some((p, geq)) => + val newT = forwardSubstitution(forwardSolved)(p) + forwardSolved(isVar.name) = (newT, geq) + newT + } + case other => other + } + } + + /** Attempts to substitute variables with their corresponding backwards-solved constraints + * If no corresponding constraint is solved, keep variable as is (as an unsolved constraint, + * which will be reported later) + * + * @param backwardSolved ConstraintMap containing earlier backward-solved constraints + * @param constraint Constraint to backward solve + * @return Backward solved constraint + */ + private def backwardSubstitution(backwardSolved: ConstraintMap)(constraint: Constraint): Constraint = { + constraint match { + case isVar: IsVar => backwardSolved.get(isVar.name) match { + case Some((p, geq)) => p + case _ => isVar + } + case other => other map backwardSubstitution(backwardSolved) + } + } + + /** Remove solvable cycles in an inequality + * + * For example: + * a >= max(1, a) + * + * Can be simplified to: + * a >= 1 + * @param name Name of the variable on left side of inequality + * @param geq Whether inequality is >= or <= + * @param constraint Constraint expression + * @return + */ + private def removeCycle(name: String, geq: Boolean)(constraint: Constraint): Constraint = + if(geq) removeGeqCycle(name)(constraint) else removeLeqCycle(name)(constraint) + + /** Removes solvable cycles of <= inequalities + * @param name Name of the variable on left side of inequality + * @param constraint Constraint expression + * @return + */ + private def removeLeqCycle(name: String)(constraint: Constraint): Constraint = constraint match { + case x if greaterEqThan(name)(x) => VarCon(name) + case isMin: IsMin => IsMin(isMin.children.filter{ c => !greaterEqThan(name)(c)}) + case x => x + } + + /** Removes solvable cycles of >= inequalities + * @param name Name of the variable on left side of inequality + * @param constraint Constraint expression + * @return + */ + private def removeGeqCycle(name: String)(constraint: Constraint): Constraint = constraint match { + case x if lessEqThan(name)(x) => VarCon(name) + case isMax: IsMax => IsMax(isMax.children.filter{c => !lessEqThan(name)(c)}) + case x => x + } + + private def greaterEqThan(name: String)(constraint: Constraint): Boolean = constraint match { + case isMin: IsMin => isMin.children.map(greaterEqThan(name)).reduce(_ && _) + case isAdd: IsAdd => isAdd.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true + case _ => false + } + case isMul: IsMul => isMul.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true + case _ => false + } + case isVar: IsVar if isVar.name == name => true + case _ => false + } + + private def lessEqThan(name: String)(constraint: Constraint): Boolean = constraint match { + case isMax: IsMax => isMax.children.map(lessEqThan(name)).reduce(_ && _) + case isAdd: IsAdd => isAdd.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true + case _ => false + } + case isMul: IsMul => isMul.children match { + case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true + case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true + case _ => false + } + case isVar: IsVar if isVar.name == name => true + case isNeg: IsNeg => isNeg.child match { + case isVar: IsVar if isVar.name == name => true + case _ => false + } + case _ => false + } + + /** Whether a constraint contains the named variable + * @param name Name of variable + * @param constraint Constraint to check + * @return + */ + private def hasVar(name: String)(constraint: Constraint): Boolean = { + var has = false + def rec(constraint: Constraint): Constraint = { + constraint match { + case isVar: IsVar if isVar.name == name => has = true + case _ => + } + constraint map rec + } + rec(constraint) + has + } + + /** Returns illegal constraints, where both a >= and <= inequality are used on the same variable + * @return + */ + def check(): Seq[Inequality] = { + val checkMap = new mutable.HashMap[String, Inequality]() + constraints.foldLeft(Seq[Inequality]()) { (seq, c) => + checkMap.get(c.left) match { + case None => + checkMap(c.left) = c + seq ++ Nil + case Some(x) if x.geq != c.geq => seq ++ Seq(x, c) + case Some(x) => seq ++ Nil + } + } + } + + /** Solves constraints present in collected inequalities + * + * Constraint solving steps: + * 1) Assert no variable has both >= and <= inequalities (it can have multiple of the same kind of inequality) + * 2) Merge constraints of variables having multiple inequalities + * 3) Forward solve inequalities + * a. Iterate through inequalities top-to-bottom, replacing previously seen variables with corresponding + * constraint + * b. For each forward-solved inequality, attempt to remove circular constraints + * c. Forward-solved inequalities without circular constraints are recorded + * 4) Backwards solve inequalities + * a. Iterate through successful forward-solved inequalities bottom-to-top, replacing previously seen variables + * with corresponding constraint + * b. Record solved constraints + */ + def solve(): Unit = { + // 1) Check if any variable has both >= and <= inequalities (which is illegal) + val illegals = check() + if (illegals != Nil) throwInternalError(s"Constraints cannot have both >= and <= inequalities: $illegals") + + // 2) Merge constraints + val uniqueConstraints = mergeConstraints(constraints.toSeq) + + // 3) Forward Solve + val forwardConstraintMap = new ConstraintMap + val orderedVars = mutable.HashMap[Int, String]() + + var index = 0 + for (constraint <- uniqueConstraints) { + //TODO: Risky if used improperly... need to check whether substitution from a leq to a geq is negated (always). + val subbedRight = forwardSubstitution(forwardConstraintMap)(constraint.right) + val name = constraint.left + val finishedRight = removeCycle(name, constraint.geq)(subbedRight) + if (!hasVar(name)(finishedRight)) { + forwardConstraintMap(name) = (finishedRight, constraint.geq) + orderedVars(index) = name + index += 1 + } + } + + // 4) Backwards Solve + for (i <- (orderedVars.size - 1) to 0 by -1) { + val name = orderedVars(i) // Should visit `orderedVars` backward + val (forwardRight, forwardGeq) = forwardConstraintMap(name) + val solvedRight = backwardSubstitution(solvedConstraintMap)(forwardRight) + solvedConstraintMap(name) = (solvedRight, forwardGeq) + } + } +} diff --git a/src/main/scala/firrtl/constraint/Inequality.scala b/src/main/scala/firrtl/constraint/Inequality.scala new file mode 100644 index 00000000..0fa1d2eb --- /dev/null +++ b/src/main/scala/firrtl/constraint/Inequality.scala @@ -0,0 +1,24 @@ +// See LICENSE for license details. + +package firrtl.constraint + +/** Represents either greater or equal to or less than or equal to + * Is passed to the constraint solver to resolve + */ +trait Inequality { + def left: String + def right: Constraint + def geq: Boolean +} + +case class GreaterOrEqual(left: String, right: Constraint) extends Inequality { + val geq = true + override def toString: String = s"$left >= ${right.serialize}" +} + +case class LesserOrEqual(left: String, right: Constraint) extends Inequality { + val geq = false + override def toString: String = s"$left <= ${right.serialize}" +} + + diff --git a/src/main/scala/firrtl/constraint/IsAdd.scala b/src/main/scala/firrtl/constraint/IsAdd.scala new file mode 100644 index 00000000..e177a8b9 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsAdd.scala @@ -0,0 +1,52 @@ +// See LICENSE for license details. + + +package firrtl.constraint + +// Is case class because writing tests is easier due to equality is not object equality +case class IsAdd private (known: Option[IsKnown], + maxs: Vector[IsMax], + mins: Vector[IsMin], + others: Vector[Constraint]) extends Constraint with MultiAry { + + def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 + b2 + + lazy val children: Vector[Constraint] = { + if(known.nonEmpty) known.get +: (maxs ++ mins ++ others) else maxs ++ mins ++ others + } + + def addChild(x: Constraint): IsAdd = x match { + case k: IsKnown => new IsAdd(merge(Some(k), known), maxs, mins, others) + case add: IsAdd => new IsAdd(merge(known, add.known), maxs ++ add.maxs, mins ++ add.mins, others ++ add.others) + case max: IsMax => new IsAdd(known, maxs :+ max, mins, others) + case min: IsMin => new IsAdd(known, maxs, mins :+ min, others) + case other => new IsAdd(known, maxs, mins, others :+ other) + } + + override def serialize: String = "(" + children.map(_.serialize).mkString(" + ") + ")" + + override def map(f: Constraint=>Constraint): Constraint = IsAdd(children.map(f)) + + def reduce(): Constraint = { + if(children.size == 1) children.head else { + (known, maxs, mins, others) match { + case (Some(k), _, _, _) if k.value == 0 => new IsAdd(None, maxs, mins, others).reduce() + case (Some(k), Vector(max), Vector(), Vector()) => max.map { o => IsAdd(k, o) }.reduce() + case (Some(k), Vector(), Vector(min), Vector()) => min.map { o => IsAdd(k, o) }.reduce() + case _ => this + } + } + } +} + +object IsAdd { + def apply(left: Constraint, right: Constraint): Constraint = (left, right) match { + case (l: IsKnown, r: IsKnown) => l + r + case _ => apply(Seq(left, right)) + } + def apply(children: Seq[Constraint]): Constraint = { + children.foldLeft(new IsAdd(None, Vector(), Vector(), Vector())) { (add, c) => + add.addChild(c) + }.reduce() + } +}
\ No newline at end of file diff --git a/src/main/scala/firrtl/constraint/IsFloor.scala b/src/main/scala/firrtl/constraint/IsFloor.scala new file mode 100644 index 00000000..5de4697e --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsFloor.scala @@ -0,0 +1,32 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsFloor { + def apply(child: Constraint): Constraint = new IsFloor(child, 0).reduce() +} + +case class IsFloor private (child: Constraint, dummyArg: Int) extends Constraint { + + override def reduce(): Constraint = child match { + case k: IsKnown => k.floor + case x: IsAdd => this + case x: IsMul => this + case x: IsNeg => this + case x: IsPow => this + // floor(max(a, b)) -> max(floor(a), floor(b)) + case x: IsMax => IsMax(x.children.map {b => IsFloor(b)}) + case x: IsMin => IsMin(x.children.map {b => IsFloor(b)}) + case x: IsVar => this + // floor(floor(x)) -> floor(x) + case x: IsFloor => x + case _ => this + } + val children = Vector(child) + + override def map(f: Constraint=>Constraint): Constraint = IsFloor(f(child)) + + override def serialize: String = "floor(" + child.serialize + ")" +} + + diff --git a/src/main/scala/firrtl/constraint/IsKnown.scala b/src/main/scala/firrtl/constraint/IsKnown.scala new file mode 100644 index 00000000..5bd25f92 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsKnown.scala @@ -0,0 +1,44 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsKnown { + def unapply(b: Constraint): Option[BigDecimal] = b match { + case k: IsKnown => Some(k.value) + case _ => None + } +} + +/** Constant values must extend this trait see [[firrtl.ir.Closed and firrtl.ir.Open]] */ +trait IsKnown extends Constraint { + val value: BigDecimal + + /** Addition */ + def +(that: IsKnown): IsKnown + + /** Multiplication */ + def *(that: IsKnown): IsKnown + + /** Max */ + def max(that: IsKnown): IsKnown + + /** Min */ + def min(that: IsKnown): IsKnown + + /** Negate */ + def neg: IsKnown + + /** 2 << value */ + def pow: IsKnown + + /** Floor */ + def floor: IsKnown + + override def map(f: Constraint=>Constraint): Constraint = this + + val children: Vector[Constraint] = Vector.empty[Constraint] + + def reduce(): IsKnown = this +} + + diff --git a/src/main/scala/firrtl/constraint/IsMax.scala b/src/main/scala/firrtl/constraint/IsMax.scala new file mode 100644 index 00000000..3f24b7c0 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsMax.scala @@ -0,0 +1,59 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsMax { + def apply(left: Constraint, right: Constraint): Constraint = (left, right) match { + case (l: IsKnown, r: IsKnown) => l max r + case _ => apply(Seq(left, right)) + } + def apply(children: Seq[Constraint]): Constraint = { + val x = children.foldLeft(new IsMax(None, Vector(), Vector())) { (add, c) => + add.addChild(c) + } + x.reduce() + } +} + +case class IsMax private[constraint](known: Option[IsKnown], + mins: Vector[IsMin], + others: Vector[Constraint] + ) extends MultiAry { + + def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 max b2 + + override def serialize: String = "max(" + children.map(_.serialize).mkString(", ") + ")" + + override def map(f: Constraint=>Constraint): Constraint = IsMax(children.map(f)) + + lazy val children: Vector[Constraint] = { + if(known.nonEmpty) known.get +: (mins ++ others) else mins ++ others + } + + def reduce(): Constraint = { + if(children.size == 1) children.head else { + (known, mins, others) match { + case (Some(IsKnown(a)), _, _) => + // Eliminate minimums who have a known minimum value which is smaller than known maximum value + val filteredMins = mins.filter { + case IsMin(Some(IsKnown(i)), _, _) if i <= a => false + case other => true + } + // If a successful filter, rerun reduce + val newMax = new IsMax(known, filteredMins, others) + if(filteredMins.size != mins.size) { + newMax.reduce() + } else newMax + case _ => this + } + } + } + + def addChild(x: Constraint): IsMax = x match { + case k: IsKnown => new IsMax(known = merge(Some(k), known), mins, others) + case max: IsMax => new IsMax(known = merge(known, max.known), max.mins ++ mins, others ++ max.others) + case min: IsMin => new IsMax(known, mins :+ min, others) + case other => new IsMax(known, mins, others :+ other) + } +} + diff --git a/src/main/scala/firrtl/constraint/IsMin.scala b/src/main/scala/firrtl/constraint/IsMin.scala new file mode 100644 index 00000000..ee97e298 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsMin.scala @@ -0,0 +1,57 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsMin { + def apply(left: Constraint, right: Constraint): Constraint = (left, right) match { + case (l: IsKnown, r: IsKnown) => l min r + case _ => apply(Seq(left, right)) + } + def apply(children: Seq[Constraint]): Constraint = { + children.foldLeft(new IsMin(None, Vector(), Vector())) { (add, c) => + add.addChild(c) + }.reduce() + } +} + +case class IsMin private[constraint](known: Option[IsKnown], + maxs: Vector[IsMax], + others: Vector[Constraint] + ) extends MultiAry { + + def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 min b2 + + override def serialize: String = "min(" + children.map(_.serialize).mkString(", ") + ")" + + override def map(f: Constraint=>Constraint): Constraint = IsMin(children.map(f)) + + lazy val children: Vector[Constraint] = { + if(known.nonEmpty) known.get +: (maxs ++ others) else maxs ++ others + } + + def reduce(): Constraint = { + if(children.size == 1) children.head else { + (known, maxs, others) match { + case (Some(IsKnown(i)), _, _) => + // Eliminate maximums who have a known maximum value which is larger than known minimum value + val filteredMaxs = maxs.filter { + case IsMax(Some(IsKnown(a)), _, _) if a >= i => false + case other => true + } + // If a successful filter, rerun reduce + val newMin = new IsMin(known, filteredMaxs, others) + if(filteredMaxs.size != maxs.size) { + newMin.reduce() + } else newMin + case _ => this + } + } + } + + def addChild(x: Constraint): IsMin = x match { + case k: IsKnown => new IsMin(merge(Some(k), known), maxs, others) + case max: IsMax => new IsMin(known, maxs :+ max, others) + case min: IsMin => new IsMin(merge(min.known, known), maxs ++ min.maxs, others ++ min.others) + case other => new IsMin(known, maxs, others :+ other) + } +} diff --git a/src/main/scala/firrtl/constraint/IsMul.scala b/src/main/scala/firrtl/constraint/IsMul.scala new file mode 100644 index 00000000..3f637d75 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsMul.scala @@ -0,0 +1,52 @@ +// See LICENSE for license details. + +package firrtl.constraint + +import firrtl.ir.Closed + +object IsMul { + def apply(left: Constraint, right: Constraint): Constraint = (left, right) match { + case (l: IsKnown, r: IsKnown) => l * r + case _ => apply(Seq(left, right)) + } + def apply(children: Seq[Constraint]): Constraint = { + children.foldLeft(new IsMul(None, Vector())) { (add, c) => + add.addChild(c) + }.reduce() + } +} + +case class IsMul private (known: Option[IsKnown], others: Vector[Constraint]) extends MultiAry { + + def op(b1: IsKnown, b2: IsKnown): IsKnown = b1 * b2 + + lazy val children: Vector[Constraint] = if(known.nonEmpty) known.get +: others else others + + def addChild(x: Constraint): IsMul = x match { + case k: IsKnown => new IsMul(known = merge(Some(k), known), others) + case mul: IsMul => new IsMul(merge(known, mul.known), others ++ mul.others) + case other => new IsMul(known, others :+ other) + } + + override def reduce(): Constraint = { + if(children.size == 1) children.head else { + (known, others) match { + case (Some(Closed(x)), _) if x == BigDecimal(1) => new IsMul(None, others).reduce() + case (Some(Closed(x)), _) if x == BigDecimal(0) => Closed(0) + case (Some(Closed(x)), Vector(m: IsMax)) if x > 0 => + IsMax(m.children.map { c => IsMul(Closed(x), c) }) + case (Some(Closed(x)), Vector(m: IsMax)) if x < 0 => + IsMin(m.children.map { c => IsMul(Closed(x), c) }) + case (Some(Closed(x)), Vector(m: IsMin)) if x > 0 => + IsMin(m.children.map { c => IsMul(Closed(x), c) }) + case (Some(Closed(x)), Vector(m: IsMin)) if x < 0 => + IsMax(m.children.map { c => IsMul(Closed(x), c) }) + case _ => this + } + } + } + + override def map(f: Constraint=>Constraint): Constraint = IsMul(children.map(f)) + + override def serialize: String = "(" + children.map(_.serialize).mkString(" * ") + ")" +} diff --git a/src/main/scala/firrtl/constraint/IsNeg.scala b/src/main/scala/firrtl/constraint/IsNeg.scala new file mode 100644 index 00000000..46f739c6 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsNeg.scala @@ -0,0 +1,32 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsNeg { + def apply(child: Constraint): Constraint = new IsNeg(child, 0).reduce() +} + +// Dummy arg is to get around weird Scala issue that can't differentiate between a +// private constructor and public apply that share the same arguments +case class IsNeg private (child: Constraint, dummyArg: Int) extends Constraint { + override def reduce(): Constraint = child match { + case k: IsKnown => k.neg + case x: IsAdd => IsAdd(x.children.map { b => IsNeg(b) }) + case x: IsMul => IsMul(Seq(IsNeg(x.children.head)) ++ x.children.tail) + case x: IsNeg => x.child + case x: IsPow => this + // -[max(a, b)] -> min[-a, -b] + case x: IsMax => IsMin(x.children.map { b => IsNeg(b) }) + case x: IsMin => IsMax(x.children.map { b => IsNeg(b) }) + case x: IsVar => this + case _ => this + } + + lazy val children = Vector(child) + + override def map(f: Constraint=>Constraint): Constraint = IsNeg(f(child)) + + override def serialize: String = "(-" + child.serialize + ")" +} + + diff --git a/src/main/scala/firrtl/constraint/IsPow.scala b/src/main/scala/firrtl/constraint/IsPow.scala new file mode 100644 index 00000000..54a06bf8 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsPow.scala @@ -0,0 +1,33 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsPow { + def apply(child: Constraint): Constraint = new IsPow(child, 0).reduce() +} + +// Dummy arg is to get around weird Scala issue that can't differentiate between a +// private constructor and public apply that share the same arguments +case class IsPow private (child: Constraint, dummyArg: Int) extends Constraint { + override def reduce(): Constraint = child match { + case k: IsKnown => k.pow + // 2^(a + b) -> 2^a * 2^b + case x: IsAdd => IsMul(x.children.map { b => IsPow(b)}) + case x: IsMul => this + case x: IsNeg => this + case x: IsPow => this + // 2^(max(a, b)) -> max(2^a, 2^b) since two is always positive, so a, b control magnitude + case x: IsMax => IsMax(x.children.map {b => IsPow(b)}) + case x: IsMin => IsMin(x.children.map {b => IsPow(b)}) + case x: IsVar => this + case _ => this + } + + val children = Vector(child) + + override def map(f: Constraint=>Constraint): Constraint = IsPow(f(child)) + + override def serialize: String = "(2^" + child.serialize + ")" +} + + diff --git a/src/main/scala/firrtl/constraint/IsVar.scala b/src/main/scala/firrtl/constraint/IsVar.scala new file mode 100644 index 00000000..98396fa0 --- /dev/null +++ b/src/main/scala/firrtl/constraint/IsVar.scala @@ -0,0 +1,27 @@ +// See LICENSE for license details. + +package firrtl.constraint + +object IsVar { + def unapply(i: Constraint): Option[String] = i match { + case i: IsVar => Some(i.name) + case _ => None + } +} + +/** Extend to be a constraint variable */ +trait IsVar extends Constraint { + + def name: String + + override def serialize: String = name + + override def map(f: Constraint=>Constraint): Constraint = this + + override def reduce() = this + + val children = Vector() +} + +case class VarCon(name: String) extends IsVar + |
