summaryrefslogtreecommitdiff
path: root/core/src/main/scala/chisel3/SeqUtils.scala
blob: 9d975349e783e21ffdba3873d6332091a1f83e64 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
// SPDX-License-Identifier: Apache-2.0

package chisel3

import chisel3.experimental.FixedPoint
import chisel3.internal.{prefix, throwException}

import scala.language.experimental.macros
import chisel3.internal.sourceinfo._
import chisel3.internal.plugin.autoNameRecursively

private[chisel3] object SeqUtils {

  /** Concatenates the data elements of the input sequence, in sequence order, together.
    * The first element of the sequence forms the least significant bits, while the last element
    * in the sequence forms the most significant bits.
    *
    * Equivalent to r(n-1) ## ... ## r(1) ## r(0).
    * @note This returns a `0.U` if applied to a zero-element `Vec`.
    */
  def asUInt[T <: Bits](in: Seq[T]): UInt = macro SourceInfoTransform.inArg

  /** @group SourceInfoTransformMacros */
  def do_asUInt[T <: Bits](in: Seq[T])(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt = {
    if (in.isEmpty) {
      0.U
    } else if (in.tail.isEmpty) {
      in.head.asUInt
    } else {
      val lo = autoNameRecursively("lo")(prefix("lo") {
        asUInt(in.slice(0, in.length / 2))
      })
      val hi = autoNameRecursively("hi")(prefix("hi") {
        asUInt(in.slice(in.length / 2, in.length))
      })
      hi ## lo
    }
  }

  /** Outputs the number of elements that === true.B.
    */
  def count(in: Seq[Bool]): UInt = macro SourceInfoTransform.inArg

  /** @group SourceInfoTransformMacros */
  def do_count(in: Seq[Bool])(implicit sourceInfo: SourceInfo, compileOptions: CompileOptions): UInt = in.size match {
    case 0 => 0.U
    case 1 => in.head
    case n =>
      val sum = count(in.take(n / 2)) +& count(in.drop(n / 2))
      sum(BigInt(n).bitLength - 1, 0)
  }

  /** Returns the data value corresponding to the first true predicate.
    */
  def priorityMux[T <: Data](in: Seq[(Bool, T)]): T = macro SourceInfoTransform.inArg

  /** @group SourceInfoTransformMacros */
  def do_priorityMux[T <: Data](
    in: Seq[(Bool, T)]
  )(
    implicit sourceInfo: SourceInfo,
    compileOptions:      CompileOptions
  ): T = {
    if (in.size == 1) {
      in.head._2
    } else {
      val r = in.view.reverse
      r.tail.foldLeft(r.head._2) {
        case (alt, (sel, elt)) => Mux(sel, elt, alt)
      }
    }
  }

  /** Returns the data value corresponding to the lone true predicate.
    * This is elaborated to firrtl using a structure that should be optimized into and and/or tree.
    *
    * @note assumes exactly one true predicate, results undefined otherwise
    *       FixedPoint values or aggregates containing FixedPoint values cause this optimized structure to be lost
    */
  def oneHotMux[T <: Data](in: Iterable[(Bool, T)]): T = macro SourceInfoTransform.inArg

  /** @group SourceInfoTransformMacros */
  def do_oneHotMux[T <: Data](
    in: Iterable[(Bool, T)]
  )(
    implicit sourceInfo: SourceInfo,
    compileOptions:      CompileOptions
  ): T = {
    if (in.tail.isEmpty) {
      in.head._2
    } else {
      val output = cloneSupertype(in.toSeq.map { _._2 }, "oneHotMux")

      def buildAndOrMultiplexor[TT <: Data](inputs: Iterable[(Bool, TT)]): T = {
        val masked = for ((s, i) <- inputs) yield Mux(s, i.asUInt, 0.U)
        masked.reduceLeft(_ | _).asTypeOf(output)
      }

      output match {
        case _: SInt =>
          // SInt's have to be managed carefully so sign extension works

          val sInts: Iterable[(Bool, SInt)] = in.collect {
            case (s: Bool, f: SInt) =>
              (s, f.asTypeOf(output).asInstanceOf[SInt])
          }

          val masked = for ((s, i) <- sInts) yield Mux(s, i, 0.S)
          masked.reduceLeft(_ | _).asTypeOf(output)

        case _: FixedPoint =>
          val (sels, possibleOuts) = in.toSeq.unzip

          val (intWidths, binaryPoints) = in.toSeq.map {
            case (_, o) =>
              val fo = o.asInstanceOf[FixedPoint]
              require(fo.binaryPoint.known, "Mux1H requires width/binary points to be defined")
              (fo.getWidth - fo.binaryPoint.get, fo.binaryPoint.get)
          }.unzip

          if (intWidths.distinct.length == 1 && binaryPoints.distinct.length == 1) {
            buildAndOrMultiplexor(in)
          } else {
            val maxIntWidth = intWidths.max
            val maxBP = binaryPoints.max
            val inWidthMatched = Seq.fill(intWidths.length)(Wire(FixedPoint((maxIntWidth + maxBP).W, maxBP.BP)))
            inWidthMatched.zipWithIndex.foreach { case (e, idx) => e := possibleOuts(idx).asInstanceOf[FixedPoint] }
            buildAndOrMultiplexor(sels.zip(inWidthMatched))
          }

        case agg: Aggregate =>
          val allDefineWidth = in.forall { case (_, element) => element.widthOption.isDefined }
          if (allDefineWidth) {
            val out = Wire(agg)
            val (sel, inData) = in.unzip
            val inElts = inData.map(_.asInstanceOf[Aggregate].getElements)
            // We want to iterate on the columns of inElts, so we transpose
            out.getElements.zip(inElts.transpose).foreach {
              case (outElt, elts) =>
                outElt := oneHotMux(sel.zip(elts))
            }
            out.asInstanceOf[T]
          } else {
            throwException(s"Cannot Mux1H with aggregates with inferred widths")
          }

        case _ =>
          buildAndOrMultiplexor(in)
      }
    }
  }
}