aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/SplitExpressions.scala
blob: 973e1be930a8a104a5d3c040655613ccceef4891 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
package firrtl
package passes

import firrtl.Mappers.{ExpMap, StmtMap}
import firrtl.Utils.{tpe, kind, gender, info}
import firrtl.ir._
import scala.collection.mutable


// Splits compound expressions into simple expressions
//  and named intermediate nodes
object SplitExpressions extends Pass {
   def name = "Split Expressions"
   private def onModule(m: Module): Module = {
      val namespace = Namespace(m)
      def onStmt(s: Statement): Statement = {
         val v = mutable.ArrayBuffer[Statement]()
         // Splits current expression if needed
         // Adds named temporaries to v
         def split(e: Expression): Expression = e match {
            case e: DoPrim => {
               val name = namespace.newTemp
               v += DefNode(info(s), name, e)
               WRef(name, tpe(e), kind(e), gender(e))
            }
            case e: Mux => {
               val name = namespace.newTemp
               v += DefNode(info(s), name, e)
               WRef(name, tpe(e), kind(e), gender(e))
            }
            case e: ValidIf => {
               val name = namespace.newTemp
               v += DefNode(info(s), name, e)
               WRef(name, tpe(e), kind(e), gender(e))
            }
            case e => e
         }
         // Recursive. Splits compound nodes
         def onExp(e: Expression): Expression = {
            val ex = e map onExp
            ex match {
               case (_: DoPrim) => ex map split
               case v => v
            }
         }
         val x = s map onExp
         x match {
            case x: Begin => x map onStmt
            case EmptyStmt => x
            case x => {
               v += x
               if (v.size > 1) Begin(v.toVector)
               else v(0)
            }
         }
      }
      Module(m.info, m.name, m.ports, onStmt(m.body))
   }
   def run(c: Circuit): Circuit = {
      val modulesx = c.modules.map( _ match {
         case m: Module => onModule(m)
         case m: ExtModule => m
      })
      Circuit(c.info, modulesx, c.main)
   }
}