aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
blob: 869cac3a7b96d2e1288e6a0298840508e6efcf72 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
// SPDX-License-Identifier: Apache-2.0

package firrtl.passes

import scala.collection.mutable
import firrtl.PrimOps._
import firrtl.ir._
import firrtl._
import firrtl.Mappers._
import firrtl.Utils.{field_type, max, module_type, sub_type, throwInternalError}
import firrtl.options.Dependency

/** Replaces FixedType with SIntType, and correctly aligns all binary points
  */
object ConvertFixedToSInt extends Pass {

  override def prerequisites =
    Seq(
      Dependency(PullMuxes),
      Dependency(ReplaceAccesses),
      Dependency(ExpandConnects),
      Dependency(RemoveAccesses),
      Dependency[ExpandWhensAndCheck],
      Dependency[RemoveIntervals]
    ) ++ firrtl.stage.Forms.Deduped

  override def invalidates(a: Transform) = false

  def alignArg(e: Expression, point: BigInt): Expression = e.tpe match {
    case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p)
      if ((point - p) > 0) {
        DoPrim(Shl, Seq(e), Seq(point - p), UnknownType)
      } else if (point - p < 0) {
        DoPrim(Shr, Seq(e), Seq(p - point), UnknownType)
      } else e
    case FixedType(w, p) => throwInternalError(s"alignArg: shouldn't be here - $e")
    case _               => e
  }
  def calcPoint(es: Seq[Expression]): BigInt =
    es.map(_.tpe match {
      case FixedType(IntWidth(w), IntWidth(p)) => p
      case _                                   => BigInt(0)
    }).reduce(max(_, _))
  def toSIntType(t: Type): Type = t match {
    case FixedType(IntWidth(w), IntWidth(p)) => SIntType(IntWidth(w))
    case FixedType(w, p)                     => throwInternalError(s"toSIntType: shouldn't be here - $t")
    case _                                   => t.map(toSIntType)
  }
  def run(c: Circuit): Circuit = {
    val moduleTypes = mutable.HashMap[String, Type]()
    def onModule(m: DefModule): DefModule = {
      val types = mutable.HashMap[String, Type]()
      def updateExpType(e: Expression): Expression = e match {
        case DoPrim(Mul, args, consts, tpe)                        => e.map(updateExpType)
        case DoPrim(AsFixedPoint, args, consts, tpe)               => DoPrim(AsSInt, args, Seq.empty, tpe).map(updateExpType)
        case DoPrim(IncP, args, consts, tpe)                       => DoPrim(Shl, args, consts, tpe).map(updateExpType)
        case DoPrim(DecP, args, consts, tpe)                       => DoPrim(Shr, args, consts, tpe).map(updateExpType)
        case DoPrim(SetP, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p).map(updateExpType)
        case DoPrim(op, args, consts, tpe) =>
          val point = calcPoint(args)
          val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType)
          newExp.map(updateExpType) match {
            case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe)
            case e                                       => e
          }
        case Mux(cond, tval, fval, tpe) =>
          val point = calcPoint(Seq(tval, fval))
          val newExp = Mux(cond, alignArg(tval, point), alignArg(fval, point), UnknownType)
          newExp.map(updateExpType)
        case e: UIntLiteral => e
        case e: SIntLiteral => e
        case _ =>
          e.map(updateExpType) match {
            case ValidIf(cond, value, tpe)      => ValidIf(cond, value, value.tpe)
            case WRef(name, tpe, k, g)          => WRef(name, types(name), k, g)
            case WSubField(exp, name, tpe, g)   => WSubField(exp, name, field_type(exp.tpe, name), g)
            case WSubIndex(exp, value, tpe, g)  => WSubIndex(exp, value, sub_type(exp.tpe), g)
            case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g)
          }
      }
      def updateStmtType(s: Statement): Statement = s match {
        case DefRegister(info, name, tpe, clock, reset, init) =>
          val newType = toSIntType(tpe)
          types(name) = newType
          DefRegister(info, name, newType, clock, reset, init).map(updateExpType)
        case DefWire(info, name, tpe) =>
          val newType = toSIntType(tpe)
          types(name) = newType
          DefWire(info, name, newType)
        case DefNode(info, name, value) =>
          val newValue = updateExpType(value)
          val newType = toSIntType(newValue.tpe)
          types(name) = newType
          DefNode(info, name, newValue)
        case DefMemory(info, name, dt, depth, wL, rL, rs, ws, rws, ruw) =>
          val newStmt = DefMemory(info, name, toSIntType(dt), depth, wL, rL, rs, ws, rws, ruw)
          val newType = MemPortUtils.memType(newStmt)
          types(name) = newType
          newStmt
        case WDefInstance(info, name, module, tpe) =>
          val newType = moduleTypes(module)
          types(name) = newType
          WDefInstance(info, name, module, newType)
        case Connect(info, loc, exp) =>
          val point = calcPoint(Seq(loc))
          val newExp = alignArg(exp, point)
          Connect(info, loc, newExp).map(updateExpType)
        case PartialConnect(info, loc, exp) =>
          val point = calcPoint(Seq(loc))
          val newExp = alignArg(exp, point)
          PartialConnect(info, loc, newExp).map(updateExpType)
        // check Connect case, need to shl
        case s => (s.map(updateStmtType)).map(updateExpType)
      }

      m.ports.foreach(p => types(p.name) = p.tpe)
      m match {
        case Module(info, name, ports, body) => Module(info, name, ports, updateStmtType(body))
        case m: ExtModule => m
      }
    }

    val newModules = for (m <- c.modules) yield {
      val newPorts = m.ports.map(p => Port(p.info, p.name, p.direction, toSIntType(p.tpe)))
      m match {
        case Module(info, name, ports, body) => Module(info, name, newPorts, body)
        case ext: ExtModule => ext.copy(ports = newPorts)
      }
    }
    newModules.foreach(m => moduleTypes(m.name) = module_type(m))

    /* @todo This should be moved outside */
    (firrtl.passes.InferTypes).run(Circuit(c.info, newModules.map(onModule(_)), c.main))
  }
}

// vim: set ts=4 sw=4 et: