1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
|
// SPDX-License-Identifier: Apache-2.0
package chisel3
import chisel3.experimental.FixedPoint
import chisel3.internal.{prefix, throwException}
import scala.language.experimental.macros
import chisel3.internal.sourceinfo._
import chisel3.internal.plugin.autoNameRecursively
private[chisel3] object SeqUtils {
/** Concatenates the data elements of the input sequence, in sequence order, together.
* The first element of the sequence forms the least significant bits, while the last element
* in the sequence forms the most significant bits.
*
* Equivalent to r(n-1) ## ... ## r(1) ## r(0).
* @note This returns a `0.U` if applied to a zero-element `Vec`.
*/
def asUInt[T <: Bits](in: Seq[T]): UInt = macro SourceInfoTransform.inArg
/** @group SourceInfoTransformMacros */
def do_asUInt[T <: Bits](in: Seq[T])(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt = {
if (in.isEmpty) {
0.U
} else if (in.tail.isEmpty) {
in.head.asUInt
} else {
val lo = autoNameRecursively("lo")(prefix("lo") {
asUInt(in.slice(0, in.length / 2))
})
val hi = autoNameRecursively("hi")(prefix("hi") {
asUInt(in.slice(in.length / 2, in.length))
})
hi ## lo
}
}
/** Outputs the number of elements that === true.B.
*/
def count(in: Seq[Bool]): UInt = macro SourceInfoTransform.inArg
/** @group SourceInfoTransformMacros */
def do_count(in: Seq[Bool])(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt = in.size match {
case 0 => 0.U
case 1 => in.head
case n =>
val sum = count(in.take(n / 2)) +& count(in.drop(n / 2))
sum(BigInt(n).bitLength - 1, 0)
}
/** Returns the data value corresponding to the first true predicate.
*/
def priorityMux[T <: Data](in: Seq[(Bool, T)]): T = macro SourceInfoTransform.inArg
/** @group SourceInfoTransformMacros */
def do_priorityMux[T <: Data](
in: Seq[(Bool, T)]
)(
implicit sourceInfo: SourceInfo,
compileOptions: CompileOptions
): T = {
if (in.size == 1) {
in.head._2
} else {
Mux(in.head._1, in.head._2, priorityMux(in.tail))
}
}
/** Returns the data value corresponding to the lone true predicate.
* This is elaborated to firrtl using a structure that should be optimized into and and/or tree.
*
* @note assumes exactly one true predicate, results undefined otherwise
* FixedPoint values or aggregates containing FixedPoint values cause this optimized structure to be lost
*/
def oneHotMux[T <: Data](in: Iterable[(Bool, T)]): T = macro SourceInfoTransform.inArg
/** @group SourceInfoTransformMacros */
def do_oneHotMux[T <: Data](
in: Iterable[(Bool, T)]
)(
implicit sourceInfo: SourceInfo,
compileOptions: CompileOptions
): T = {
if (in.tail.isEmpty) {
in.head._2
} else {
val output = cloneSupertype(in.toSeq.map { _._2 }, "oneHotMux")
def buildAndOrMultiplexor[TT <: Data](inputs: Iterable[(Bool, TT)]): T = {
val masked = for ((s, i) <- inputs) yield Mux(s, i.asUInt, 0.U)
masked.reduceLeft(_ | _).asTypeOf(output)
}
output match {
case _: SInt =>
// SInt's have to be managed carefully so sign extension works
val sInts: Iterable[(Bool, SInt)] = in.collect {
case (s: Bool, f: SInt) =>
(s, f.asTypeOf(output).asInstanceOf[SInt])
}
val masked = for ((s, i) <- sInts) yield Mux(s, i, 0.S)
masked.reduceLeft(_ | _).asTypeOf(output)
case _: FixedPoint =>
val (sels, possibleOuts) = in.toSeq.unzip
val (intWidths, binaryPoints) = in.toSeq.map {
case (_, o) =>
val fo = o.asInstanceOf[FixedPoint]
require(fo.binaryPoint.known, "Mux1H requires width/binary points to be defined")
(fo.getWidth - fo.binaryPoint.get, fo.binaryPoint.get)
}.unzip
if (intWidths.distinct.length == 1 && binaryPoints.distinct.length == 1) {
buildAndOrMultiplexor(in)
} else {
val maxIntWidth = intWidths.max
val maxBP = binaryPoints.max
val inWidthMatched = Seq.fill(intWidths.length)(Wire(FixedPoint((maxIntWidth + maxBP).W, maxBP.BP)))
inWidthMatched.zipWithIndex.foreach { case (e, idx) => e := possibleOuts(idx).asInstanceOf[FixedPoint] }
buildAndOrMultiplexor(sels.zip(inWidthMatched))
}
case agg: Aggregate =>
val allDefineWidth = in.forall { case (_, element) => element.widthOption.isDefined }
if (allDefineWidth) {
val out = Wire(agg)
val (sel, inData) = in.unzip
val inElts = inData.map(_.asInstanceOf[Aggregate].getElements)
// We want to iterate on the columns of inElts, so we transpose
out.getElements.zip(inElts.transpose).foreach {
case (outElt, elts) =>
outElt := oneHotMux(sel.zip(elts))
}
out.asInstanceOf[T]
} else {
throwException(s"Cannot Mux1H with aggregates with inferred widths")
}
case _ =>
buildAndOrMultiplexor(in)
}
}
}
}
|