aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/passes/RemoveCHIRRTL.scala')
-rw-r--r--src/main/scala/firrtl/passes/RemoveCHIRRTL.scala296
1 files changed, 296 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
new file mode 100644
index 00000000..47e6bbbc
--- /dev/null
+++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala
@@ -0,0 +1,296 @@
+/*
+Copyright (c) 2014 - 2016 The Regents of the University of
+California (Regents). All Rights Reserved. Redistribution and use in
+source and binary forms, with or without modification, are permitted
+provided that the following conditions are met:
+ * Redistributions of source code must retain the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer in the documentation and/or other materials
+ provided with the distribution.
+ * Neither the name of the Regents nor the names of its contributors
+ may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS,
+ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
+REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF
+ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION
+TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
+MODIFICATIONS.
+*/
+
+package firrtl.passes
+
+import com.typesafe.scalalogging.LazyLogging
+import java.nio.file.{Paths, Files}
+
+// Datastructures
+import scala.collection.mutable.LinkedHashMap
+import scala.collection.mutable.HashMap
+import scala.collection.mutable.HashSet
+import scala.collection.mutable.ArrayBuffer
+
+import firrtl._
+import firrtl.ir._
+import firrtl.Utils._
+import firrtl.Mappers._
+import firrtl.PrimOps._
+import firrtl.WrappedExpression._
+
+case class MPort( val name : String, val clk : Expression)
+case class MPorts( val readers : ArrayBuffer[MPort], val writers : ArrayBuffer[MPort], val readwriters : ArrayBuffer[MPort])
+case class DataRef( val exp : Expression, val male : String, val female : String, val mask : String, val rdwrite : Boolean)
+
+object RemoveCHIRRTL extends Pass {
+ def name = "Remove CHIRRTL"
+ var mname = ""
+ def create_exps (e:Expression) : Seq[Expression] = e match {
+ case (e:Mux) =>
+ val e1s = create_exps(e.tval)
+ val e2s = create_exps(e.fval)
+ (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2)))
+ case (e:ValidIf) =>
+ create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe))
+ case (e) => (e.tpe) match {
+ case (_:GroundType) => Seq(e)
+ case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) =>
+ exps ++ create_exps(SubField(e,f.name,f.tpe)))
+ case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) =>
+ exps ++ create_exps(SubIndex(e,i,t.tpe)))
+ case UnknownType => Seq(e)
+ }
+ }
+ def run (c:Circuit) : Circuit = {
+ def remove_chirrtl_m (m:Module) : Module = {
+ val hash = LinkedHashMap[String,MPorts]()
+ val repl = LinkedHashMap[String,DataRef]()
+ val raddrs = HashMap[String, Expression]()
+ val ut = UnknownType
+ val mport_types = LinkedHashMap[String,Type]()
+ val smems = HashSet[String]()
+ def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]())
+ def collect_smems_and_mports (s:Statement) : Statement = {
+ (s) match {
+ case (s:CDefMemory) if s.seq =>
+ smems += s.name
+ s
+ case (s:CDefMPort) => {
+ val mports = hash.getOrElse(s.mem,EMPs())
+ s.direction match {
+ case MRead => mports.readers += MPort(s.name,s.exps(1))
+ case MWrite => mports.writers += MPort(s.name,s.exps(1))
+ case MReadWrite => mports.readwriters += MPort(s.name,s.exps(1))
+ }
+ hash(s.mem) = mports
+ s
+ }
+ case (s) => s map (collect_smems_and_mports)
+ }
+ }
+ def collect_refs (s:Statement) : Statement = {
+ (s) match {
+ case (s:CDefMemory) => {
+ mport_types(s.name) = s.tpe
+ val stmts = ArrayBuffer[Statement]()
+ val taddr = UIntType(IntWidth(scala.math.max(1,ceil_log2(s.size))))
+ val tdata = s.tpe
+ def set_poison (vec:Seq[MPort],addr:String) : Unit = {
+ for (r <- vec ) {
+ stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),addr,taddr))
+ stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),"clk",taddr))
+ }
+ }
+ def set_enable (vec:Seq[MPort],en:String) : Unit = {
+ for (r <- vec ) {
+ stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),en,taddr),zero)
+ }}
+ def set_wmode (vec:Seq[MPort],wmode:String) : Unit = {
+ for (r <- vec) {
+ stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),wmode,taddr),zero)
+ }}
+ def set_write (vec:Seq[MPort],data:String,mask:String) : Unit = {
+ val tmask = create_mask(s.tpe)
+ for (r <- vec ) {
+ stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),data,tdata))
+ for (x <- create_exps(SubField(SubField(Reference(s.name,ut),r.name,ut),mask,tmask)) ) {
+ stmts += Connect(s.info,x,zero)
+ }}}
+ val rds = (hash.getOrElse(s.name,EMPs())).readers
+ set_poison(rds,"addr")
+ set_enable(rds,"en")
+ val wrs = (hash.getOrElse(s.name,EMPs())).writers
+ set_poison(wrs,"addr")
+ set_enable(wrs,"en")
+ set_write(wrs,"data","mask")
+ val rws = (hash.getOrElse(s.name,EMPs())).readwriters
+ set_poison(rws,"addr")
+ set_wmode(rws,"wmode")
+ set_enable(rws,"en")
+ set_write(rws,"wdata","wmask")
+ val read_l = if (s.seq) 1 else 0
+ val mem = DefMemory(s.info,s.name,s.tpe,s.size,1,read_l,rds.map(_.name),wrs.map(_.name),rws.map(_.name))
+ Block(Seq(mem,Block(stmts)))
+ }
+ case (s:CDefMPort) => {
+ mport_types(s.name) = mport_types(s.mem)
+ val addrs = ArrayBuffer[String]()
+ val clks = ArrayBuffer[String]()
+ val ens = ArrayBuffer[String]()
+ val masks = ArrayBuffer[String]()
+ s.direction match {
+ case MReadWrite => {
+ repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"rdata","wdata","wmask",true)
+ addrs += "addr"
+ clks += "clk"
+ ens += "en"
+ masks += "wmask"
+ }
+ case MWrite => {
+ repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","mask",false)
+ addrs += "addr"
+ clks += "clk"
+ ens += "en"
+ masks += "mask"
+ }
+ case MRead => {
+ repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","blah",false)
+ addrs += "addr"
+ clks += "clk"
+ s.exps(0) match {
+ case e: Reference if smems(s.mem) =>
+ raddrs(e.name) = SubField(SubField(Reference(s.mem,ut),s.name,ut),"en",ut)
+ case _ => ens += "en"
+ }
+ }
+ }
+ val stmts = ArrayBuffer[Statement]()
+ for (x <- addrs ) {
+ stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(0))
+ }
+ for (x <- clks ) {
+ stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(1))
+ }
+ for (x <- ens ) {
+ stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),one)
+ }
+ Block(stmts)
+ }
+ case (s) => s map (collect_refs)
+ }
+ }
+ def remove_chirrtl_s (s:Statement) : Statement = {
+ var has_write_mport = false
+ var has_read_mport: Option[Expression] = None
+ var has_readwrite_mport: Option[Expression] = None
+ def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = {
+ (e) match {
+ case (e:Reference) if repl contains e.name =>
+ val vt = repl(e.name)
+ g match {
+ case MALE => SubField(vt.exp,vt.male,e.tpe)
+ case FEMALE => {
+ has_write_mport = true
+ if (vt.rdwrite)
+ has_readwrite_mport = Some(SubField(vt.exp,"wmode",UIntType(IntWidth(1))))
+ SubField(vt.exp,vt.female,e.tpe)
+ }
+ }
+ case (e:Reference) if g == FEMALE && (raddrs contains e.name) =>
+ has_read_mport = Some(raddrs(e.name))
+ e
+ case (e:Reference) => e
+ case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.expr),remove_chirrtl_e(MALE)(e.index),e.tpe)
+ case (e) => e map (remove_chirrtl_e(g))
+ }
+ }
+ def get_mask (e:Expression) : Expression = {
+ (e map (get_mask)) match {
+ case (e:Reference) => {
+ if (repl.contains(e.name)) {
+ val vt = repl(e.name)
+ val t = create_mask(e.tpe)
+ SubField(vt.exp,vt.mask,t)
+ } else e
+ }
+ case (e) => e
+ }
+ }
+ (s) match {
+ case (s:DefNode) => {
+ val stmts = ArrayBuffer[Statement]()
+ val valuex = remove_chirrtl_e(MALE)(s.value)
+ stmts += DefNode(s.info,s.name,valuex)
+ has_read_mport match {
+ case None =>
+ case Some(en) => stmts += Connect(s.info,en,one)
+ }
+ if (stmts.size > 1) Block(stmts)
+ else stmts(0)
+ }
+ case (s:Connect) => {
+ val stmts = ArrayBuffer[Statement]()
+ val rocx = remove_chirrtl_e(MALE)(s.expr)
+ val locx = remove_chirrtl_e(FEMALE)(s.loc)
+ stmts += Connect(s.info,locx,rocx)
+ has_read_mport match {
+ case None =>
+ case Some(en) => stmts += Connect(s.info,en,one)
+ }
+ if (has_write_mport) {
+ val e = get_mask(s.loc)
+ for (x <- create_exps(e) ) {
+ stmts += Connect(s.info,x,one)
+ }
+ has_readwrite_mport match {
+ case None =>
+ case Some(wmode) => stmts += Connect(s.info,wmode,one)
+ }
+ }
+ if (stmts.size > 1) Block(stmts)
+ else stmts(0)
+ }
+ case (s:PartialConnect) => {
+ val stmts = ArrayBuffer[Statement]()
+ val locx = remove_chirrtl_e(FEMALE)(s.loc)
+ val rocx = remove_chirrtl_e(MALE)(s.expr)
+ stmts += PartialConnect(s.info,locx,rocx)
+ has_read_mport match {
+ case None =>
+ case Some(en) => stmts += Connect(s.info,en,one)
+ }
+ if (has_write_mport) {
+ val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default)
+ val locs = create_exps(get_mask(s.loc))
+ for (x <- ls ) {
+ val locx = locs(x._1)
+ stmts += Connect(s.info,locx,one)
+ }
+ has_readwrite_mport match {
+ case None =>
+ case Some(wmode) => stmts += Connect(s.info,wmode,one)
+ }
+ }
+ if (stmts.size > 1) Block(stmts)
+ else stmts(0)
+ }
+ case (s) => s map (remove_chirrtl_s) map (remove_chirrtl_e(MALE))
+ }
+ }
+ collect_smems_and_mports(m.body)
+ val sx = collect_refs(m.body)
+ Module(m.info,m.name, m.ports, remove_chirrtl_s(sx))
+ }
+ val modulesx = c.modules.map{ m => {
+ (m) match {
+ case (m:Module) => remove_chirrtl_m(m)
+ case (m:ExtModule) => m
+ }}}
+ Circuit(c.info,modulesx, c.main)
+ }
+}