1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
|
// SPDX-License-Identifier: Apache-2.0
package firrtl.passes
import firrtl.PrimOps._
import firrtl.ir._
import firrtl.Mappers._
import firrtl.constraint.{IsFloor, IsKnown, IsMul}
import firrtl.options.Dependency
import firrtl.Transform
/** Replaces IntervalType with SIntType, three AST walks:
* 1) Align binary points
* - adds shift operators to primop args and connections
* - does not affect declaration- or inferred-types
* 2) Replace declaration IntervalType's with SIntType's
* - for each declaration:
* a. remove non-zero binary points
* b. remove open bounds
* c. replace with SIntType
* 3) Run InferTypes
*/
class TrimIntervals extends Pass {
override def prerequisites =
Seq(Dependency(ResolveKinds), Dependency(InferTypes), Dependency(ResolveFlows), Dependency[InferBinaryPoints])
override def optionalPrerequisiteOf = Seq.empty
override def invalidates(a: Transform) = false
def run(c: Circuit): Circuit = {
// Open -> closed
val firstPass = InferTypes.run(c.map(replaceModuleInterval))
// Align binary points and adjust range accordingly (loss of precision changes range)
firstPass.map(alignModuleBP)
}
/* Replace interval types */
private def replaceModuleInterval(m: DefModule): DefModule = m.map(replaceStmtInterval).map(replacePortInterval)
private def replaceStmtInterval(s: Statement): Statement = s.map(replaceTypeInterval).map(replaceStmtInterval)
private def replacePortInterval(p: Port): Port = p.map(replaceTypeInterval)
private def replaceTypeInterval(t: Type): Type = t match {
case i @ IntervalType(l: IsKnown, u: IsKnown, IntWidth(p)) =>
IntervalType(Closed(i.min.get), Closed(i.max.get), IntWidth(p))
case i: IntervalType => i
case v => v.map(replaceTypeInterval)
}
/* Align interval binary points -- BINARY POINT ALIGNMENT AFFECTS RANGE INFERENCE! */
private def alignModuleBP(m: DefModule): DefModule = m.map(alignStmtBP)
private def alignStmtBP(s: Statement): Statement = s.map(alignExpBP) match {
case c @ Connect(info, loc, expr) =>
loc.tpe match {
case IntervalType(_, _, p) => Connect(info, loc, fixBP(p)(expr))
case _ => c
}
case c @ PartialConnect(info, loc, expr) =>
loc.tpe match {
case IntervalType(_, _, p) => PartialConnect(info, loc, fixBP(p)(expr))
case _ => c
}
case other => other.map(alignStmtBP)
}
// Note - wrap/clip/squeeze ignore the binary point of the second argument, thus not needed to be aligned
// Note - Mul does not need its binary points aligned, because multiplication is cool like that
private val opsToFix = Seq(Add, Sub, Lt, Leq, Gt, Geq, Eq, Neq /*, Wrap, Clip, Squeeze*/ )
private def alignExpBP(e: Expression): Expression = e.map(alignExpBP) match {
case DoPrim(SetP, Seq(arg), Seq(const), tpe: IntervalType) => fixBP(IntWidth(const))(arg)
case DoPrim(o, args, consts, t)
if opsToFix.contains(o) &&
(args.map(_.tpe).collect { case x: IntervalType => x }).size == args.size =>
val maxBP = args.map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _)
DoPrim(o, args.map { a => fixBP(maxBP)(a) }, consts, t)
case Mux(cond, tval, fval, t: IntervalType) =>
val maxBP = Seq(tval, fval).map(_.tpe).collect { case IntervalType(_, _, p) => p }.reduce(_ max _)
Mux(cond, fixBP(maxBP)(tval), fixBP(maxBP)(fval), t)
case other => other
}
private def fixBP(p: Width)(e: Expression): Expression = (p, e.tpe) match {
case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired == current => e
case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired > current =>
DoPrim(IncP, Seq(e), Seq(desired - current), IntervalType(l, u, IntWidth(desired)))
case (IntWidth(desired), IntervalType(l, u, IntWidth(current))) if desired < current =>
val shiftAmt = current - desired
val shiftGain = BigDecimal(BigInt(1) << shiftAmt.toInt)
val shiftMul = Closed(BigDecimal(1) / shiftGain)
val bpGain = BigDecimal(BigInt(1) << current.toInt)
// BP is inferred at this point
// y = floor(x * 2^(-amt + bp)) gets rid of precision --> y * 2^(-bp + amt)
val newBPRes = Closed(shiftGain / bpGain)
val bpResInv = Closed(bpGain)
val newL = IsMul(IsFloor(IsMul(IsMul(l, shiftMul), bpResInv)), newBPRes)
val newU = IsMul(IsFloor(IsMul(IsMul(u, shiftMul), bpResInv)), newBPRes)
DoPrim(DecP, Seq(e), Seq(current - desired), IntervalType(CalcBound(newL), CalcBound(newU), IntWidth(desired)))
case x => sys.error(s"Shouldn't be here: $x")
}
}
// vim: set ts=4 sw=4 et:
|