1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
|
// SPDX-License-Identifier: Apache-2.0
package firrtl.constraint
import firrtl._
import firrtl.ir._
import firrtl.Utils.throwInternalError
import scala.collection.mutable
/** Forwards-Backwards Constraint Solver
*
* Used for computing [[firrtl.ir.Width Width]] and [[firrtl.ir.Bound Bound]] constraints
*
* Note - this is an O(N) algorithm, but requires exponential memory. We rely on aggressive early optimization
* of constraint expressions to (usually) get around this.
*/
class ConstraintSolver {
/** Initial, mutable constraint list, with function to add the constraint */
private val constraints = mutable.ArrayBuffer[Inequality]()
/** Solved constraints */
type ConstraintMap = mutable.HashMap[String, (Constraint, Boolean)]
private val solvedConstraintMap = new ConstraintMap()
/** Clear all previously recorded/solved constraints */
def clear(): Unit = {
constraints.clear()
solvedConstraintMap.clear()
}
/** Updates internal list of inequalities with a new [[GreaterOrEqual]]
* @param big The larger constraint, must be either known or a variable
* @param small The smaller constraint
*/
def addGeq(big: Constraint, small: Constraint, r1: String, r2: String): Unit = (big, small) match {
case (IsVar(name), other: Constraint) => add(GreaterOrEqual(name, other))
case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints
}
/** Updates internal list of inequalities with a new [[GreaterOrEqual]]
* @param big The larger constraint, must be either known or a variable
* @param small The smaller constraint
*/
def addGeq(big: Width, small: Width, r1: String, r2: String): Unit = (big, small) match {
case (IsVar(name), other: CalcWidth) => add(GreaterOrEqual(name, other.arg))
case (IsVar(name), other: IsVar) => add(GreaterOrEqual(name, other))
case (IsVar(name), other: IntWidth) => add(GreaterOrEqual(name, Implicits.width2constraint(other)))
case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints
}
/** Updates internal list of inequalities with a new [[LesserOrEqual]]
* @param small The smaller constraint, must be either known or a variable
* @param big The larger constraint
*/
def addLeq(small: Constraint, big: Constraint, r1: String, r2: String): Unit = (small, big) match {
case (IsVar(name), other: Constraint) => add(LesserOrEqual(name, other))
case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints
}
/** Updates internal list of inequalities with a new [[LesserOrEqual]]
* @param small The smaller constraint, must be either known or a variable
* @param big The larger constraint
*/
def addLeq(small: Width, big: Width, r1: String, r2: String): Unit = (small, big) match {
case (IsVar(name), other: CalcWidth) => add(LesserOrEqual(name, other.arg))
case (IsVar(name), other: IsVar) => add(LesserOrEqual(name, other))
case (IsVar(name), other: IntWidth) => add(LesserOrEqual(name, Implicits.width2constraint(other)))
case _ => // Constraints on widths should never error, e.g. attach adds lots of unnecessary constraints
}
/** Returns a solved constraint, if it exists and is solved
* @param b
* @return
*/
def get(b: Constraint): Option[IsKnown] = {
val name = b match {
case IsVar(name) => name
case x => ""
}
solvedConstraintMap.get(name) match {
case None => None
case Some((k: IsKnown, _)) => Some(k)
case Some(_) => None
}
}
/** Returns a solved width, if it exists and is solved
* @param b
* @return
*/
def get(b: Width): Option[IsKnown] = {
val name = b match {
case IsVar(name) => name
case x => ""
}
solvedConstraintMap.get(name) match {
case None => None
case Some((k: IsKnown, _)) => Some(k)
case Some(_) => None
}
}
private def add(c: Inequality) = constraints += c
/** Creates an Inequality given a variable name, constraint, and whether its >= or <=
* @param left
* @param right
* @param geq
* @return
*/
private def genConst(left: String, right: Constraint, geq: Boolean): Inequality = geq match {
case true => GreaterOrEqual(left, right)
case false => LesserOrEqual(left, right)
}
/** For debugging, can serialize the initial constraints */
def serializeConstraints: String = constraints.mkString("\n")
/** For debugging, can serialize the solved constraints */
def serializeSolutions: String = solvedConstraintMap.map {
case (k, (v, true)) => s"$k >= ${v.serialize}"
case (k, (v, false)) => s"$k <= ${v.serialize}"
}.mkString("\n")
/** *********** Constraint Solver Engine ***************
*/
/** Merges constraints on the same variable
*
* Returns a new list of Inequalities with a single Inequality per variable
*
* For example, given:
* a >= 1 + b
* a >= 3
*
* Will return:
* a >= max(3, 1 + b)
*
* @param constraints
* @return
*/
private def mergeConstraints(constraints: Seq[Inequality]): Seq[Inequality] = {
val mergedMap = mutable.HashMap[String, Inequality]()
constraints.foreach {
case c if c.geq && mergedMap.contains(c.left) =>
mergedMap(c.left) = genConst(c.left, IsMax(mergedMap(c.left).right, c.right), true)
case c if !c.geq && mergedMap.contains(c.left) =>
mergedMap(c.left) = genConst(c.left, IsMin(mergedMap(c.left).right, c.right), false)
case c =>
mergedMap(c.left) = c
}
mergedMap.values.toList
}
/** Attempts to substitute variables with their corresponding forward-solved constraints
* If no corresponding constraint has been visited yet, keep variable as is
*
* @param forwardSolved ConstraintMap containing earlier forward-solved constraints
* @param constraint Constraint to forward solve
* @return Forward solved constraint
*/
private def forwardSubstitution(forwardSolved: ConstraintMap)(constraint: Constraint): Constraint = {
val x = constraint.map(forwardSubstitution(forwardSolved))
x match {
case isVar: IsVar =>
forwardSolved.get(isVar.name) match {
case None => isVar.asInstanceOf[Constraint]
case Some((p, geq)) =>
val newT = forwardSubstitution(forwardSolved)(p)
forwardSolved(isVar.name) = (newT, geq)
newT
}
case other => other
}
}
/** Attempts to substitute variables with their corresponding backwards-solved constraints
* If no corresponding constraint is solved, keep variable as is (as an unsolved constraint,
* which will be reported later)
*
* @param backwardSolved ConstraintMap containing earlier backward-solved constraints
* @param constraint Constraint to backward solve
* @return Backward solved constraint
*/
private def backwardSubstitution(backwardSolved: ConstraintMap)(constraint: Constraint): Constraint = {
constraint match {
case isVar: IsVar =>
backwardSolved.get(isVar.name) match {
case Some((p, geq)) => p
case _ => isVar
}
case other => other.map(backwardSubstitution(backwardSolved))
}
}
/** Remove solvable cycles in an inequality
*
* For example:
* a >= max(1, a)
*
* Can be simplified to:
* a >= 1
* @param name Name of the variable on left side of inequality
* @param geq Whether inequality is >= or <=
* @param constraint Constraint expression
* @return
*/
private def removeCycle(name: String, geq: Boolean)(constraint: Constraint): Constraint =
if (geq) removeGeqCycle(name)(constraint) else removeLeqCycle(name)(constraint)
/** Removes solvable cycles of <= inequalities
* @param name Name of the variable on left side of inequality
* @param constraint Constraint expression
* @return
*/
private def removeLeqCycle(name: String)(constraint: Constraint): Constraint = constraint match {
case x if greaterEqThan(name)(x) => VarCon(name)
case isMin: IsMin => IsMin(isMin.children.filter { c => !greaterEqThan(name)(c) })
case x => x
}
/** Removes solvable cycles of >= inequalities
* @param name Name of the variable on left side of inequality
* @param constraint Constraint expression
* @return
*/
private def removeGeqCycle(name: String)(constraint: Constraint): Constraint = constraint match {
case x if lessEqThan(name)(x) => VarCon(name)
case isMax: IsMax => IsMax(isMax.children.filter { c => !lessEqThan(name)(c) })
case x => x
}
private def greaterEqThan(name: String)(constraint: Constraint): Boolean = constraint match {
case isMin: IsMin => isMin.children.map(greaterEqThan(name)).reduce(_ && _)
case isAdd: IsAdd =>
isAdd.children match {
case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true
case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true
case _ => false
}
case isMul: IsMul =>
isMul.children match {
case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value >= 0) => true
case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value >= 0) => true
case _ => false
}
case isVar: IsVar if isVar.name == name => true
case _ => false
}
private def lessEqThan(name: String)(constraint: Constraint): Boolean = constraint match {
case isMax: IsMax => isMax.children.map(lessEqThan(name)).reduce(_ && _)
case isAdd: IsAdd =>
isAdd.children match {
case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true
case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true
case _ => false
}
case isMul: IsMul =>
isMul.children match {
case Seq(isVar: IsVar, isVal: IsKnown) if (isVar.name == name) && (isVal.value <= 0) => true
case Seq(isVal: IsKnown, isVar: IsVar) if (isVar.name == name) && (isVal.value <= 0) => true
case _ => false
}
case isVar: IsVar if isVar.name == name => true
case isNeg: IsNeg =>
isNeg.child match {
case isVar: IsVar if isVar.name == name => true
case _ => false
}
case _ => false
}
/** Whether a constraint contains the named variable
* @param name Name of variable
* @param constraint Constraint to check
* @return
*/
private def hasVar(name: String)(constraint: Constraint): Boolean = {
var has = false
def rec(constraint: Constraint): Constraint = {
constraint match {
case isVar: IsVar if isVar.name == name => has = true
case _ =>
}
constraint.map(rec)
}
rec(constraint)
has
}
/** Returns illegal constraints, where both a >= and <= inequality are used on the same variable
* @return
*/
def check(): Seq[Inequality] = {
val checkMap = new mutable.HashMap[String, Inequality]()
constraints.foldLeft(Seq[Inequality]()) { (seq, c) =>
checkMap.get(c.left) match {
case None =>
checkMap(c.left) = c
seq ++ Nil
case Some(x) if x.geq != c.geq => seq ++ Seq(x, c)
case Some(x) => seq ++ Nil
}
}
}
/** Solves constraints present in collected inequalities
*
* Constraint solving steps:
* 1) Assert no variable has both >= and <= inequalities (it can have multiple of the same kind of inequality)
* 2) Merge constraints of variables having multiple inequalities
* 3) Forward solve inequalities
* a. Iterate through inequalities top-to-bottom, replacing previously seen variables with corresponding
* constraint
* b. For each forward-solved inequality, attempt to remove circular constraints
* c. Forward-solved inequalities without circular constraints are recorded
* 4) Backwards solve inequalities
* a. Iterate through successful forward-solved inequalities bottom-to-top, replacing previously seen variables
* with corresponding constraint
* b. Record solved constraints
*/
def solve(): Unit = {
// 1) Check if any variable has both >= and <= inequalities (which is illegal)
val illegals = check()
if (illegals != Nil) throwInternalError(s"Constraints cannot have both >= and <= inequalities: $illegals")
// 2) Merge constraints
val uniqueConstraints = mergeConstraints(constraints.toSeq)
// 3) Forward Solve
val forwardConstraintMap = new ConstraintMap
val orderedVars = mutable.HashMap[Int, String]()
var index = 0
for (constraint <- uniqueConstraints) {
//TODO: Risky if used improperly... need to check whether substitution from a leq to a geq is negated (always).
val subbedRight = forwardSubstitution(forwardConstraintMap)(constraint.right)
val name = constraint.left
val finishedRight = removeCycle(name, constraint.geq)(subbedRight)
if (!hasVar(name)(finishedRight)) {
forwardConstraintMap(name) = (finishedRight, constraint.geq)
orderedVars(index) = name
index += 1
}
}
// 4) Backwards Solve
for (i <- (orderedVars.size - 1) to 0 by -1) {
val name = orderedVars(i) // Should visit `orderedVars` backward
val (forwardRight, forwardGeq) = forwardConstraintMap(name)
val solvedRight = backwardSubstitution(solvedConstraintMap)(forwardRight)
solvedConstraintMap(name) = (solvedRight, forwardGeq)
}
}
}
|