summaryrefslogtreecommitdiff
path: root/core/src/main/scala/chisel3/SeqUtils.scala
blob: 460954bef02167f013af62d125d65c52e5158673 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// SPDX-License-Identifier: Apache-2.0

package chisel3

import chisel3.internal.{prefix, throwException}
import chisel3.internal.plugin.autoNameRecursively

private[chisel3] object SeqUtils {

  /** Concatenates the data elements of the input sequence, in sequence order, together.
    * The first element of the sequence forms the least significant bits, while the last element
    * in the sequence forms the most significant bits.
    *
    * Equivalent to r(n-1) ## ... ## r(1) ## r(0).
    * @note This returns a `0.U` if applied to a zero-element `Vec`.
    */
  def asUInt[T <: Bits](in: Seq[T]): UInt = {
    if (in.isEmpty) {
      0.U
    } else if (in.tail.isEmpty) {
      in.head.asUInt
    } else {
      val lo = autoNameRecursively("lo")(prefix("lo") {
        asUInt(in.slice(0, in.length / 2))
      })
      val hi = autoNameRecursively("hi")(prefix("hi") {
        asUInt(in.slice(in.length / 2, in.length))
      })
      hi ## lo
    }
  }

  /** Outputs the number of elements that === true.B.
    */
  def count(in: Seq[Bool]): UInt = in.size match {
    case 0 => 0.U
    case 1 => in.head
    case n =>
      val sum = count(in.take(n / 2)) +& count(in.drop(n / 2))
      sum(BigInt(n).bitLength - 1, 0)
  }

  /** Returns the data value corresponding to the first true predicate.
    */
  def priorityMux[T <: Data](
    in: Seq[(Bool, T)]
  ): T = {
    if (in.size == 1) {
      in.head._2
    } else {
      val r = in.view.reverse
      r.tail.foldLeft(r.head._2) {
        case (alt, (sel, elt)) => Mux(sel, elt, alt)
      }
    }
  }

  /** Returns the data value corresponding to the lone true predicate.
    * This is elaborated to firrtl using a structure that should be optimized into and and/or tree.
    */
  def oneHotMux[T <: Data](
    in: Iterable[(Bool, T)]
  ): T = {
    if (in.tail.isEmpty) {
      in.head._2
    } else {
      val output = cloneSupertype(in.toSeq.map { _._2 }, "oneHotMux")

      def buildAndOrMultiplexor[TT <: Data](inputs: Iterable[(Bool, TT)]): T = {
        val masked = for ((s, i) <- inputs) yield Mux(s, i.asUInt, 0.U)
        masked.reduceLeft(_ | _).asTypeOf(output)
      }

      output match {
        case _: SInt =>
          // SInt's have to be managed carefully so sign extension works

          val sInts: Iterable[(Bool, SInt)] = in.collect {
            case (s: Bool, f: SInt) =>
              (s, f.asTypeOf(output).asInstanceOf[SInt])
          }

          val masked = for ((s, i) <- sInts) yield Mux(s, i, 0.S)
          masked.reduceLeft(_ | _).asTypeOf(output)

        case agg: Aggregate =>
          val allDefineWidth = in.forall { case (_, element) => element.widthOption.isDefined }
          if (allDefineWidth) {
            val out = Wire(agg)
            val (sel, inData) = in.unzip
            val inElts = inData.map(_.asInstanceOf[Aggregate].getElements)
            // We want to iterate on the columns of inElts, so we transpose
            out.getElements.zip(inElts.transpose).foreach {
              case (outElt, elts) =>
                outElt := oneHotMux(sel.zip(elts))
            }
            out.asInstanceOf[T]
          } else {
            throwException(s"Cannot Mux1H with aggregates with inferred widths")
          }

        case _ =>
          buildAndOrMultiplexor(in)
      }
    }
  }
}