aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/Passes.scala
diff options
context:
space:
mode:
authorazidar2016-01-31 10:09:18 -0800
committerazidar2016-02-09 18:57:06 -0800
commit2bd423fa061fb3e0973fa83e98f2877fd4616746 (patch)
tree51ca630714df2d011ea86b1d94d85eb6182c6f9d /src/main/scala/firrtl/Passes.scala
parent1613a7127dd74427786baa093b0dde5a76265b78 (diff)
Added remove accesses
Diffstat (limited to 'src/main/scala/firrtl/Passes.scala')
-rw-r--r--src/main/scala/firrtl/Passes.scala165
1 files changed, 162 insertions, 3 deletions
diff --git a/src/main/scala/firrtl/Passes.scala b/src/main/scala/firrtl/Passes.scala
index b0b93db9..67fefde6 100644
--- a/src/main/scala/firrtl/Passes.scala
+++ b/src/main/scala/firrtl/Passes.scala
@@ -9,9 +9,12 @@ import Utils._
import DebugUtils._
import PrimOps._
+case class Location(base:Expression,guard:Expression)
+
@deprecated("This object will be replaced with package firrtl.passes")
object Passes extends LazyLogging {
+
// TODO Perhaps we should get rid of Logger since this map would be nice
////private val defaultLogger = Logger()
//private def mapNameToPass = Map[String, Circuit => Circuit] (
@@ -43,14 +46,16 @@ object Passes extends LazyLogging {
inferTypes _,
resolveGenders _,
pullMuxes _,
- expandConnects _)
+ expandConnects _,
+ removeAccesses _)
val names = Seq(
"To Working IR",
"Resolve Kinds",
"Infer Types",
"Resolve Genders",
"Pull Muxes",
- "Expand Connects")
+ "Expand Connects",
+ "Remove Accesses")
var c_BANG = c
(names, passes).zipped.foreach {
(n,p) => {
@@ -145,7 +150,7 @@ object Passes extends LazyLogging {
// ------------------ Utils -------------------------
- val width_name_hash = Map[String,Int]()
+ val width_name_hash = HashMap[String,Int]()
def set_type (s:Stmt,t:Type) : Stmt = {
s match {
case s:DefWire => DefWire(s.info,s.name,t)
@@ -548,6 +553,160 @@ object Passes extends LazyLogging {
}
Circuit(c.info,modulesx,c.main)
}
+ // ===============================================
+
+
+
+ // ============ REMOVE ACCESSES ==================
+ // ---------------- UTILS ------------------
+
+ def get_locations (e:Expression) : Seq[Location] = {
+ e match {
+ case (e:WRef) => create_exps(e).map(Location(_,one))
+ case (e:WSubIndex) => {
+ val ls = get_locations(e.exp)
+ val start = get_point(e)
+ val end = start + get_size(tpe(e))
+ val stride = get_size(tpe(e.exp))
+ val lsx = ArrayBuffer[Location]()
+ var c = 0
+ for (i <- 0 until ls.size) {
+ if (((i % stride) >= start) & ((i % stride) < end)) {
+ lsx += ls(i)
+ }
+ }
+ lsx
+ }
+ case (e:WSubField) => {
+ val ls = get_locations(e.exp)
+ val start = get_point(e)
+ val end = start + get_size(tpe(e))
+ val stride = get_size(tpe(e.exp))
+ val lsx = ArrayBuffer[Location]()
+ var c = 0
+ for (i <- 0 until ls.size) {
+ if (((i % stride) >= start) & ((i % stride) < end)) {
+ lsx += ls(i)
+ }
+ }
+ lsx
+ }
+ case (e:WSubAccess) => {
+ val ls = get_locations(e.exp)
+ val stride = get_size(tpe(e))
+ val wrap = tpe(e.exp).asInstanceOf[VectorType].size
+ val lsx = ArrayBuffer[Location]()
+ var c = 0
+ for (i <- 0 until ls.size) {
+ if ((c % wrap) == 0) { c = 0 }
+ val basex = ls(i).base
+ val guardx = AND(ls(i).guard,EQV(uint(c),e.index))
+ lsx += Location(basex,guardx)
+ if ((i + 1) % stride == 0) {
+ c = c + 1
+ }
+ }
+ lsx
+ }
+ }
+ }
+
+ def has_access (e:Expression) : Boolean = {
+ var ret:Boolean = false
+ def rec_has_access (e:Expression) : Expression = {
+ e match {
+ case (e:WSubAccess) => {
+ ret = true
+ e
+ }
+ case (e) => eMap(rec_has_access _,e)
+ }
+ }
+ rec_has_access(e)
+ ret
+ }
+
+ def removeAccesses (c:Circuit) = {
+ def remove_m (m:InModule) : InModule = {
+ val sh = sym_hash
+ mname = m.name
+ def remove_s (s:Stmt) : Stmt = {
+ val stmts = ArrayBuffer[Stmt]()
+ def create_temp (e:Expression) : Expression = {
+ val n = firrtl_gensym("GEN",sh)
+ stmts += DefWire(info(s),n,tpe(e))
+ WRef(n,tpe(e),kind(e),gender(e))
+ }
+ def remove_e (e:Expression) : Expression = { //NOT RECURSIVE (except primops) INTENTIONALLY!
+ e match {
+ case (e:DoPrim) => eMap(remove_e,e)
+ case (e:Mux) => eMap(remove_e,e)
+ case (e:ValidIf) => eMap(remove_e,e)
+ case (e:SIntValue) => e
+ case (e:UIntValue) => e
+ case e => {
+ if (has_access(e)) {
+ val rs = get_locations(e)
+ val foo = rs.find(x => {x.guard != one})
+ foo match {
+ case None => error("Shouldn't be here")
+ case foo:Some[Location] => {
+ val temp = create_temp(e)
+ val temps = create_exps(temp)
+ def get_temp (i:Int) = temps(i % temps.size)
+ (rs,0 until rs.size).zipped.foreach {
+ (x,i) => {
+ if (i < temps.size) {
+ stmts += Connect(info(s),get_temp(i),x.base)
+ } else {
+ stmts += Conditionally(info(s),x.guard,Connect(info(s),get_temp(i),x.base),Empty())
+ }
+ }
+ }
+ temp
+ }
+ }
+ } else { e}
+ }
+ }
+ }
+
+ val sx = s match {
+ case (s:Connect) => {
+ if (has_access(s.loc)) {
+ val ls = get_locations(s.loc)
+ val locx =
+ if (ls.size == 1 & ls(0).guard == one) s.loc
+ else {
+ val temp = create_temp(s.loc)
+ for (x <- ls) {
+ stmts += Conditionally(s.info,x.guard,Connect(s.info,x.base,temp),Empty())
+ }
+ temp
+ }
+ Connect(s.info,locx,remove_e(s.exp))
+ } else {
+ Connect(s.info,s.loc,remove_e(s.exp))
+ }
+ }
+ case (s) => sMap(remove_s,eMap(remove_e,s))
+ }
+ stmts += sx
+ if (stmts.size != 1) Begin(stmts) else stmts(0)
+ }
+ InModule(m.info,m.name,m.ports,remove_s(m.body))
+ }
+
+ val modulesx = c.modules.map{
+ m => {
+ m match {
+ case (m:ExModule) => m
+ case (m:InModule) => remove_m(m)
+ }
+ }
+ }
+ Circuit(c.info,modulesx,c.main)
+ }
/** INFER TYPES
*