diff options
| author | Adam Izraelevitz | 2016-06-13 10:52:04 -0700 |
|---|---|---|
| committer | GitHub | 2016-06-13 10:52:04 -0700 |
| commit | 860b04eff7758c3efae09fb0b5b908abad3b4593 (patch) | |
| tree | abb1a5b81693da98818a9bee79c776d421a820d1 /src | |
| parent | 83f53a3a0cdcfc7537e923b827ab820205025d45 (diff) | |
| parent | c13ad522ae226bad57f341d0f93865194fb0bf76 (diff) | |
Merge pull request #191 from ucb-bar/ir-cleanup-fix
Ir cleanup fix
Diffstat (limited to 'src')
38 files changed, 1236 insertions, 1193 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 8efd010c..49bf9395 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -30,6 +30,7 @@ package firrtl import com.typesafe.scalalogging.LazyLogging import java.io.Writer +import firrtl.ir._ import Utils._ import firrtl.passes._ diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index f871c82a..18074f7c 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -39,6 +39,8 @@ import Utils._ import firrtl.Serialize._ import firrtl.Mappers._ import firrtl.passes._ +import firrtl.PrimOps._ +import firrtl.ir._ import WrappedExpression._ // Datastructures import scala.collection.mutable.LinkedHashMap @@ -56,7 +58,7 @@ object FIRRTLEmitter extends Emitter { case class VIndent() case object VRandom extends Expression { - def tpe = UIntType(UnknownWidth()) + def tpe = UIntType(UnknownWidth) } class VerilogEmitter extends Emitter { val tab = " " @@ -76,7 +78,7 @@ class VerilogEmitter extends Emitter { e.tpe match { case (t:UIntType) => e case (t:SIntType) => Seq("$signed(",e,")") - case (t:ClockType) => e + case ClockType => e } } (x) match { @@ -89,9 +91,8 @@ class VerilogEmitter extends Emitter { case (e:WSubField) => w.get.write(LowerTypes.loweredName(e)) case (e:WSubAccess) => w.get.write(LowerTypes.loweredName(e.exp) + "[" + LowerTypes.loweredName(e.index) + "]") case (e:WSubIndex) => w.get.write(e.serialize) - case (_:UIntValue|_:SIntValue) => v_print(e) + case (e:Literal) => v_print(e) case VRandom => w.get.write("$random") - } } case (t:Type) => { @@ -99,7 +100,7 @@ class VerilogEmitter extends Emitter { case (_:UIntType|_:SIntType) => val wx = long_BANG(t) - 1 if (wx > 0) w.get.write("[" + wx + ":0]") else w.get.write("") - case (t:ClockType) => w.get.write("") + case ClockType => w.get.write("") case (t:VectorType) => emit2(t.tpe, top + 1) w.get.write("[" + (t.size - 1) + ":0]") @@ -108,8 +109,8 @@ class VerilogEmitter extends Emitter { } case (p:Direction) => { p match { - case INPUT => w.get.write("input") - case OUTPUT => w.get.write("output") + case Input => w.get.write("input") + case Output => w.get.write("output") } } case (s:String) => w.get.write(s) @@ -126,11 +127,11 @@ class VerilogEmitter extends Emitter { //;------------- PASS ----------------- def v_print (e:Expression) = { e match { - case (e:UIntValue) => { + case (e:UIntLiteral) => { val str = e.value.toString(16) w.get.write(long_BANG(tpe(e)).toString + "'h" + str) } - case (e:SIntValue) => { + case (e:SIntLiteral) => { val str = e.value.toString(16) w.get.write(long_BANG(tpe(e)).toString + "'sh" + str) } @@ -164,8 +165,8 @@ class VerilogEmitter extends Emitter { def c1 () : Int = doprim.consts(1).toInt def checkArgumentLegality(e: Expression) = e match { - case _: UIntValue => - case _: SIntValue => + case _: UIntLiteral => + case _: SIntLiteral => case _: WRef => case _: WSubField => case _ => throw new EmitterException(s"Can't emit ${e.getClass.getName} as PrimOp argument") @@ -174,20 +175,20 @@ class VerilogEmitter extends Emitter { doprim.args foreach checkArgumentLegality doprim.op match { - case ADD_OP => Seq(cast_if(a0())," + ", cast_if(a1())) - case ADDW_OP => Seq(cast_if(a0())," + ", cast_if(a1())) - case SUB_OP => Seq(cast_if(a0())," - ", cast_if(a1())) - case SUBW_OP => Seq(cast_if(a0())," - ", cast_if(a1())) - case MUL_OP => Seq(cast_if(a0())," * ", cast_if(a1()) ) - case DIV_OP => Seq(cast_if(a0())," / ", cast_if(a1()) ) - case REM_OP => Seq(cast_if(a0())," % ", cast_if(a1()) ) - case LESS_OP => Seq(cast_if(a0())," < ", cast_if(a1())) - case LESS_EQ_OP => Seq(cast_if(a0())," <= ", cast_if(a1())) - case GREATER_OP => Seq(cast_if(a0())," > ", cast_if(a1())) - case GREATER_EQ_OP => Seq(cast_if(a0())," >= ", cast_if(a1())) - case EQUAL_OP => Seq(cast_if(a0())," == ", cast_if(a1())) - case NEQUAL_OP => Seq(cast_if(a0())," != ", cast_if(a1())) - case PAD_OP => { + case Add => Seq(cast_if(a0())," + ", cast_if(a1())) + case Addw => Seq(cast_if(a0())," + ", cast_if(a1())) + case Sub => Seq(cast_if(a0())," - ", cast_if(a1())) + case Subw => Seq(cast_if(a0())," - ", cast_if(a1())) + case Mul => Seq(cast_if(a0())," * ", cast_if(a1()) ) + case Div => Seq(cast_if(a0())," / ", cast_if(a1()) ) + case Rem => Seq(cast_if(a0())," % ", cast_if(a1()) ) + case Lt => Seq(cast_if(a0())," < ", cast_if(a1())) + case Leq => Seq(cast_if(a0())," <= ", cast_if(a1())) + case Gt => Seq(cast_if(a0())," > ", cast_if(a1())) + case Geq => Seq(cast_if(a0())," >= ", cast_if(a1())) + case Eq => Seq(cast_if(a0())," == ", cast_if(a1())) + case Neq => Seq(cast_if(a0())," != ", cast_if(a1())) + case Pad => { val w = long_BANG(tpe(a0())) val diff = (c0() - w) if (w == 0) Seq(a0()) @@ -201,70 +202,70 @@ class VerilogEmitter extends Emitter { case (t) => Seq("{{", diff, "'d0}, ", a0(), "}") } } - case AS_UINT_OP => Seq("$unsigned(",a0(),")") - case AS_SINT_OP => Seq("$signed(",a0(),")") - case AS_CLOCK_OP => Seq("$unsigned(",a0(),")") - case DSHLW_OP => Seq(cast(a0())," << ", a1()) - case DYN_SHIFT_LEFT_OP => Seq(cast(a0())," << ", a1()) - case DYN_SHIFT_RIGHT_OP => { + case AsUInt => Seq("$unsigned(",a0(),")") + case AsSInt => Seq("$signed(",a0(),")") + case AsClock => Seq("$unsigned(",a0(),")") + case Dshlw => Seq(cast(a0())," << ", a1()) + case Dshl => Seq(cast(a0())," << ", a1()) + case Dshr => { (doprim.tpe) match { case (t:SIntType) => Seq(cast(a0())," >>> ",a1()) case (t) => Seq(cast(a0())," >> ",a1()) } } - case SHLW_OP => Seq(cast(a0())," << ", c0()) - case SHIFT_LEFT_OP => Seq(cast(a0())," << ",c0()) - case SHIFT_RIGHT_OP => { + case Shlw => Seq(cast(a0())," << ", c0()) + case Shl => Seq(cast(a0())," << ",c0()) + case Shr => { if (c0 >= long_BANG(tpe(a0))) error("Verilog emitter does not support SHIFT_RIGHT >= arg width") Seq(a0(),"[", long_BANG(tpe(a0())) - 1,":",c0(),"]") } - case NEG_OP => Seq("-{",cast(a0()),"}") - case CONVERT_OP => { + case Neg => Seq("-{",cast(a0()),"}") + case Cvt => { tpe(a0()) match { case (t:UIntType) => Seq("{1'b0,",cast(a0()),"}") case (t:SIntType) => Seq(cast(a0())) } } - case NOT_OP => Seq("~ ",a0()) - case AND_OP => Seq(cast_as(a0())," & ", cast_as(a1())) - case OR_OP => Seq(cast_as(a0())," | ", cast_as(a1())) - case XOR_OP => Seq(cast_as(a0())," ^ ", cast_as(a1())) - case AND_REDUCE_OP => { + case Not => Seq("~ ",a0()) + case And => Seq(cast_as(a0())," & ", cast_as(a1())) + case Or => Seq(cast_as(a0())," | ", cast_as(a1())) + case Xor => Seq(cast_as(a0())," ^ ", cast_as(a1())) + case Andr => { val v = ArrayBuffer[Seq[Any]]() for (b <- 0 until long_BANG(doprim.tpe).toInt) { v += Seq(cast(a0()),"[",b,"]") } v.reduce(_ + " & " + _) } - case OR_REDUCE_OP => { + case Orr => { val v = ArrayBuffer[Seq[Any]]() for (b <- 0 until long_BANG(doprim.tpe).toInt) { v += Seq(cast(a0()),"[",b,"]") } v.reduce(_ + " | " + _) } - case XOR_REDUCE_OP => { + case Xorr => { val v = ArrayBuffer[Seq[Any]]() for (b <- 0 until long_BANG(doprim.tpe).toInt) { v += Seq(cast(a0()),"[",b,"]") } v.reduce(_ + " ^ " + _) } - case CONCAT_OP => Seq("{",cast(a0()),",",cast(a1()),"}") - case BITS_SELECT_OP => { + case Cat => Seq("{",cast(a0()),",",cast(a1()),"}") + case Bits => { // If selecting zeroth bit and single-bit wire, just emit the wire if (c0() == 0 && c1() == 0 && long_BANG(tpe(a0())) == 1) Seq(a0()) else if (c0() == c1()) Seq(a0(),"[",c0(),"]") else Seq(a0(),"[",c0(),":",c1(),"]") } - case HEAD_OP => { + case Head => { val w = long_BANG(tpe(a0())) val high = w - 1 val low = w - c0() Seq(a0(),"[",high,":",low,"]") } - case TAIL_OP => { + case Tail => { val w = long_BANG(tpe(a0())) val low = w - c0() - 1 Seq(a0(),"[",low,":",0,"]") @@ -272,18 +273,18 @@ class VerilogEmitter extends Emitter { } } - def emit_verilog (m:InModule) : Module = { + def emit_verilog (m:Module) : DefModule = { mname = m.name val netlist = LinkedHashMap[WrappedExpression,Expression]() - val simlist = ArrayBuffer[Stmt]() + val simlist = ArrayBuffer[Statement]() val namespace = Namespace(m) - def build_netlist (s:Stmt) : Stmt = { + def build_netlist (s:Statement) : Statement = { s match { - case (s:Connect) => netlist(s.loc) = s.exp + case (s:Connect) => netlist(s.loc) = s.expr case (s:IsInvalid) => { val n = namespace.newTemp - val e = wref(n,tpe(s.exp)) - netlist(s.exp) = e + val e = wref(n,tpe(s.expr)) + netlist(s.expr) = e } case (s:Conditionally) => simlist += s case (s:DefNode) => { @@ -379,10 +380,10 @@ class VerilogEmitter extends Emitter { } def initialize (e:Expression) = initials += Seq(e," = ",rand_string(tpe(e)),";") def initialize_mem(s: DefMemory) = { - val index = WRef("initvar", s.data_type, ExpKind(), UNKNOWNGENDER) - val rstring = rand_string(s.data_type) + val index = WRef("initvar", s.dataType, ExpKind(), UNKNOWNGENDER) + val rstring = rand_string(s.dataType) initials += Seq("for (initvar = 0; initvar < ", s.depth, "; initvar = initvar+1)") - initials += Seq(tab, WSubAccess(wref(s.name, s.data_type), index, s.data_type, FEMALE), " = ", rstring,";") + initials += Seq(tab, WSubAccess(wref(s.name, s.dataType), index, s.dataType, FEMALE), " = ", rstring,";") } def instantiate (n:String,m:String,es:Seq[Expression]) = { instdeclares += Seq(m," ",n," (") @@ -438,8 +439,8 @@ class VerilogEmitter extends Emitter { def build_ports () = { (m.ports,0 until m.ports.size).zipped.foreach{(p,i) => { p.direction match { - case INPUT => portdefs += Seq(p.direction," ",p.tpe," ",p.name) - case OUTPUT => { + case Input => portdefs += Seq(p.direction," ",p.tpe," ",p.name) + case Output => { portdefs += Seq(p.direction," ",p.tpe," ",p.name) val ex = WRef(p.name,p.tpe,PortKind(),FEMALE) assign(ex,netlist(ex)) @@ -447,9 +448,9 @@ class VerilogEmitter extends Emitter { } }} } - def build_streams (s:Stmt) : Stmt = { + def build_streams (s:Statement) : Statement = { s match { - case (s:Empty) => s + case EmptyStmt => s case (s:Connect) => s case (s:DefWire) => declare("wire",s.name,s.tpe) @@ -462,16 +463,10 @@ class VerilogEmitter extends Emitter { initialize(e) } case (s:IsInvalid) => { - val wref = netlist(s.exp).as[WRef].get - declare("reg",wref.name,tpe(s.exp)) + val wref = netlist(s.expr).as[WRef].get + declare("reg",wref.name,tpe(s.expr)) initialize(wref) } - case (s:DefPoison) => { - val n = s.name - val e = wref(n,s.tpe) - declare("reg",n,tpe(e)) - initialize(e) - } case (s:DefNode) => { declare("wire",s.name,tpe(s.value)) assign(WRef(s.name,tpe(s.value),NodeKind(),MALE),s.value) @@ -491,7 +486,7 @@ class VerilogEmitter extends Emitter { WSubField(x,f,t2,UNKNOWNGENDER) } - declare("reg",s.name,VectorType(s.data_type,s.depth)) + declare("reg",s.name,VectorType(s.dataType,s.depth)) initialize_mem(s) for (r <- s.readers ) { val data = mem_exp(r,"data") @@ -507,12 +502,12 @@ class VerilogEmitter extends Emitter { //; Read port assign(addr,netlist(addr)) //;Connects value to m.r.addr assign(en,netlist(en)) //;Connects value to m.r.en - val addrx = delay(addr,s.read_latency,clk) - val enx = delay(en,s.read_latency,clk) - val mem_port = WSubAccess(mem,addrx,s.data_type,UNKNOWNGENDER) - val depthValue = UIntValue(s.depth, IntWidth(BigInt(s.depth).bitLength)) - val garbageGuard = DoPrim(GREATER_EQ_OP, Seq(addrx, depthValue), Seq(), UnknownType()) - val garbageMux = Mux(garbageGuard, VRandom, mem_port, UnknownType()) + val addrx = delay(addr,s.readLatency,clk) + val enx = delay(en,s.readLatency,clk) + val mem_port = WSubAccess(mem,addrx,s.dataType,UNKNOWNGENDER) + val depthValue = UIntLiteral(s.depth, IntWidth(BigInt(s.depth).bitLength)) + val garbageGuard = DoPrim(Geq, Seq(addrx, depthValue), Seq(), UnknownType) + val garbageMux = Mux(garbageGuard, VRandom, mem_port, UnknownType) synSimAssign(data, mem_port, garbageMux) } @@ -535,11 +530,11 @@ class VerilogEmitter extends Emitter { assign(mask,netlist(mask)) assign(en,netlist(en)) - val datax = delay(data,s.write_latency - 1,clk) - val addrx = delay(addr,s.write_latency - 1,clk) - val maskx = delay(mask,s.write_latency - 1,clk) - val enx = delay(en,s.write_latency - 1,clk) - val mem_port = WSubAccess(mem,addrx,s.data_type,UNKNOWNGENDER) + val datax = delay(data,s.writeLatency - 1,clk) + val addrx = delay(addr,s.writeLatency - 1,clk) + val maskx = delay(mask,s.writeLatency - 1,clk) + val enx = delay(en,s.writeLatency - 1,clk) + val mem_port = WSubAccess(mem,addrx,s.dataType,UNKNOWNGENDER) update(mem_port,datax,clk,AND(enx,maskx)) } @@ -569,18 +564,18 @@ class VerilogEmitter extends Emitter { assign(wmode,netlist(wmode)) //; Delay new signals by latency - val raddrx = delay(addr,s.read_latency,clk) - val waddrx = delay(addr,s.write_latency - 1,clk) - val enx = delay(en,s.write_latency - 1,clk) - val rmodx = delay(wmode,s.write_latency - 1,clk) - val datax = delay(data,s.write_latency - 1,clk) - val maskx = delay(mask,s.write_latency - 1,clk) + val raddrx = delay(addr,s.readLatency,clk) + val waddrx = delay(addr,s.writeLatency - 1,clk) + val enx = delay(en,s.writeLatency - 1,clk) + val rmodx = delay(wmode,s.writeLatency - 1,clk) + val datax = delay(data,s.writeLatency - 1,clk) + val maskx = delay(mask,s.writeLatency - 1,clk) //; Write - val rmem_port = WSubAccess(mem,raddrx,s.data_type,UNKNOWNGENDER) + val rmem_port = WSubAccess(mem,raddrx,s.dataType,UNKNOWNGENDER) assign(rdata,rmem_port) - val wmem_port = WSubAccess(mem,waddrx,s.data_type,UNKNOWNGENDER) + val wmem_port = WSubAccess(mem,waddrx,s.dataType,UNKNOWNGENDER) val tempName = namespace.newTemp val tempExp = AND(enx,maskx) @@ -655,8 +650,8 @@ class VerilogEmitter extends Emitter { this.w = Some(w) for (m <- c.modules) { m match { - case (m:InModule) => emit_verilog(m) - case (m:ExModule) => false + case (m:Module) => emit_verilog(m) + case (m:ExtModule) => false } } } diff --git a/src/main/scala/firrtl/IR.scala b/src/main/scala/firrtl/IR.scala deleted file mode 100644 index c762b198..00000000 --- a/src/main/scala/firrtl/IR.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* -Copyright (c) 2014 - 2016 The Regents of the University of -California (Regents). All Rights Reserved. Redistribution and use in -source and binary forms, with or without modification, are permitted -provided that the following conditions are met: - * Redistributions of source code must retain the above - copyright notice, this list of conditions and the following - two paragraphs of disclaimer. - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - two paragraphs of disclaimer in the documentation and/or other materials - provided with the distribution. - * Neither the name of the Regents nor the names of its contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. -IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, -SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, -ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF -REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF -ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION -TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR -MODIFICATIONS. -*/ - -package firrtl - -import scala.collection.Seq - -// Should this be defined elsewhere? -/* -Structure containing source locator information. -Member of most Stmt case classes. -*/ -trait Info -case object NoInfo extends Info { - override def toString(): String = "" -} -case class FileInfo(info: StringLit) extends Info { - override def toString(): String = " @[" + info.serialize + "]" -} - -class FIRRTLException(str: String) extends Exception(str) - -trait AST { - def serialize: String = firrtl.Serialize.serialize(this) -} - -trait HasName { - val name: String -} -trait HasInfo { - val info: Info -} -trait IsDeclaration extends HasName with HasInfo - -case class StringLit(array: Array[Byte]) extends AST - -trait PrimOp extends AST -case object ADD_OP extends PrimOp -case object SUB_OP extends PrimOp -case object MUL_OP extends PrimOp -case object DIV_OP extends PrimOp -case object REM_OP extends PrimOp -case object LESS_OP extends PrimOp -case object LESS_EQ_OP extends PrimOp -case object GREATER_OP extends PrimOp -case object GREATER_EQ_OP extends PrimOp -case object EQUAL_OP extends PrimOp -case object NEQUAL_OP extends PrimOp -case object PAD_OP extends PrimOp -case object AS_UINT_OP extends PrimOp -case object AS_SINT_OP extends PrimOp -case object AS_CLOCK_OP extends PrimOp -case object SHIFT_LEFT_OP extends PrimOp -case object SHIFT_RIGHT_OP extends PrimOp -case object DYN_SHIFT_LEFT_OP extends PrimOp -case object DYN_SHIFT_RIGHT_OP extends PrimOp -case object CONVERT_OP extends PrimOp -case object NEG_OP extends PrimOp -case object NOT_OP extends PrimOp -case object AND_OP extends PrimOp -case object OR_OP extends PrimOp -case object XOR_OP extends PrimOp -case object AND_REDUCE_OP extends PrimOp -case object OR_REDUCE_OP extends PrimOp -case object XOR_REDUCE_OP extends PrimOp -case object CONCAT_OP extends PrimOp -case object BITS_SELECT_OP extends PrimOp -case object HEAD_OP extends PrimOp -case object TAIL_OP extends PrimOp - -trait Expression extends AST { - def tpe: Type -} -case class Ref(name: String, tpe: Type) extends Expression with HasName -case class SubField(exp: Expression, name: String, tpe: Type) extends Expression with HasName -case class SubIndex(exp: Expression, value: Int, tpe: Type) extends Expression -case class SubAccess(exp: Expression, index: Expression, tpe: Type) extends Expression -case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression -case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression -case class UIntValue(value: BigInt, width: Width) extends Expression { - def tpe = UIntType(width) -} -case class SIntValue(value: BigInt, width: Width) extends Expression { - def tpe = SIntType(width) -} -case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression - -trait Stmt extends AST -case class DefWire(info: Info, name: String, tpe: Type) extends Stmt with IsDeclaration -case class DefPoison(info: Info, name: String, tpe: Type) extends Stmt with IsDeclaration -case class DefRegister(info: Info, name: String, tpe: Type, clock: Expression, reset: Expression, init: Expression) extends Stmt with IsDeclaration -case class DefInstance(info: Info, name: String, module: String) extends Stmt with IsDeclaration -case class DefMemory(info: Info, name: String, data_type: Type, depth: Int, write_latency: Int, - read_latency: Int, readers: Seq[String], writers: Seq[String], readwriters: Seq[String]) extends Stmt with IsDeclaration -case class DefNode(info: Info, name: String, value: Expression) extends Stmt with IsDeclaration -case class Conditionally(info: Info, pred: Expression, conseq: Stmt, alt: Stmt) extends Stmt with HasInfo -case class Begin(stmts: Seq[Stmt]) extends Stmt -case class BulkConnect(info: Info, loc: Expression, exp: Expression) extends Stmt with HasInfo -case class Connect(info: Info, loc: Expression, exp: Expression) extends Stmt with HasInfo -case class IsInvalid(info: Info, exp: Expression) extends Stmt with HasInfo -case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Stmt with HasInfo -case class Print(info: Info, string: StringLit, args: Seq[Expression], clk: Expression, en: Expression) extends Stmt with HasInfo -case class Empty() extends Stmt - -trait Width extends AST { - def +(x: Width): Width = (this, x) match { - case (a: IntWidth, b: IntWidth) => IntWidth(a.width + b.width) - case _ => UnknownWidth() - } - def -(x: Width): Width = (this, x) match { - case (a: IntWidth, b: IntWidth) => IntWidth(a.width - b.width) - case _ => UnknownWidth() - } - def max(x: Width): Width = (this, x) match { - case (a: IntWidth, b: IntWidth) => IntWidth(a.width max b.width) - case _ => UnknownWidth() - } - def min(x: Width): Width = (this, x) match { - case (a: IntWidth, b: IntWidth) => IntWidth(a.width min b.width) - case _ => UnknownWidth() - } -} -case class IntWidth(width: BigInt) extends Width -case class UnknownWidth() extends Width - -trait Flip extends AST -case object DEFAULT extends Flip -case object REVERSE extends Flip - -case class Field(name: String, flip: Flip, tpe: Type) extends AST with HasName - -trait Type extends AST -case class UIntType(width: Width) extends Type -case class SIntType(width: Width) extends Type -case class BundleType(fields: Seq[Field]) extends Type -case class VectorType(tpe: Type, size: Int) extends Type -case class ClockType() extends Type -case class UnknownType() extends Type - -trait Direction extends AST -case object INPUT extends Direction -case object OUTPUT extends Direction - -case class Port(info: Info, name: String, direction: Direction, tpe: Type) extends AST with IsDeclaration - -trait Module extends AST with IsDeclaration { - val info : Info - val name : String - val ports : Seq[Port] -} -case class InModule(info: Info, name: String, ports: Seq[Port], body: Stmt) extends Module -case class ExModule(info: Info, name: String, ports: Seq[Port]) extends Module - -case class Circuit(info: Info, modules: Seq[Module], main: String) extends AST with HasInfo - diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 3db83406..33cb70db 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -30,6 +30,7 @@ package firrtl import com.typesafe.scalalogging.LazyLogging import java.io.Writer import firrtl.passes.Pass +import firrtl.ir.Circuit // =========================================== // Utility Traits @@ -67,7 +68,7 @@ class Chisel3ToHighFirrtl () extends Transform with SimpleRun { run(circuit, passSeq) } -// Converts from the bare intermediate representation (IR.scala) +// Converts from the bare intermediate representation (ir.scala) // to a working representation (WIR.scala) class IRToWorkingIR () extends Transform with SimpleRun { val passSeq = Seq(passes.ToWorkingIR) diff --git a/src/main/scala/firrtl/Mappers.scala b/src/main/scala/firrtl/Mappers.scala index 335bc4aa..c00ca855 100644 --- a/src/main/scala/firrtl/Mappers.scala +++ b/src/main/scala/firrtl/Mappers.scala @@ -27,46 +27,47 @@ MODIFICATIONS. package firrtl +import firrtl.ir._ + // TODO: Implement remaining mappers and recursive mappers object Mappers { // ********** Stmt Mappers ********** private trait StmtMagnet { - def map(stmt: Stmt): Stmt + def map(stmt: Statement): Statement } private object StmtMagnet { - implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet { - override def map(stmt: Stmt): Stmt = { + implicit def forStmt(f: Statement => Statement) = new StmtMagnet { + override def map(stmt: Statement): Statement = { stmt match { case s: Conditionally => Conditionally(s.info, s.pred, f(s.conseq), f(s.alt)) case s: Begin => Begin(s.stmts.map(f)) - case s: Stmt => s + case s: Statement => s } } } implicit def forExp(f: Expression => Expression) = new StmtMagnet { - override def map(stmt: Stmt): Stmt = { + override def map(stmt: Statement): Statement = { stmt match { case s: DefRegister => DefRegister(s.info, s.name, s.tpe, f(s.clock), f(s.reset), f(s.init)) case s: DefNode => DefNode(s.info, s.name, f(s.value)) - case s: Connect => Connect(s.info, f(s.loc), f(s.exp)) - case s: BulkConnect => BulkConnect(s.info, f(s.loc), f(s.exp)) + case s: Connect => Connect(s.info, f(s.loc), f(s.expr)) + case s: PartialConnect => PartialConnect(s.info, f(s.loc), f(s.expr)) case s: Conditionally => Conditionally(s.info, f(s.pred), s.conseq, s.alt) - case s: IsInvalid => IsInvalid(s.info, f(s.exp)) + case s: IsInvalid => IsInvalid(s.info, f(s.expr)) case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) case s: Print => Print(s.info, s.string, s.args.map(f), f(s.clk), f(s.en)) case s: CDefMPort => CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps.map(f),s.direction) - case s: Stmt => s + case s: Statement => s } } } implicit def forType(f: Type => Type) = new StmtMagnet { - override def map(stmt: Stmt) : Stmt = { + override def map(stmt: Statement) : Statement = { stmt match { - case s:DefPoison => DefPoison(s.info,s.name,f(s.tpe)) case s:DefWire => DefWire(s.info,s.name,f(s.tpe)) case s:DefRegister => DefRegister(s.info,s.name,f(s.tpe),s.clock,s.reset,s.init) - case s:DefMemory => DefMemory(s.info,s.name, f(s.data_type), s.depth, s.write_latency, s.read_latency, s.readers, s.writers, s.readwriters) + case s:DefMemory => DefMemory(s.info,s.name, f(s.dataType), s.depth, s.writeLatency, s.readLatency, s.readers, s.writers, s.readwriters) case s:CDefMemory => CDefMemory(s.info,s.name, f(s.tpe), s.size, s.seq) case s:CDefMPort => CDefMPort(s.info,s.name, f(s.tpe), s.mem, s.exps,s.direction) case s => s @@ -74,12 +75,11 @@ object Mappers { } } implicit def forString(f: String => String) = new StmtMagnet { - override def map(stmt: Stmt): Stmt = { + override def map(stmt: Statement): Statement = { stmt match { case s: DefWire => DefWire(s.info,f(s.name),s.tpe) - case s: DefPoison => DefPoison(s.info,f(s.name),s.tpe) case s: DefRegister => DefRegister(s.info,f(s.name), s.tpe, s.clock, s.reset, s.init) - case s: DefMemory => DefMemory(s.info,f(s.name), s.data_type, s.depth, s.write_latency, s.read_latency, s.readers, s.writers, s.readwriters) + case s: DefMemory => DefMemory(s.info,f(s.name), s.dataType, s.depth, s.writeLatency, s.readLatency, s.readers, s.writers, s.readwriters) case s: DefNode => DefNode(s.info,f(s.name),s.value) case s: DefInstance => DefInstance(s.info,f(s.name), s.module) case s: WDefInstance => WDefInstance(s.info,f(s.name), s.module,s.tpe) @@ -90,9 +90,9 @@ object Mappers { } } } - implicit class StmtMap(stmt: Stmt) { + implicit class StmtMap(stmt: Statement) { // Using implicit types to allow overloading of function type to map, see StmtMagnet above - def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt) + def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Statement = magnet(f).map(stmt) } // ********** Expression Mappers ********** @@ -103,9 +103,9 @@ object Mappers { implicit def forExp(f: Expression => Expression) = new ExpMagnet { override def map(exp: Expression): Expression = { exp match { - case e: SubField => SubField(f(e.exp), e.name, e.tpe) - case e: SubIndex => SubIndex(f(e.exp), e.value, e.tpe) - case e: SubAccess => SubAccess(f(e.exp), f(e.index), e.tpe) + case e: SubField => SubField(f(e.expr), e.name, e.tpe) + case e: SubIndex => SubIndex(f(e.expr), e.value, e.tpe) + case e: SubAccess => SubAccess(f(e.expr), f(e.index), e.tpe) case e: Mux => Mux(f(e.cond), f(e.tval), f(e.fval), e.tpe) case e: ValidIf => ValidIf(f(e.cond), f(e.value), e.tpe) case e: DoPrim => DoPrim(e.op, e.args.map(f), e.consts, e.tpe) @@ -133,8 +133,8 @@ object Mappers { implicit def forWidth(f: Width => Width) = new ExpMagnet { override def map(exp: Expression): Expression = { exp match { - case e: UIntValue => UIntValue(e.value,f(e.width)) - case e: SIntValue => SIntValue(e.value,f(e.width)) + case e: UIntLiteral => UIntLiteral(e.value,f(e.width)) + case e: SIntLiteral => SIntLiteral(e.value,f(e.width)) case e => e } } @@ -196,36 +196,36 @@ object Mappers { // ********** Module Mappers ********** private trait ModuleMagnet { - def map(module: Module): Module + def map(module: DefModule): DefModule } private object ModuleMagnet { - implicit def forStmt(f: Stmt => Stmt) = new ModuleMagnet { - override def map(module: Module): Module = { + implicit def forStmt(f: Statement => Statement) = new ModuleMagnet { + override def map(module: DefModule): DefModule = { module match { - case m: InModule => InModule(m.info, m.name, m.ports, f(m.body)) - case m: ExModule => m + case m: Module => Module(m.info, m.name, m.ports, f(m.body)) + case m: ExtModule => m } } } implicit def forPorts(f: Port => Port) = new ModuleMagnet { - override def map(module: Module): Module = { + override def map(module: DefModule): DefModule = { module match { - case m: InModule => InModule(m.info, m.name, m.ports.map(f), m.body) - case m: ExModule => ExModule(m.info, m.name, m.ports.map(f)) + case m: Module => Module(m.info, m.name, m.ports.map(f), m.body) + case m: ExtModule => ExtModule(m.info, m.name, m.ports.map(f)) } } } implicit def forString(f: String => String) = new ModuleMagnet { - override def map(module: Module): Module = { + override def map(module: DefModule): DefModule = { module match { - case m: InModule => InModule(m.info, f(m.name), m.ports, m.body) - case m: ExModule => ExModule(m.info, f(m.name), m.ports) + case m: Module => Module(m.info, f(m.name), m.ports, m.body) + case m: ExtModule => ExtModule(m.info, f(m.name), m.ports) } } } } - implicit class ModuleMap(module: Module) { - def map[T](f: T => T)(implicit magnet: (T => T) => ModuleMagnet): Module = magnet(f).map(module) + implicit class ModuleMap(module: DefModule) { + def map[T](f: T => T)(implicit magnet: (T => T) => ModuleMagnet): DefModule = magnet(f).map(module) } } diff --git a/src/main/scala/firrtl/Namespace.scala b/src/main/scala/firrtl/Namespace.scala index 01cc59fd..7d4758c5 100644 --- a/src/main/scala/firrtl/Namespace.scala +++ b/src/main/scala/firrtl/Namespace.scala @@ -28,6 +28,7 @@ MODIFICATIONS. package firrtl import scala.collection.mutable.HashSet +import firrtl.ir._ import Mappers._ class Namespace private { @@ -59,10 +60,10 @@ object Namespace { def apply(): Namespace = new Namespace // Initializes a namespace from a Module - def apply(m: Module): Namespace = { + def apply(m: DefModule): Namespace = { val namespace = new Namespace - def buildNamespaceStmt(s: Stmt): Stmt = + def buildNamespaceStmt(s: Statement): Statement = s map buildNamespaceStmt match { case dec: IsDeclaration => namespace.namespace += dec.name @@ -77,7 +78,7 @@ object Namespace { } m.ports map buildNamespacePort m match { - case in: InModule => buildNamespaceStmt(in.body) + case in: Module => buildNamespaceStmt(in.body) case _ => // Do nothing } diff --git a/src/main/scala/firrtl/Parser.scala b/src/main/scala/firrtl/Parser.scala index e6489b2f..dc8d6875 100644 --- a/src/main/scala/firrtl/Parser.scala +++ b/src/main/scala/firrtl/Parser.scala @@ -29,6 +29,7 @@ package firrtl import org.antlr.v4.runtime._; import org.antlr.v4.runtime.atn._; import com.typesafe.scalalogging.LazyLogging +import firrtl.ir._ import Utils.{time} import antlr._ @@ -40,7 +41,7 @@ case class InvalidEscapeCharException(message: String) extends ParserException(m object Parser extends LazyLogging { - /** Takes Iterator over lines of FIRRTL, returns AST (root node is Circuit) */ + /** Takes Iterator over lines of FIRRTL, returns FirrtlNode (root node is Circuit) */ def parse(lines: Iterator[String], infoMode: InfoMode = UseInfo): Circuit = { val fixedInput = time("Translator") { Translator.addBrackets(lines) } val antlrStream = new ANTLRInputStream(fixedInput.result) diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index ed3752f9..7e8d77db 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -27,57 +27,86 @@ MODIFICATIONS. package firrtl -import com.typesafe.scalalogging.LazyLogging +import firrtl.ir._ -import Utils._ +import com.typesafe.scalalogging.LazyLogging +/** Definitions and Utility functions for [[ir.PrimOp]]s */ object PrimOps extends LazyLogging { + /** Addition */ + case object Add extends PrimOp { override def toString = "add" } + /** Subtraction */ + case object Sub extends PrimOp { override def toString = "sub" } + /** Multiplication */ + case object Mul extends PrimOp { override def toString = "mul" } + /** Division */ + case object Div extends PrimOp { override def toString = "div" } + /** Remainder */ + case object Rem extends PrimOp { override def toString = "rem" } + /** Less Than */ + case object Lt extends PrimOp { override def toString = "lt" } + /** Less Than Or Equal To */ + case object Leq extends PrimOp { override def toString = "leq" } + /** Greater Than */ + case object Gt extends PrimOp { override def toString = "gt" } + /** Greater Than Or Equal To */ + case object Geq extends PrimOp { override def toString = "geq" } + /** Equal To */ + case object Eq extends PrimOp { override def toString = "eq" } + /** Not Equal To */ + case object Neq extends PrimOp { override def toString = "neq" } + /** Padding */ + case object Pad extends PrimOp { override def toString = "pad" } + /** Interpret As UInt */ + case object AsUInt extends PrimOp { override def toString = "asUInt" } + /** Interpret As SInt */ + case object AsSInt extends PrimOp { override def toString = "asSInt" } + /** Interpret As Clock */ + case object AsClock extends PrimOp { override def toString = "asClock" } + /** Static Shift Left */ + case object Shl extends PrimOp { override def toString = "shl" } + /** Static Shift Right */ + case object Shr extends PrimOp { override def toString = "shr" } + /** Dynamic Shift Left */ + case object Dshl extends PrimOp { override def toString = "dshl" } + /** Dynamic Shift Right */ + case object Dshr extends PrimOp { override def toString = "dshr" } + /** Arithmetic Convert to Signed */ + case object Cvt extends PrimOp { override def toString = "cvt" } + /** Negate */ + case object Neg extends PrimOp { override def toString = "neg" } + /** Bitwise Complement */ + case object Not extends PrimOp { override def toString = "not" } + /** Bitwise And */ + case object And extends PrimOp { override def toString = "and" } + /** Bitwise Or */ + case object Or extends PrimOp { override def toString = "or" } + /** Bitwise Exclusive Or */ + case object Xor extends PrimOp { override def toString = "xor" } + /** Bitwise And Reduce */ + case object Andr extends PrimOp { override def toString = "andr" } + /** Bitwise Or Reduce */ + case object Orr extends PrimOp { override def toString = "orr" } + /** Bitwise Exclusive Or Reduce */ + case object Xorr extends PrimOp { override def toString = "xorr" } + /** Concatenate */ + case object Cat extends PrimOp { override def toString = "cat" } + /** Bit Extraction */ + case object Bits extends PrimOp { override def toString = "bits" } + /** Head */ + case object Head extends PrimOp { override def toString = "head" } + /** Tail */ + case object Tail extends PrimOp { override def toString = "tail" } - private val mapPrimOp2String = Map[PrimOp, String]( - ADD_OP -> "add", - SUB_OP -> "sub", - MUL_OP -> "mul", - DIV_OP -> "div", - REM_OP -> "rem", - LESS_OP -> "lt", - LESS_EQ_OP -> "leq", - GREATER_OP -> "gt", - GREATER_EQ_OP -> "geq", - EQUAL_OP -> "eq", - NEQUAL_OP -> "neq", - PAD_OP -> "pad", - AS_UINT_OP -> "asUInt", - AS_SINT_OP -> "asSInt", - AS_CLOCK_OP -> "asClock", - SHIFT_LEFT_OP -> "shl", - SHIFT_RIGHT_OP -> "shr", - DYN_SHIFT_LEFT_OP -> "dshl", - DYN_SHIFT_RIGHT_OP -> "dshr", - NEG_OP -> "neg", - CONVERT_OP -> "cvt", - NOT_OP -> "not", - AND_OP -> "and", - OR_OP -> "or", - XOR_OP -> "xor", - AND_REDUCE_OP -> "andr", - OR_REDUCE_OP -> "orr", - XOR_REDUCE_OP -> "xorr", - CONCAT_OP -> "cat", - BITS_SELECT_OP -> "bits", - HEAD_OP -> "head", - TAIL_OP -> "tail", - - //This are custom, we need to refactor to enable easily extending FIRRTL with custom primops - ADDW_OP -> "addw", - SUBW_OP -> "subw" - ) - lazy val listing: Seq[String] = PrimOps.mapPrimOp2String.map { case (k,v) => v } toSeq - private val mapString2PrimOp = mapPrimOp2String.map(_.swap) - def fromString(op: String): PrimOp = mapString2PrimOp(op) + private lazy val builtinPrimOps: Seq[PrimOp] = + Seq(Add, Sub, Mul, Div, Rem, Lt, Leq, Gt, Geq, Eq, Neq, Pad, AsUInt, AsSInt, AsClock, Shl, Shr, + Dshl, Dshr, Neg, Cvt, Not, And, Or, Xor, Andr, Orr, Xorr, Cat, Bits, Head, Tail) + private lazy val strToPrimOp: Map[String, PrimOp] = builtinPrimOps map (op => op.toString -> op) toMap - implicit class PrimOpImplicits(op: PrimOp){ - def getString(): String = mapPrimOp2String(op) - } + /** Seq of String representations of [[ir.PrimOp]]s */ + lazy val listing: Seq[String] = builtinPrimOps map (_.toString) + /** Gets the corresponding [[ir.PrimOp]] from its String representation */ + def fromString(op: String): PrimOp = strToPrimOp(op) // Borrowed from Stanza implementation def set_primop_type (e:DoPrim) : DoPrim = { @@ -90,283 +119,283 @@ object PrimOps extends LazyLogging { val o = e.op val a = e.args val c = e.consts - def t1 () = tpe(a(0)) - def t2 () = tpe(a(1)) - def t3 () = tpe(a(2)) - def w1 () = widthBANG(tpe(a(0))) - def w2 () = widthBANG(tpe(a(1))) - def w3 () = widthBANG(tpe(a(2))) + def t1 () = a(0).tpe + def t2 () = a(1).tpe + def t3 () = a(2).tpe + def w1 () = Utils.widthBANG(a(0).tpe) + def w2 () = Utils.widthBANG(a(1).tpe) + def w3 () = Utils.widthBANG(a(2).tpe) def c1 () = IntWidth(c(0)) def c2 () = IntWidth(c(1)) o match { - case ADD_OP => { + case Add => { val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => UIntType(PLUS(MAX(w1(),w2()),ONE)) - case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) - case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) - case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) - case (t1, t2) => UnknownType() + case (t1:UIntType, t2:UIntType) => UIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case SUB_OP => { + case Sub => { val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) - case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) - case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) - case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) - case (t1, t2) => UnknownType() + case (t1:UIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case MUL_OP => { + case Mul => { val t = (t1(),t2()) match { case (t1:UIntType, t2:UIntType) => UIntType(PLUS(w1(),w2())) case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),w2())) case (t1:SIntType, t2:UIntType) => SIntType(PLUS(w1(),w2())) case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),w2())) - case (t1, t2) => UnknownType() + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case DIV_OP => { + case Div => { val t = (t1(),t2()) match { case (t1:UIntType, t2:UIntType) => UIntType(w1()) - case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),ONE)) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) case (t1:SIntType, t2:UIntType) => SIntType(w1()) - case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),ONE)) - case (t1, t2) => UnknownType() + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case REM_OP => { + case Rem => { val t = (t1(),t2()) match { case (t1:UIntType, t2:UIntType) => UIntType(MIN(w1(),w2())) case (t1:UIntType, t2:SIntType) => UIntType(MIN(w1(),w2())) - case (t1:SIntType, t2:UIntType) => SIntType(MIN(w1(),PLUS(w2(),ONE))) + case (t1:SIntType, t2:UIntType) => SIntType(MIN(w1(),PLUS(w2(),Utils.ONE))) case (t1:SIntType, t2:SIntType) => SIntType(MIN(w1(),w2())) - case (t1, t2) => UnknownType() + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case LESS_OP => { + case Lt => { val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => BoolType() - case (t1:SIntType, t2:UIntType) => BoolType() - case (t1:UIntType, t2:SIntType) => BoolType() - case (t1:SIntType, t2:SIntType) => BoolType() - case (t1, t2) => UnknownType() + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case LESS_EQ_OP => { + case Leq => { val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => BoolType() - case (t1:SIntType, t2:UIntType) => BoolType() - case (t1:UIntType, t2:SIntType) => BoolType() - case (t1:SIntType, t2:SIntType) => BoolType() - case (t1, t2) => UnknownType() + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case GREATER_OP => { + case Gt => { val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => BoolType() - case (t1:SIntType, t2:UIntType) => BoolType() - case (t1:UIntType, t2:SIntType) => BoolType() - case (t1:SIntType, t2:SIntType) => BoolType() - case (t1, t2) => UnknownType() + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case GREATER_EQ_OP => { + case Geq => { val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => BoolType() - case (t1:SIntType, t2:UIntType) => BoolType() - case (t1:UIntType, t2:SIntType) => BoolType() - case (t1:SIntType, t2:SIntType) => BoolType() - case (t1, t2) => UnknownType() + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case EQUAL_OP => { + case Eq => { val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => BoolType() - case (t1:SIntType, t2:UIntType) => BoolType() - case (t1:UIntType, t2:SIntType) => BoolType() - case (t1:SIntType, t2:SIntType) => BoolType() - case (t1, t2) => UnknownType() + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case NEQUAL_OP => { + case Neq => { val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => BoolType() - case (t1:SIntType, t2:UIntType) => BoolType() - case (t1:UIntType, t2:SIntType) => BoolType() - case (t1:SIntType, t2:SIntType) => BoolType() - case (t1, t2) => UnknownType() + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case PAD_OP => { + case Pad => { val t = (t1()) match { case (t1:UIntType) => UIntType(MAX(w1(),c1())) case (t1:SIntType) => SIntType(MAX(w1(),c1())) - case (t1) => UnknownType() + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case AS_UINT_OP => { + case AsUInt => { val t = (t1()) match { case (t1:UIntType) => UIntType(w1()) case (t1:SIntType) => UIntType(w1()) - case (t1:ClockType) => UIntType(ONE) - case (t1) => UnknownType() + case ClockType => UIntType(Utils.ONE) + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case AS_SINT_OP => { + case AsSInt => { val t = (t1()) match { case (t1:UIntType) => SIntType(w1()) case (t1:SIntType) => SIntType(w1()) - case (t1:ClockType) => SIntType(ONE) - case (t1) => UnknownType() + case ClockType => SIntType(Utils.ONE) + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case AS_CLOCK_OP => { + case AsClock => { val t = (t1()) match { - case (t1:UIntType) => ClockType() - case (t1:SIntType) => ClockType() - case (t1:ClockType) => ClockType() - case (t1) => UnknownType() + case (t1:UIntType) => ClockType + case (t1:SIntType) => ClockType + case ClockType => ClockType + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case SHIFT_LEFT_OP => { + case Shl => { val t = (t1()) match { case (t1:UIntType) => UIntType(PLUS(w1(),c1())) case (t1:SIntType) => SIntType(PLUS(w1(),c1())) - case (t1) => UnknownType() + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case SHIFT_RIGHT_OP => { + case Shr => { val t = (t1()) match { - case (t1:UIntType) => UIntType(MAX(MINUS(w1(),c1()),ONE)) - case (t1:SIntType) => SIntType(MAX(MINUS(w1(),c1()),ONE)) - case (t1) => UnknownType() + case (t1:UIntType) => UIntType(MAX(MINUS(w1(),c1()),Utils.ONE)) + case (t1:SIntType) => SIntType(MAX(MINUS(w1(),c1()),Utils.ONE)) + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case DYN_SHIFT_LEFT_OP => { + case Dshl => { val t = (t1()) match { case (t1:UIntType) => UIntType(PLUS(w1(),POW(w2()))) case (t1:SIntType) => SIntType(PLUS(w1(),POW(w2()))) - case (t1) => UnknownType() + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case DYN_SHIFT_RIGHT_OP => { + case Dshr => { val t = (t1()) match { case (t1:UIntType) => UIntType(w1()) case (t1:SIntType) => SIntType(w1()) - case (t1) => UnknownType() + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case CONVERT_OP => { + case Cvt => { val t = (t1()) match { - case (t1:UIntType) => SIntType(PLUS(w1(),ONE)) + case (t1:UIntType) => SIntType(PLUS(w1(),Utils.ONE)) case (t1:SIntType) => SIntType(w1()) - case (t1) => UnknownType() + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case NEG_OP => { + case Neg => { val t = (t1()) match { - case (t1:UIntType) => SIntType(PLUS(w1(),ONE)) - case (t1:SIntType) => SIntType(PLUS(w1(),ONE)) - case (t1) => UnknownType() + case (t1:UIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case NOT_OP => { + case Not => { val t = (t1()) match { case (t1:UIntType) => UIntType(w1()) case (t1:SIntType) => UIntType(w1()) - case (t1) => UnknownType() + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case AND_OP => { + case And => { val t = (t1(),t2()) match { case (_:SIntType|_:UIntType, _:SIntType|_:UIntType) => UIntType(MAX(w1(),w2())) - case (t1,t2) => UnknownType() + case (t1,t2) => UnknownType } DoPrim(o,a,c,t) } - case OR_OP => { + case Or => { val t = (t1(),t2()) match { case (_:SIntType|_:UIntType, _:SIntType|_:UIntType) => UIntType(MAX(w1(),w2())) - case (t1,t2) => UnknownType() + case (t1,t2) => UnknownType } DoPrim(o,a,c,t) } - case XOR_OP => { + case Xor => { val t = (t1(),t2()) match { case (_:SIntType|_:UIntType, _:SIntType|_:UIntType) => UIntType(MAX(w1(),w2())) - case (t1,t2) => UnknownType() + case (t1,t2) => UnknownType } DoPrim(o,a,c,t) } - case AND_REDUCE_OP => { + case Andr => { val t = (t1()) match { - case (_:UIntType|_:SIntType) => BoolType() - case (t1) => UnknownType() + case (_:UIntType|_:SIntType) => Utils.BoolType + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case OR_REDUCE_OP => { + case Orr => { val t = (t1()) match { - case (_:UIntType|_:SIntType) => BoolType() - case (t1) => UnknownType() + case (_:UIntType|_:SIntType) => Utils.BoolType + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case XOR_REDUCE_OP => { + case Xorr => { val t = (t1()) match { - case (_:UIntType|_:SIntType) => BoolType() - case (t1) => UnknownType() + case (_:UIntType|_:SIntType) => Utils.BoolType + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case CONCAT_OP => { + case Cat => { val t = (t1(),t2()) match { case (_:UIntType|_:SIntType,_:UIntType|_:SIntType) => UIntType(PLUS(w1(),w2())) - case (t1, t2) => UnknownType() + case (t1, t2) => UnknownType } DoPrim(o,a,c,t) } - case BITS_SELECT_OP => { + case Bits => { val t = (t1()) match { - case (_:UIntType|_:SIntType) => UIntType(PLUS(MINUS(c1(),c2()),ONE)) - case (t1) => UnknownType() + case (_:UIntType|_:SIntType) => UIntType(PLUS(MINUS(c1(),c2()),Utils.ONE)) + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case HEAD_OP => { + case Head => { val t = (t1()) match { case (_:UIntType|_:SIntType) => UIntType(c1()) - case (t1) => UnknownType() + case (t1) => UnknownType } DoPrim(o,a,c,t) } - case TAIL_OP => { + case Tail => { val t = (t1()) match { case (_:UIntType|_:SIntType) => UIntType(MINUS(w1(),c1())) - case (t1) => UnknownType() + case (t1) => UnknownType } DoPrim(o,a,c,t) } diff --git a/src/main/scala/firrtl/Serialize.scala b/src/main/scala/firrtl/Serialize.scala index 1735c270..2c45c6ec 100644 --- a/src/main/scala/firrtl/Serialize.scala +++ b/src/main/scala/firrtl/Serialize.scala @@ -27,26 +27,27 @@ MODIFICATIONS. package firrtl +import firrtl.ir._ import firrtl.PrimOps._ import firrtl.Utils._ private object Serialize { - def serialize(root: AST): String = { + def serialize(root: FirrtlNode): String = { lazy val ser = new Serialize root match { case r: PrimOp => ser.serialize(r) case r: Expression => ser.serialize(r) - case r: Stmt => ser.serialize(r) + case r: Statement => ser.serialize(r) case r: Width => ser.serialize(r) - case r: Flip => ser.serialize(r) + case r: Orientation => ser.serialize(r) case r: Field => ser.serialize(r) case r: Type => ser.serialize(r) case r: Direction => ser.serialize(r) case r: Port => ser.serialize(r) - case r: Module => ser.serialize(r) + case r: DefModule => ser.serialize(r) case r: Circuit => ser.serialize(r) case r: StringLit => ser.serialize(r) - case _ => throw new Exception("serialize called on unknown AST node!") + case _ => throw new Exception("serialize called on unknown FirrtlNode!") } } /** Creates new instance of Serialize */ @@ -60,18 +61,18 @@ class Serialize { def serialize(info: Info): String = " " + info.toString - def serialize(op: PrimOp): String = op.getString + def serialize(op: PrimOp): String = op.toString def serialize(lit: StringLit): String = FIRRTLStringLitHandler.escape(lit) def serialize(exp: Expression): String = { exp match { - case v: UIntValue => s"UInt${serialize(v.width)}(${serialize(v.value)})" - case v: SIntValue => s"SInt${serialize(v.width)}(${serialize(v.value)})" - case r: Ref => r.name - case s: SubField => s"${serialize(s.exp)}.${s.name}" - case s: SubIndex => s"${serialize(s.exp)}[${s.value}]" - case s: SubAccess => s"${serialize(s.exp)}[${serialize(s.index)}]" + case v: UIntLiteral => s"UInt${serialize(v.width)}(${serialize(v.value)})" + case v: SIntLiteral => s"SInt${serialize(v.width)}(${serialize(v.value)})" + case r: Reference => r.name + case s: SubField => s"${serialize(s.expr)}.${s.name}" + case s: SubIndex => s"${serialize(s.expr)}[${s.value}]" + case s: SubAccess => s"${serialize(s.expr)}[${serialize(s.index)}]" case m: Mux => s"mux(${serialize(m.cond)}, ${serialize(m.tval)}, ${serialize(m.fval)})" case v: ValidIf => s"validif(${serialize(v.cond)}, ${serialize(v.value)})" case p: DoPrim => @@ -84,7 +85,7 @@ class Serialize { } } - def serialize(stmt: Stmt): String = { + def serialize(stmt: Statement): String = { stmt match { case w: DefWire => s"wire ${w.name} : ${serialize(w.tpe)}${w.info}" case r: DefRegister => @@ -99,10 +100,10 @@ class Serialize { val str = new StringBuilder(s"mem ${m.name} :${m.info}") withIndent { str ++= newline + - s"data-type => ${serialize(m.data_type)}" + newline + + s"data-type => ${serialize(m.dataType)}" + newline + s"depth => ${m.depth}" + newline + - s"read-latency => ${m.read_latency}" + newline + - s"write-latency => ${m.write_latency}" + newline + + s"read-latency => ${m.readLatency}" + newline + + s"write-latency => ${m.writeLatency}" + newline + (if (m.readers.nonEmpty) m.readers.map(r => s"reader => ${r}").mkString(newline) + newline else "") + (if (m.writers.nonEmpty) m.writers.map(w => s"writer => ${w}").mkString(newline) + newline @@ -114,13 +115,13 @@ class Serialize { str.result } case n: DefNode => s"node ${n.name} = ${serialize(n.value)}${n.info}" - case c: Connect => s"${serialize(c.loc)} <= ${serialize(c.exp)}${c.info}" - case b: BulkConnect => s"${serialize(b.loc)} <- ${serialize(b.exp)}${b.info}" + case c: Connect => s"${serialize(c.loc)} <= ${serialize(c.expr)}${c.info}" + case p: PartialConnect => s"${serialize(p.loc)} <- ${serialize(p.expr)}${p.info}" case w: Conditionally => { var str = new StringBuilder(s"when ${serialize(w.pred)} :${w.info}") withIndent { str ++= newline + serialize(w.conseq) } w.alt match { - case s:Empty => str.result + case EmptyStmt => str.result case s => { str ++= newline + "else :" withIndent { str ++= newline + serialize(w.alt) } @@ -136,7 +137,7 @@ class Serialize { } s.result } - case i: IsInvalid => s"${serialize(i.exp)} is invalid${i.info}" + case i: IsInvalid => s"${serialize(i.expr)} is invalid${i.info}" case s: Stop => s"stop(${serialize(s.clk)}, ${serialize(s.en)}, ${s.ret})${s.info}" case p: Print => { val q = '"'.toString @@ -144,7 +145,7 @@ class Serialize { (if (p.args.nonEmpty) p.args.map(serialize).mkString(", ", ", ", "") else "") + s")${p.info}" } - case s: Empty => "skip" + case EmptyStmt => "skip" case s: CDefMemory => { if (s.seq) s"smem ${s.name} : ${serialize(s.tpe)} [${s.size}]${s.info}" else s"cmem ${s.name} : ${serialize(s.tpe)} [${s.size}]${s.info}" @@ -163,16 +164,16 @@ class Serialize { def serialize(w: Width): String = { w match { - case w:UnknownWidth => "" + case UnknownWidth => "" case w: IntWidth => s"<${w.width.toString}>" case w: VarWidth => s"<${w.name}>" } } - def serialize(f: Flip): String = { + def serialize(f: Orientation): String = { f match { - case REVERSE => "flip " - case DEFAULT => "" + case Flip => "flip " + case Default => "" } } @@ -182,8 +183,8 @@ class Serialize { def serialize(t: Type): String = { val commas = ", " // for mkString in BundleType t match { - case c:ClockType => "Clock" - case u:UnknownType => "?" + case ClockType => "Clock" + case UnknownType => "?" case t: UIntType => s"UInt${serialize(t.width)}" case t: SIntType => s"SInt${serialize(t.width)}" case t: BundleType => s"{ ${t.fields.map(serialize).mkString(commas)}}" @@ -193,17 +194,17 @@ class Serialize { def serialize(d: Direction): String = { d match { - case INPUT => "input" - case OUTPUT => "output" + case Input => "input" + case Output => "output" } } def serialize(p: Port): String = s"${serialize(p.direction)} ${p.name} : ${serialize(p.tpe)}${p.info}" - def serialize(m: Module): String = { + def serialize(m: DefModule): String = { m match { - case m: InModule => { + case m: Module => { var s = new StringBuilder(s"module ${m.name} :${m.info}") withIndent { s ++= m.ports.map(newline ++ serialize(_)).mkString @@ -211,7 +212,7 @@ class Serialize { } s.toString } - case m: ExModule => { + case m: ExtModule => { var s = new StringBuilder(s"extmodule ${m.name} :${m.info}") withIndent { s ++= m.ports.map(newline ++ serialize(_)).mkString diff --git a/src/main/scala/firrtl/StringLit.scala b/src/main/scala/firrtl/StringLit.scala index b3d67064..501e9686 100644 --- a/src/main/scala/firrtl/StringLit.scala +++ b/src/main/scala/firrtl/StringLit.scala @@ -27,6 +27,8 @@ MODIFICATIONS. package firrtl +import firrtl.ir._ + import java.nio.charset.StandardCharsets.UTF_8 import scala.annotation.tailrec diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index a2ca3103..a5253e84 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -42,10 +42,14 @@ import com.typesafe.scalalogging.LazyLogging import WrappedExpression._ import firrtl.WrappedType._ import firrtl.Mappers._ +import firrtl.PrimOps._ +import firrtl.ir._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.LinkedHashMap //import scala.reflect.runtime.universe._ +class FIRRTLException(str: String) extends Exception(str) + object Utils extends LazyLogging { private[firrtl] def time[R](name: String)(block: => R): R = { logger.info(s"Starting $name") @@ -72,13 +76,13 @@ object Utils extends LazyLogging { def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt val gen_names = Map[String,Int]() val delin = "_" - def BoolType () = { UIntType(IntWidth(1)) } - val one = UIntValue(BigInt(1),IntWidth(1)) - val zero = UIntValue(BigInt(0),IntWidth(1)) - def uint (i:Int) : UIntValue = { + val BoolType = UIntType(IntWidth(1)) + val one = UIntLiteral(BigInt(1),IntWidth(1)) + val zero = UIntLiteral(BigInt(0),IntWidth(1)) + def uint (i:Int) : UIntLiteral = { val num_bits = req_num_bits(i) val w = IntWidth(scala.math.max(1,num_bits - 1)) - UIntValue(BigInt(i),w) + UIntLiteral(BigInt(i),w) } def req_num_bits (i: Int) : Int = { val ix = if (i < 0) ((-1 * i) - 1) else i @@ -89,7 +93,7 @@ object Utils extends LazyLogging { else if ((e1 == we(zero)) | (e2 == we(zero))) zero else if (e1 == we(one)) e2.e1 else if (e2 == we(one)) e1.e1 - else DoPrim(AND_OP,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) + else DoPrim(And,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) } def OR (e1:WrappedExpression,e2:WrappedExpression) : Expression = { @@ -97,13 +101,13 @@ object Utils extends LazyLogging { else if ((e1 == we(one)) | (e2 == we(one))) one else if (e1 == we(zero)) e2.e1 else if (e2 == we(zero)) e1.e1 - else DoPrim(OR_OP,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) + else DoPrim(Or,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) } - def EQV (e1:Expression,e2:Expression) : Expression = { DoPrim(EQUAL_OP,Seq(e1,e2),Seq(),tpe(e1)) } + def EQV (e1:Expression,e2:Expression) : Expression = { DoPrim(Eq,Seq(e1,e2),Seq(),tpe(e1)) } def NOT (e1:WrappedExpression) : Expression = { if (e1 == we(one)) zero else if (e1 == we(zero)) one - else DoPrim(EQUAL_OP,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1))) + else DoPrim(Eq,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1))) } @@ -118,8 +122,8 @@ object Utils extends LazyLogging { val fieldss = t.fields.map { f => Field(f.name,f.flip,create_mask(f.tpe)) } BundleType(fieldss) } - case t:UIntType => BoolType() - case t:SIntType => BoolType() + case t:UIntType => BoolType + case t:SIntType => BoolType } } def create_exps (n:String, t:Type) : Seq[Expression] = @@ -136,7 +140,7 @@ object Utils extends LazyLogging { tpe(e) match { case (t:UIntType) => Seq(e) case (t:SIntType) => Seq(e) - case (t:ClockType) => Seq(e) + case ClockType => Seq(e) case (t:BundleType) => { t.fields.flatMap { f => create_exps(WSubField(e,f.name,f.tpe,times(gender(e), f.flip))) } } @@ -147,15 +151,15 @@ object Utils extends LazyLogging { } } } - def get_flip (t:Type, i:Int, f:Flip) : Flip = { + def get_flip (t:Type, i:Int, f:Orientation) : Orientation = { if (i >= get_size(t)) error("Shouldn't be here") val x = t match { case (t:UIntType) => f case (t:SIntType) => f - case (t:ClockType) => f + case ClockType => f case (t:BundleType) => { var n = i - var ret:Option[Flip] = None + var ret:Option[Orientation] = None t.fields.foreach { x => { if (n < get_size(x.tpe)) { ret match { @@ -164,11 +168,11 @@ object Utils extends LazyLogging { } } else { n = n - get_size(x.tpe) } }} - ret.asInstanceOf[Some[Flip]].get + ret.asInstanceOf[Some[Orientation]].get } case (t:VectorType) => { var n = i - var ret:Option[Flip] = None + var ret:Option[Orientation] = None for (j <- 0 until t.size) { if (n < get_size(t.tpe)) { ret = Some(get_flip(t.tpe,n,f)) @@ -176,7 +180,7 @@ object Utils extends LazyLogging { n = n - get_size(t.tpe) } } - ret.asInstanceOf[Some[Flip]].get + ret.asInstanceOf[Some[Orientation]].get } } x @@ -204,15 +208,15 @@ object Utils extends LazyLogging { def mux_type (t1:Type,t2:Type) : Type = { if (wt(t1) == wt(t2)) { (t1,t2) match { - case (t1:UIntType,t2:UIntType) => UIntType(UnknownWidth()) - case (t1:SIntType,t2:SIntType) => SIntType(UnknownWidth()) + case (t1:UIntType,t2:UIntType) => UIntType(UnknownWidth) + case (t1:SIntType,t2:SIntType) => SIntType(UnknownWidth) case (t1:VectorType,t2:VectorType) => VectorType(mux_type(t1.tpe,t2.tpe),t1.size) case (t1:BundleType,t2:BundleType) => BundleType((t1.fields,t2.fields).zipped.map((f1,f2) => { Field(f1.name,f1.flip,mux_type(f1.tpe,f2.tpe)) })) } - } else UnknownType() + } else UnknownType } def mux_type_and_widths (e1:Expression,e2:Expression) : Type = mux_type_and_widths(tpe(e1),tpe(e2)) def mux_type_and_widths (t1:Type,t2:Type) : Type = { @@ -231,15 +235,15 @@ object Utils extends LazyLogging { case (t1:VectorType,t2:VectorType) => VectorType(mux_type_and_widths(t1.tpe,t2.tpe),t1.size) case (t1:BundleType,t2:BundleType) => BundleType((t1.fields zip t2.fields).map{case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe,f2.tpe))}) } - } else UnknownType() + } else UnknownType } - def module_type (m:Module) : Type = { + def module_type (m:DefModule) : Type = { BundleType(m.ports.map(p => p.toField)) } def sub_type (v:Type) : Type = { v match { case v:VectorType => v.tpe - case v => UnknownType() + case v => UnknownType } } def field_type (v:Type,s:String) : Type = { @@ -248,47 +252,43 @@ object Utils extends LazyLogging { val ft = v.fields.find(p => p.name == s) ft match { case ft:Some[Field] => ft.get.tpe - case ft => UnknownType() + case ft => UnknownType } } - case v => UnknownType() + case v => UnknownType } } ////===================================== def widthBANG (t:Type) : Width = { t match { - case t:UIntType => t.width - case t:SIntType => t.width - case t:ClockType => IntWidth(1) + case g: GroundType => g.width case t => error("No width!") } } def long_BANG (t:Type) : Long = { (t) match { - case (t:UIntType) => t.width.as[IntWidth].get.width.toLong - case (t:SIntType) => t.width.as[IntWidth].get.width.toLong + case g: GroundType => g.width.as[IntWidth].get.width.toLong case (t:BundleType) => { var w = 0 for (f <- t.fields) { w = w + long_BANG(f.tpe).toInt } w } case (t:VectorType) => t.size * long_BANG(t.tpe) - case (t:ClockType) => 1 } } // ================================= def error(str:String) = throw new FIRRTLException(str) - implicit class ASTUtils(ast: AST) { + implicit class FirrtlNodeUtils(node: FirrtlNode) { def getType(): Type = - ast match { + node match { case e: Expression => e.getType - case s: Stmt => s.getType + case s: Statement => s.getType //case f: Field => f.getType case t: Type => t.getType case p: Port => p.getType - case _ => UnknownType() + case _ => UnknownType } } @@ -306,7 +306,7 @@ object Utils extends LazyLogging { case (t) => 1 } } - def get_valid_points (t1:Type,t2:Type,flip1:Flip,flip2:Flip) : Seq[(Int,Int)] = { + def get_valid_points (t1:Type, t2:Type, flip1:Orientation, flip2:Orientation) : Seq[(Int,Int)] = { //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) (t1,t2) match { case (t1:UIntType,t2:UIntType) => if (flip1 == flip2) Seq((0, 0)) else Seq() @@ -360,47 +360,47 @@ object Utils extends LazyLogging { } def swap (d:Direction) : Direction = { d match { - case OUTPUT => INPUT - case INPUT => OUTPUT + case Output => Input + case Input => Output } } - def swap (f:Flip) : Flip = { + def swap (f:Orientation) : Orientation = { f match { - case DEFAULT => REVERSE - case REVERSE => DEFAULT + case Default => Flip + case Flip => Default } } def to_dir (g:Gender) : Direction = { g match { - case MALE => INPUT - case FEMALE => OUTPUT + case MALE => Input + case FEMALE => Output } } def to_gender (d:Direction) : Gender = { d match { - case INPUT => MALE - case OUTPUT => FEMALE + case Input => MALE + case Output => FEMALE } } - def toGender(f: Flip): Gender = f match { - case DEFAULT => FEMALE - case REVERSE => MALE + def toGender(f: Orientation): Gender = f match { + case Default => FEMALE + case Flip => MALE } - def toFlip(g: Gender): Flip = g match { - case MALE => REVERSE - case FEMALE => DEFAULT + def toFlip(g: Gender): Orientation = g match { + case MALE => Flip + case FEMALE => Default } - def field_flip (v:Type,s:String) : Flip = { + def field_flip (v:Type,s:String) : Orientation = { v match { case v:BundleType => { val ft = v.fields.find {p => p.name == s} ft match { case ft:Some[Field] => ft.get.flip - case ft => DEFAULT + case ft => Default } } - case v => DEFAULT + case v => Default } } def get_field (v:Type,s:String) : Field = { @@ -409,17 +409,17 @@ object Utils extends LazyLogging { val ft = v.fields.find {p => p.name == s} ft match { case ft:Some[Field] => ft.get - case ft => error("Shouldn't be here"); Field("blah",DEFAULT,UnknownType()) + case ft => error("Shouldn't be here"); Field("blah",Default,UnknownType) } } - case v => error("Shouldn't be here"); Field("blah",DEFAULT,UnknownType()) + case v => error("Shouldn't be here"); Field("blah",Default,UnknownType) } } - def times (flip:Flip,d:Direction) : Direction = times(flip, d) - def times (d:Direction,flip:Flip) : Direction = { + def times (flip:Orientation, d:Direction) : Direction = times(flip, d) + def times (d:Direction,flip:Orientation) : Direction = { flip match { - case DEFAULT => d - case REVERSE => swap(d) + case Default => d + case Flip => swap(d) } } def times (g: Gender, d: Direction): Direction = times(d, g) @@ -428,39 +428,38 @@ object Utils extends LazyLogging { case MALE => swap(d) // MALE == INPUT == REVERSE } - def times (g:Gender,flip:Flip) : Gender = times(flip, g) - def times (flip:Flip,g:Gender) : Gender = { + def times (g:Gender,flip:Orientation) : Gender = times(flip, g) + def times (flip:Orientation, g:Gender) : Gender = { flip match { - case DEFAULT => g - case REVERSE => swap(g) + case Default => g + case Flip => swap(g) } } - def times (f1:Flip,f2:Flip) : Flip = { + def times (f1:Orientation, f2:Orientation) : Orientation = { f2 match { - case DEFAULT => f1 - case REVERSE => swap(f1) + case Default => f1 + case Flip => swap(f1) } } // =========== ACCESSORS ========= - def info (s:Stmt) : Info = { + def info (s:Statement) : Info = { s match { case s:DefWire => s.info - case s:DefPoison => s.info case s:DefRegister => s.info case s:DefInstance => s.info case s:WDefInstance => s.info case s:DefMemory => s.info case s:DefNode => s.info case s:Conditionally => s.info - case s:BulkConnect => s.info + case s:PartialConnect => s.info case s:Connect => s.info case s:IsInvalid => s.info case s:Stop => s.info case s:Print => s.info case s:Begin => NoInfo - case s:Empty => NoInfo + case EmptyStmt => NoInfo } } def gender (e:Expression) : Gender = { @@ -470,32 +469,31 @@ object Utils extends LazyLogging { case e:WSubIndex => e.gender case e:WSubAccess => e.gender case e:DoPrim => MALE - case e:UIntValue => MALE - case e:SIntValue => MALE + case e:UIntLiteral => MALE + case e:SIntLiteral => MALE case e:Mux => MALE case e:ValidIf => MALE case e:WInvalid => MALE case e => println(e); error("Shouldn't be here") }} - def get_gender (s:Stmt) : Gender = + def get_gender (s:Statement) : Gender = s match { case s:DefWire => BIGENDER case s:DefRegister => BIGENDER case s:WDefInstance => MALE case s:DefNode => MALE case s:DefInstance => MALE - case s:DefPoison => UNKNOWNGENDER case s:DefMemory => MALE case s:Begin => UNKNOWNGENDER case s:Connect => UNKNOWNGENDER - case s:BulkConnect => UNKNOWNGENDER + case s:PartialConnect => UNKNOWNGENDER case s:Stop => UNKNOWNGENDER case s:Print => UNKNOWNGENDER - case s:Empty => UNKNOWNGENDER + case EmptyStmt => UNKNOWNGENDER case s:IsInvalid => UNKNOWNGENDER } def get_gender (p:Port) : Gender = - if (p.direction == INPUT) MALE else FEMALE + if (p.direction == Input) MALE else FEMALE def kind (e:Expression) : Kind = e match { case e:WRef => e.kind @@ -505,7 +503,7 @@ object Utils extends LazyLogging { } def tpe (e:Expression) : Type = e match { - case e:Ref => e.tpe + case e:Reference => e.tpe case e:SubField => e.tpe case e:SubIndex => e.tpe case e:SubAccess => e.tpe @@ -516,45 +514,43 @@ object Utils extends LazyLogging { case e:DoPrim => e.tpe case e:Mux => e.tpe case e:ValidIf => e.tpe - case e:UIntValue => UIntType(e.width) - case e:SIntValue => SIntType(e.width) - case e:WVoid => UnknownType() - case e:WInvalid => UnknownType() + case e:UIntLiteral => UIntType(e.width) + case e:SIntLiteral => SIntType(e.width) + case e:WVoid => UnknownType + case e:WInvalid => UnknownType } - def get_type (s:Stmt) : Type = { + def get_type (s:Statement) : Type = { s match { case s:DefWire => s.tpe - case s:DefPoison => s.tpe case s:DefRegister => s.tpe case s:DefNode => tpe(s.value) case s:DefMemory => { val depth = s.depth - val addr = Field("addr",DEFAULT,UIntType(IntWidth(scala.math.max(ceil_log2(depth), 1)))) - val en = Field("en",DEFAULT,BoolType()) - val clk = Field("clk",DEFAULT,ClockType()) - val def_data = Field("data",DEFAULT,s.data_type) - val rev_data = Field("data",REVERSE,s.data_type) - val mask = Field("mask",DEFAULT,create_mask(s.data_type)) - val wmode = Field("wmode",DEFAULT,UIntType(IntWidth(1))) - val rdata = Field("rdata",REVERSE,s.data_type) + val addr = Field("addr",Default,UIntType(IntWidth(scala.math.max(ceil_log2(depth), 1)))) + val en = Field("en",Default,BoolType) + val clk = Field("clk",Default,ClockType) + val def_data = Field("data",Default,s.dataType) + val rev_data = Field("data",Flip,s.dataType) + val mask = Field("mask",Default,create_mask(s.dataType)) + val wmode = Field("wmode",Default,UIntType(IntWidth(1))) + val rdata = Field("rdata",Flip,s.dataType) val read_type = BundleType(Seq(rev_data,addr,en,clk)) val write_type = BundleType(Seq(def_data,mask,addr,en,clk)) val readwrite_type = BundleType(Seq(wmode,rdata,def_data,mask,addr,en,clk)) val mem_fields = ArrayBuffer[Field]() - s.readers.foreach {x => mem_fields += Field(x,REVERSE,read_type)} - s.writers.foreach {x => mem_fields += Field(x,REVERSE,write_type)} - s.readwriters.foreach {x => mem_fields += Field(x,REVERSE,readwrite_type)} + s.readers.foreach {x => mem_fields += Field(x,Flip,read_type)} + s.writers.foreach {x => mem_fields += Field(x,Flip,write_type)} + s.readwriters.foreach {x => mem_fields += Field(x,Flip,readwrite_type)} BundleType(mem_fields) } - case s:DefInstance => UnknownType() + case s:DefInstance => UnknownType case s:WDefInstance => s.tpe - case _ => UnknownType() + case _ => UnknownType }} - def get_name (s:Stmt) : String = { + def get_name (s:Statement) : String = { s match { case s:DefWire => s.name - case s:DefPoison => s.name case s:DefRegister => s.name case s:DefNode => s.name case s:DefMemory => s.name @@ -562,17 +558,16 @@ object Utils extends LazyLogging { case s:WDefInstance => s.name case _ => error("Shouldn't be here"); "blah" }} - def get_info (s:Stmt) : Info = { + def get_info (s:Statement) : Info = { s match { case s:DefWire => s.info - case s:DefPoison => s.info case s:DefRegister => s.info case s:DefInstance => s.info case s:WDefInstance => s.info case s:DefMemory => s.info case s:DefNode => s.info case s:Conditionally => s.info - case s:BulkConnect => s.info + case s:PartialConnect => s.info case s:Connect => s.info case s:IsInvalid => s.info case s:Stop => s.info @@ -620,13 +615,13 @@ object Utils extends LazyLogging { /** Gets the root declaration of an expression * - * @param m the [[firrtl.InModule]] to search - * @param expr the [[firrtl.Expression]] that refers to some declaration - * @return the [[firrtl.IsDeclaration]] of `expr` + * @param m the [[firrtl.ir.Module]] to search + * @param expr the [[firrtl.ir.Expression]] that refers to some declaration + * @return the [[firrtl.ir.IsDeclaration]] of `expr` * @throws DeclarationNotFoundException if no declaration of `expr` is found */ - def getDeclaration(m: InModule, expr: Expression): IsDeclaration = { - def getRootDecl(name: String)(s: Stmt): Option[IsDeclaration] = s match { + def getDeclaration(m: Module, expr: Expression): IsDeclaration = { + def getRootDecl(name: String)(s: Statement): Option[IsDeclaration] = s match { case decl: IsDeclaration => if (decl.name == name) Some(decl) else None case c: Conditionally => val m = (getRootDecl(name)(c.conseq), getRootDecl(name)(c.alt)) @@ -661,10 +656,10 @@ object Utils extends LazyLogging { def apply_t (t:Type) : Type = t map (apply_t) map (f) apply_t(t) } - def mapr (f: Width => Width, s:Stmt) : Stmt = { + def mapr (f: Width => Width, s:Statement) : Statement = { def apply_t (t:Type) : Type = mapr(f,t) def apply_e (e:Expression) : Expression = e map (apply_e) map (apply_t) map (f) - def apply_s (s:Stmt) : Stmt = s map (apply_s) map (apply_e) map (apply_t) + def apply_s (s:Statement) : Statement = s map (apply_s) map (apply_e) map (apply_t) apply_s(s) } val ONE = IntWidth(1) @@ -718,26 +713,25 @@ object Utils extends LazyLogging { // to-stmt(body(m)) // map(to-port,ports(m)) // sym-hash - implicit class StmtUtils(stmt: Stmt) { + implicit class StmtUtils(stmt: Statement) { def getType(): Type = stmt match { case s: DefWire => s.tpe case s: DefRegister => s.tpe - case s: DefMemory => s.data_type - case _ => UnknownType() + case s: DefMemory => s.dataType + case _ => UnknownType } def getInfo: Info = stmt match { case s: DefWire => s.info - case s: DefPoison => s.info case s: DefRegister => s.info case s: DefInstance => s.info case s: DefMemory => s.info case s: DefNode => s.info case s: Conditionally => s.info - case s: BulkConnect => s.info + case s: PartialConnect => s.info case s: Connect => s.info case s: IsInvalid => s.info case s: Stop => s.info @@ -746,18 +740,18 @@ object Utils extends LazyLogging { } } - implicit class FlipUtils(f: Flip) { - def flip(): Flip = { + implicit class FlipUtils(f: Orientation) { + def flip(): Orientation = { f match { - case REVERSE => DEFAULT - case DEFAULT => REVERSE + case Flip => Default + case Default => Flip } } def toDirection(): Direction = { f match { - case DEFAULT => OUTPUT - case REVERSE => INPUT + case Default => Output + case Flip => Input } } } @@ -772,7 +766,7 @@ object Utils extends LazyLogging { implicit class TypeUtils(t: Type) { def isGround: Boolean = t match { - case (_: UIntType | _: SIntType | _: ClockType) => true + case (_: UIntType | _: SIntType | ClockType) => true case (_: BundleType | _: VectorType) => false } def isAggregate: Boolean = !t.isGround @@ -780,22 +774,22 @@ object Utils extends LazyLogging { def getType(): Type = t match { case v: VectorType => v.tpe - case tpe: Type => UnknownType() + case tpe: Type => UnknownType } def wipeWidth(): Type = t match { - case t: UIntType => UIntType(UnknownWidth()) - case t: SIntType => SIntType(UnknownWidth()) + case t: UIntType => UIntType(UnknownWidth) + case t: SIntType => SIntType(UnknownWidth) case _ => t } } implicit class DirectionUtils(d: Direction) { - def toFlip(): Flip = { + def toFlip(): Orientation = { d match { - case INPUT => REVERSE - case OUTPUT => DEFAULT + case Input => Flip + case Output => Default } } } diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index f2a3953b..91f9a0ce 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -42,14 +42,15 @@ import antlr._ import PrimOps._ import FIRRTLParser._ import Parser.{InfoMode, IgnoreInfo, UseInfo, GenInfo, AppendInfo} +import firrtl.ir._ import scala.annotation.tailrec -class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[AST] +class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] { // Strip file path private def stripPath(filename: String) = filename.drop(filename.lastIndexOf("/")+1) - def visit[AST](ctx: FIRRTLParser.CircuitContext): Circuit = visitCircuit(ctx) + def visit[FirrtlNode](ctx: FIRRTLParser.CircuitContext): Circuit = visitCircuit(ctx) // These regex have to change if grammar changes private def string2BigInt(s: String): BigInt = { @@ -86,35 +87,37 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[AST] } infoMode match { case UseInfo => - if (useInfo.length == 0) NoInfo else FileInfo(FIRRTLStringLitHandler.unescape(useInfo)) + if (useInfo.length == 0) NoInfo + else ir.FileInfo(FIRRTLStringLitHandler.unescape(useInfo)) case AppendInfo(filename) => val newInfo = useInfo + ":" + genInfo(filename) - FileInfo(FIRRTLStringLitHandler.unescape(newInfo)) - case GenInfo(filename) => FileInfo(FIRRTLStringLitHandler.unescape(genInfo(filename))) + ir.FileInfo(FIRRTLStringLitHandler.unescape(newInfo)) + case GenInfo(filename) => + ir.FileInfo(FIRRTLStringLitHandler.unescape(genInfo(filename))) case IgnoreInfo => NoInfo } } - private def visitCircuit[AST](ctx: FIRRTLParser.CircuitContext): Circuit = + private def visitCircuit[FirrtlNode](ctx: FIRRTLParser.CircuitContext): Circuit = Circuit(visitInfo(Option(ctx.info), ctx), ctx.module.map(visitModule), (ctx.id.getText)) - private def visitModule[AST](ctx: FIRRTLParser.ModuleContext): Module = { + private def visitModule[FirrtlNode](ctx: FIRRTLParser.ModuleContext): DefModule = { val info = visitInfo(Option(ctx.info), ctx) ctx.getChild(0).getText match { - case "module" => InModule(info, ctx.id.getText, ctx.port.map(visitPort), visitBlock(ctx.block)) - case "extmodule" => ExModule(info, ctx.id.getText, ctx.port.map(visitPort)) + case "module" => Module(info, ctx.id.getText, ctx.port.map(visitPort), visitBlock(ctx.block)) + case "extmodule" => ExtModule(info, ctx.id.getText, ctx.port.map(visitPort)) } } - private def visitPort[AST](ctx: FIRRTLParser.PortContext): Port = { + private def visitPort[FirrtlNode](ctx: FIRRTLParser.PortContext): Port = { Port(visitInfo(Option(ctx.info), ctx), (ctx.id.getText), visitDir(ctx.dir), visitType(ctx.`type`)) } - private def visitDir[AST](ctx: FIRRTLParser.DirContext): Direction = + private def visitDir[FirrtlNode](ctx: FIRRTLParser.DirContext): Direction = ctx.getText match { - case "input" => INPUT - case "output" => OUTPUT + case "input" => Input + case "output" => Output } - private def visitMdir[AST](ctx: FIRRTLParser.MdirContext): MPortDir = + private def visitMdir[FirrtlNode](ctx: FIRRTLParser.MdirContext): MPortDir = ctx.getText match { case "infer" => MInfer case "read" => MRead @@ -123,33 +126,33 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[AST] } // Match on a type instead of on strings? - private def visitType[AST](ctx: FIRRTLParser.TypeContext): Type = { + private def visitType[FirrtlNode](ctx: FIRRTLParser.TypeContext): Type = { ctx.getChild(0) match { case term: TerminalNode => term.getText match { case "UInt" => if (ctx.getChildCount > 1) UIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) - else UIntType( UnknownWidth() ) + else UIntType( UnknownWidth ) case "SInt" => if (ctx.getChildCount > 1) SIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) - else SIntType( UnknownWidth() ) - case "Clock" => ClockType() + else SIntType( UnknownWidth ) + case "Clock" => ClockType case "{" => BundleType(ctx.field.map(visitField)) } case tpe: TypeContext => new VectorType(visitType(ctx.`type`), string2Int(ctx.IntLit.getText)) } } - private def visitField[AST](ctx: FIRRTLParser.FieldContext): Field = { - val flip = if(ctx.getChild(0).getText == "flip") REVERSE else DEFAULT + private def visitField[FirrtlNode](ctx: FIRRTLParser.FieldContext): Field = { + val flip = if(ctx.getChild(0).getText == "flip") Flip else Default Field((ctx.id.getText), flip, visitType(ctx.`type`)) } // visitBlock - private def visitBlock[AST](ctx: FIRRTLParser.BlockContext): Stmt = + private def visitBlock[FirrtlNode](ctx: FIRRTLParser.BlockContext): Statement = Begin(ctx.stmt.map(visitStmt)) // Memories are fairly complicated to translate thus have a dedicated method - private def visitMem[AST](ctx: FIRRTLParser.StmtContext): Stmt = { + private def visitMem[FirrtlNode](ctx: FIRRTLParser.StmtContext): Statement = { def parseChildren(children: Seq[ParseTree], map: Map[String, Seq[ParseTree]]): Map[String, Seq[ParseTree]] = { val field = children(0).getText if (field == "}") map @@ -186,13 +189,13 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[AST] } // visitStringLit - private def visitStringLit[AST](node: TerminalNode): StringLit = { + private def visitStringLit[FirrtlNode](node: TerminalNode): StringLit = { val raw = node.getText.tail.init // Remove surrounding double quotes FIRRTLStringLitHandler.unescape(raw) } // visitStmt - private def visitStmt[AST](ctx: FIRRTLParser.StmtContext): Stmt = { + private def visitStmt[FirrtlNode](ctx: FIRRTLParser.StmtContext): Statement = { val info = visitInfo(Option(ctx.info), ctx) ctx.getChild(0) match { case term: TerminalNode => term.getText match { @@ -200,8 +203,8 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[AST] case "reg" => { val name = (ctx.id(0).getText) val tpe = visitType(ctx.`type`(0)) - val reset = if (ctx.exp(1) != null) visitExp(ctx.exp(1)) else UIntValue(0, IntWidth(1)) - val init = if (ctx.exp(2) != null) visitExp(ctx.exp(2)) else Ref(name, tpe) + val reset = if (ctx.exp(1) != null) visitExp(ctx.exp(1)) else UIntLiteral(0, IntWidth(1)) + val init = if (ctx.exp(2) != null) visitExp(ctx.exp(2)) else Reference(name, tpe) DefRegister(info, name, tpe, visitExp(ctx.exp(0)), reset, init) } case "mem" => visitMem(ctx) @@ -222,21 +225,21 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[AST] case "inst" => DefInstance(info, (ctx.id(0).getText), (ctx.id(1).getText)) case "node" => DefNode(info, (ctx.id(0).getText), visitExp(ctx.exp(0))) case "when" => { - val alt = if (ctx.block.length > 1) visitBlock(ctx.block(1)) else Empty() + val alt = if (ctx.block.length > 1) visitBlock(ctx.block(1)) else EmptyStmt Conditionally(info, visitExp(ctx.exp(0)), visitBlock(ctx.block(0)), alt) } case "stop(" => Stop(info, string2Int(ctx.IntLit(0).getText), visitExp(ctx.exp(0)), visitExp(ctx.exp(1))) case "printf(" => Print(info, visitStringLit(ctx.StringLit), ctx.exp.drop(2).map(visitExp), visitExp(ctx.exp(0)), visitExp(ctx.exp(1))) - case "skip" => Empty() + case "skip" => EmptyStmt } // If we don't match on the first child, try the next one case _ => { ctx.getChild(1).getText match { case "<=" => Connect(info, visitExp(ctx.exp(0)), visitExp(ctx.exp(1)) ) - case "<-" => BulkConnect(info, visitExp(ctx.exp(0)), visitExp(ctx.exp(1)) ) + case "<-" => PartialConnect(info, visitExp(ctx.exp(0)), visitExp(ctx.exp(1)) ) case "is" => IsInvalid(info, visitExp(ctx.exp(0))) - case "mport" => CDefMPort(info, ctx.id(0).getText, UnknownType(),ctx.id(1).getText,Seq(visitExp(ctx.exp(0)),visitExp(ctx.exp(1))),visitMdir(ctx.mdir)) + case "mport" => CDefMPort(info, ctx.id(0).getText, UnknownType,ctx.id(1).getText,Seq(visitExp(ctx.exp(0)),visitExp(ctx.exp(1))),visitMdir(ctx.mdir)) } } } @@ -244,14 +247,14 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[AST] // add visitRuw ? //T visitRuw(FIRRTLParser.RuwContext ctx); - //private def visitRuw[AST](ctx: FIRRTLParser.RuwContext): + //private def visitRuw[FirrtlNode](ctx: FIRRTLParser.RuwContext): // TODO // - Add mux // - Add validif - private def visitExp[AST](ctx: FIRRTLParser.ExpContext): Expression = + private def visitExp[FirrtlNode](ctx: FIRRTLParser.ExpContext): Expression = if( ctx.getChildCount == 1 ) - Ref((ctx.getText), UnknownType()) + Reference((ctx.getText), UnknownType) else ctx.getChild(0).getText match { case "UInt" => { // This could be better @@ -262,7 +265,7 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[AST] val bigint = string2BigInt(ctx.IntLit(0).getText) (IntWidth(BigInt(scala.math.max(bigint.bitLength,1))),bigint) } - UIntValue(value, width) + UIntLiteral(value, width) } case "SInt" => { val (width, value) = @@ -272,25 +275,25 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[AST] val bigint = string2BigInt(ctx.IntLit(0).getText) (IntWidth(BigInt(bigint.bitLength + 1)),bigint) } - SIntValue(value, width) + SIntLiteral(value, width) } - case "validif(" => ValidIf(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType()) - case "mux(" => Mux(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), visitExp(ctx.exp(2)), UnknownType()) + case "validif(" => ValidIf(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType) + case "mux(" => Mux(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), visitExp(ctx.exp(2)), UnknownType) case _ => ctx.getChild(1).getText match { - case "." => new SubField(visitExp(ctx.exp(0)), (ctx.id.getText), UnknownType()) + case "." => new SubField(visitExp(ctx.exp(0)), (ctx.id.getText), UnknownType) case "[" => if (ctx.exp(1) == null) - new SubIndex(visitExp(ctx.exp(0)), string2Int(ctx.IntLit(0).getText), UnknownType()) - else new SubAccess(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType()) + new SubIndex(visitExp(ctx.exp(0)), string2Int(ctx.IntLit(0).getText), UnknownType) + else new SubAccess(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType) // Assume primop case _ => DoPrim(visitPrimop(ctx.primop), ctx.exp.map(visitExp), - ctx.IntLit.map(x => string2BigInt(x.getText)), UnknownType()) + ctx.IntLit.map(x => string2BigInt(x.getText)), UnknownType) } } // stripSuffix("(") is included because in ANTLR concrete syntax we have to include open parentheses, // see grammar file for more details - private def visitPrimop[AST](ctx: FIRRTLParser.PrimopContext): PrimOp = fromString(ctx.getText.stripSuffix("(")) + private def visitPrimop[FirrtlNode](ctx: FIRRTLParser.PrimopContext): PrimOp = fromString(ctx.getText.stripSuffix("(")) // visit Id and Keyword? } diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index bb555112..f0c56358 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -30,6 +30,7 @@ package firrtl import scala.collection.Seq import Utils._ import firrtl.Serialize._ +import firrtl.ir._ import WrappedExpression._ import WrappedWidth._ @@ -53,21 +54,20 @@ case class WRef(name:String,tpe:Type,kind:Kind,gender:Gender) extends Expression case class WSubField(exp:Expression,name:String,tpe:Type,gender:Gender) extends Expression case class WSubIndex(exp:Expression,value:Int,tpe:Type,gender:Gender) extends Expression case class WSubAccess(exp:Expression,index:Expression,tpe:Type,gender:Gender) extends Expression -case class WVoid() extends Expression { def tpe = UnknownType() } -case class WInvalid() extends Expression { def tpe = UnknownType() } +case class WVoid() extends Expression { def tpe = UnknownType } +case class WInvalid() extends Expression { def tpe = UnknownType } // Useful for splitting then remerging references -case object EmptyExpression extends Expression { def tpe = UnknownType() } -case class WDefInstance(info:Info,name:String,module:String,tpe:Type) extends Stmt with IsDeclaration +case object EmptyExpression extends Expression { def tpe = UnknownType } +case class WDefInstance(info:Info,name:String,module:String,tpe:Type) extends Statement with IsDeclaration // Resultant width is the same as the maximum input width -case object ADDW_OP extends PrimOp +case object Addw extends PrimOp { override def toString = "addw" } // Resultant width is the same as the maximum input width -case object SUBW_OP extends PrimOp +case object Subw extends PrimOp { override def toString = "subw" } // Resultant width is the same as input argument width -case object DSHLW_OP extends PrimOp +case object Dshlw extends PrimOp { override def toString = "dshlw" } // Resultant width is the same as input argument width -case object SHLW_OP extends PrimOp - +case object Shlw extends PrimOp { override def toString = "shlw" } object WrappedExpression { def apply (e:Expression) = new WrappedExpression(e) @@ -79,8 +79,8 @@ class WrappedExpression (val e1:Expression) { we match { case (we:WrappedExpression) => { (e1,we.e1) match { - case (e1:UIntValue,e2:UIntValue) => if (e1.value == e2.value) eqw(e1.width,e2.width) else false - case (e1:SIntValue,e2:SIntValue) => if (e1.value == e2.value) eqw(e1.width,e2.width) else false + case (e1:UIntLiteral,e2:UIntLiteral) => if (e1.value == e2.value) eqw(e1.width,e2.width) else false + case (e1:SIntLiteral,e2:SIntLiteral) => if (e1.value == e2.value) eqw(e1.width,e2.width) else false case (e1:WRef,e2:WRef) => e1.name equals e2.name case (e1:WSubField,e2:WSubField) => (e1.name equals e2.name) && weq(e1.exp,e2.exp) case (e1:WSubIndex,e2:WSubIndex) => (e1.value == e2.value) && weq(e1.exp,e2.exp) @@ -125,7 +125,7 @@ class WrappedType (val t:Type) { (t,t2.t) match { case (t1:UIntType,t2:UIntType) => true case (t1:SIntType,t2:SIntType) => true - case (t1:ClockType,t2:ClockType) => true + case (ClockType, ClockType) => true case (t1:VectorType,t2:VectorType) => (wt(t1.tpe) == wt(t2.tpe) && t1.size == t2.size) case (t1:BundleType,t2:BundleType) => { var ret = true @@ -161,7 +161,7 @@ class WrappedWidth (val w:Width) { case (w:MinusWidth) => "(" + w.arg1 + " - " + w.arg2 + ")" case (w:ExpWidth) => "exp(" + w.arg1 + ")" case (w:IntWidth) => w.width.toString - case (w:UnknownWidth) => "?" + case UnknownWidth => "?" } } def ww (w:Width) : WrappedWidth = new WrappedWidth(w) @@ -200,7 +200,7 @@ class WrappedWidth (val w:Width) { case (w1:MinusWidth,w2:MinusWidth) => (ww(w1.arg1) == ww(w2.arg1) && ww(w1.arg2) == ww(w2.arg2)) || (ww(w1.arg1) == ww(w2.arg2) && ww(w1.arg2) == ww(w2.arg1)) case (w1:ExpWidth,w2:ExpWidth) => ww(w1.arg1) == ww(w2.arg1) - case (w1:UnknownWidth,w2:UnknownWidth) => true + case (UnknownWidth, UnknownWidth) => true case (w1,w2) => false } } @@ -227,6 +227,6 @@ case object MRead extends MPortDir case object MWrite extends MPortDir case object MReadWrite extends MPortDir -case class CDefMemory (val info: Info, val name: String, val tpe: Type, val size: Int, val seq: Boolean) extends Stmt -case class CDefMPort (val info: Info, val name: String, val tpe: Type, val mem: String, val exps: Seq[Expression], val direction: MPortDir) extends Stmt +case class CDefMemory (val info: Info, val name: String, val tpe: Type, val size: Int, val seq: Boolean) extends Statement +case class CDefMPort (val info: Info, val name: String, val tpe: Type, val mem: String, val exps: Seq[Expression], val direction: MPortDir) extends Statement diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala new file mode 100644 index 00000000..f25ab144 --- /dev/null +++ b/src/main/scala/firrtl/ir/IR.scala @@ -0,0 +1,189 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtl +package ir + +trait Info +case object NoInfo extends Info { + override def toString(): String = "" +} +case class FileInfo(info: StringLit) extends Info { + override def toString(): String = " @[" + info.serialize + "]" +} + +/** Intermediate Representation */ +abstract class FirrtlNode { + def serialize: String = firrtl.Serialize.serialize(this) +} + +trait HasName { + val name: String +} +trait HasInfo { + val info: Info +} +trait IsDeclaration extends HasName with HasInfo + +case class StringLit(array: Array[Byte]) extends FirrtlNode + +/** Primitive Operation + * + * See [[PrimOps]] + */ +abstract class PrimOp extends FirrtlNode + +abstract class Expression extends FirrtlNode { + def tpe: Type +} +case class Reference(name: String, tpe: Type) extends Expression with HasName +case class SubField(expr: Expression, name: String, tpe: Type) extends Expression with HasName +case class SubIndex(expr: Expression, value: Int, tpe: Type) extends Expression +case class SubAccess(expr: Expression, index: Expression, tpe: Type) extends Expression +case class Mux(cond: Expression, tval: Expression, fval: Expression, tpe: Type) extends Expression +case class ValidIf(cond: Expression, value: Expression, tpe: Type) extends Expression +abstract class Literal extends Expression { + val value: BigInt + val width: Width +} +case class UIntLiteral(value: BigInt, width: Width) extends Literal { + def tpe = UIntType(width) +} +case class SIntLiteral(value: BigInt, width: Width) extends Literal { + def tpe = SIntType(width) +} +case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression + +abstract class Statement extends FirrtlNode +case class DefWire(info: Info, name: String, tpe: Type) extends Statement with IsDeclaration +case class DefRegister( + info: Info, + name: String, + tpe: Type, + clock: Expression, + reset: Expression, + init: Expression) extends Statement with IsDeclaration +case class DefInstance(info: Info, name: String, module: String) extends Statement with IsDeclaration +case class DefMemory( + info: Info, + name: String, + dataType: Type, + depth: Int, + writeLatency: Int, + readLatency: Int, + readers: Seq[String], + writers: Seq[String], + readwriters: Seq[String]) extends Statement with IsDeclaration +case class DefNode(info: Info, name: String, value: Expression) extends Statement with IsDeclaration +case class Conditionally( + info: Info, + pred: Expression, + conseq: Statement, + alt: Statement) extends Statement with HasInfo +case class Begin(stmts: Seq[Statement]) extends Statement +case class PartialConnect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo +case class Connect(info: Info, loc: Expression, expr: Expression) extends Statement with HasInfo +case class IsInvalid(info: Info, expr: Expression) extends Statement with HasInfo +case class Stop(info: Info, ret: Int, clk: Expression, en: Expression) extends Statement with HasInfo +case class Print( + info: Info, + string: StringLit, + args: Seq[Expression], + clk: Expression, + en: Expression) extends Statement with HasInfo +case object EmptyStmt extends Statement + +abstract class Width extends FirrtlNode { + def +(x: Width): Width = (this, x) match { + case (a: IntWidth, b: IntWidth) => IntWidth(a.width + b.width) + case _ => UnknownWidth + } + def -(x: Width): Width = (this, x) match { + case (a: IntWidth, b: IntWidth) => IntWidth(a.width - b.width) + case _ => UnknownWidth + } + def max(x: Width): Width = (this, x) match { + case (a: IntWidth, b: IntWidth) => IntWidth(a.width max b.width) + case _ => UnknownWidth + } + def min(x: Width): Width = (this, x) match { + case (a: IntWidth, b: IntWidth) => IntWidth(a.width min b.width) + case _ => UnknownWidth + } +} +/** Positive Integer Bit Width of a [[GroundType]] */ +case class IntWidth(width: BigInt) extends Width +case object UnknownWidth extends Width + +/** Orientation of [[Field]] */ +abstract class Orientation extends FirrtlNode +case object Default extends Orientation +case object Flip extends Orientation + +/** Field of [[BundleType]] */ +case class Field(name: String, flip: Orientation, tpe: Type) extends FirrtlNode with HasName + +abstract class Type extends FirrtlNode +abstract class GroundType extends Type { + val width: Width +} +abstract class AggregateType extends Type +case class UIntType(width: Width) extends GroundType +case class SIntType(width: Width) extends GroundType +case class BundleType(fields: Seq[Field]) extends AggregateType +case class VectorType(tpe: Type, size: Int) extends AggregateType +case object ClockType extends GroundType { + val width = IntWidth(1) +} +case object UnknownType extends Type + +/** [[Port]] Direction */ +abstract class Direction extends FirrtlNode +case object Input extends Direction +case object Output extends Direction + +/** [[DefModule]] Port */ +case class Port(info: Info, name: String, direction: Direction, tpe: Type) extends FirrtlNode with IsDeclaration + +/** Base class for modules */ +abstract class DefModule extends FirrtlNode with IsDeclaration { + val info : Info + val name : String + val ports : Seq[Port] +} +/** Internal Module + * + * An instantiable hardware block + */ +case class Module(info: Info, name: String, ports: Seq[Port], body: Statement) extends DefModule +/** External Module + * + * Generally used for Verilog black boxes + */ +case class ExtModule(info: Info, name: String, ports: Seq[Port]) extends DefModule + +case class Circuit(info: Info, modules: Seq[DefModule], main: String) extends FirrtlNode with HasInfo diff --git a/src/main/scala/firrtl/passes/CheckInitialization.scala b/src/main/scala/firrtl/passes/CheckInitialization.scala index 27857768..6d69b792 100644 --- a/src/main/scala/firrtl/passes/CheckInitialization.scala +++ b/src/main/scala/firrtl/passes/CheckInitialization.scala @@ -28,6 +28,7 @@ MODIFICATIONS. package firrtl.passes import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ @@ -41,16 +42,16 @@ import annotation.tailrec object CheckInitialization extends Pass { def name = "Check Initialization" - private case class VoidExpr(stmt: Stmt, voidDeps: Seq[Expression]) + private case class VoidExpr(stmt: Statement, voidDeps: Seq[Expression]) - class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Stmt]) extends PassException( + class RefNotInitializedException(info: Info, mname: String, name: String, trace: Seq[Statement]) extends PassException( s"$info : [module $mname] Reference $name is not fully initialized.\n" + trace.map(s => s" ${get_info(s)} : ${s.serialize}").mkString("\n") ) - private def getTrace(expr: WrappedExpression, voidExprs: Map[WrappedExpression, VoidExpr]): Seq[Stmt] = { + private def getTrace(expr: WrappedExpression, voidExprs: Map[WrappedExpression, VoidExpr]): Seq[Statement] = { @tailrec - def rec(e: WrappedExpression, map: Map[WrappedExpression, VoidExpr], trace: Seq[Stmt]): Seq[Stmt] = { + def rec(e: WrappedExpression, map: Map[WrappedExpression, VoidExpr], trace: Seq[Statement]): Seq[Statement] = { val voidExpr = map(e) val newTrace = voidExpr.stmt +: trace if (voidExpr.voidDeps.nonEmpty) rec(voidExpr.voidDeps.head, map, newTrace) else newTrace @@ -62,7 +63,7 @@ object CheckInitialization extends Pass { val errors = collection.mutable.ArrayBuffer[PassException]() - def checkInitM(m: InModule): Unit = { + def checkInitM(m: Module): Unit = { val voidExprs = collection.mutable.HashMap[WrappedExpression, VoidExpr]() def hasVoidExpr(e: Expression): (Boolean, Seq[Expression]) = { @@ -85,10 +86,10 @@ object CheckInitialization extends Pass { hasVoid(e) (void, voidDeps) } - def checkInitS(s: Stmt): Stmt = { + def checkInitS(s: Statement): Statement = { s match { case con: Connect => - val (hasVoid, voidDeps) = hasVoidExpr(con.exp) + val (hasVoid, voidDeps) = hasVoidExpr(con.expr) if (hasVoid) voidExprs(con.loc) = VoidExpr(con, voidDeps) con case node: DefNode => @@ -116,7 +117,7 @@ object CheckInitialization extends Pass { c.modules foreach { m => m match { - case m: InModule => checkInitM(m) + case m: Module => checkInitM(m) case m => // Do nothing } } diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 23613e65..ebdd2469 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -34,6 +34,7 @@ import scala.collection.mutable.HashMap import scala.collection.mutable.ArrayBuffer import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.Serialize._ @@ -46,7 +47,7 @@ object CheckHighForm extends Pass with LazyLogging { // Custom Exceptions class NotUniqueException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Reference ${name} does not have a unique name.") class InvalidLOCException extends PassException(s"${sinfo}: [module ${mname}] Invalid connect to an expression that is not a reference or a WritePort.") - class NegUIntException extends PassException(s"${sinfo}: [module ${mname}] UIntValue cannot be negative.") + class NegUIntException extends PassException(s"${sinfo}: [module ${mname}] UIntLiteral cannot be negative.") class UndeclaredReferenceException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Reference ${name} is not declared.") class PoisonWithFlipException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Poison ${name} cannot be a bundle type with flips.") class MemWithFlipException(name: String) extends PassException(s"${sinfo}: [module ${mname}] Memory ${name} cannot be a bundle type with flips.") @@ -69,7 +70,7 @@ object CheckHighForm extends Pass with LazyLogging { t map (findFlip) match { case t: BundleType => { for (f <- t.fields) { - if (f.flip == REVERSE) has = true + if (f.flip == Flip) has = true } t } @@ -90,45 +91,45 @@ object CheckHighForm extends Pass with LazyLogging { def checkHighFormPrimop(e: DoPrim) = { def correctNum(ne: Option[Int], nc: Int) = { ne match { - case Some(i) => if(e.args.length != i) errors.append(new IncorrectNumArgsException(e.op.getString, i)) + case Some(i) => if(e.args.length != i) errors.append(new IncorrectNumArgsException(e.op.toString, i)) case None => // Do Nothing } - if (e.consts.length != nc) errors.append(new IncorrectNumConstsException(e.op.getString, nc)) + if (e.consts.length != nc) errors.append(new IncorrectNumConstsException(e.op.toString, nc)) } e.op match { - case ADD_OP => correctNum(Option(2),0) - case SUB_OP => correctNum(Option(2),0) - case MUL_OP => correctNum(Option(2),0) - case DIV_OP => correctNum(Option(2),0) - case REM_OP => correctNum(Option(2),0) - case LESS_OP => correctNum(Option(2),0) - case LESS_EQ_OP => correctNum(Option(2),0) - case GREATER_OP => correctNum(Option(2),0) - case GREATER_EQ_OP => correctNum(Option(2),0) - case EQUAL_OP => correctNum(Option(2),0) - case NEQUAL_OP => correctNum(Option(2),0) - case PAD_OP => correctNum(Option(1),1) - case AS_UINT_OP => correctNum(Option(1),0) - case AS_SINT_OP => correctNum(Option(1),0) - case AS_CLOCK_OP => correctNum(Option(1),0) - case SHIFT_LEFT_OP => correctNum(Option(1),1) - case SHIFT_RIGHT_OP => correctNum(Option(1),1) - case DYN_SHIFT_LEFT_OP => correctNum(Option(2),0) - case DYN_SHIFT_RIGHT_OP => correctNum(Option(2),0) - case CONVERT_OP => correctNum(Option(1),0) - case NEG_OP => correctNum(Option(1),0) - case NOT_OP => correctNum(Option(1),0) - case AND_OP => correctNum(Option(2),0) - case OR_OP => correctNum(Option(2),0) - case XOR_OP => correctNum(Option(2),0) - case AND_REDUCE_OP => correctNum(None,0) - case OR_REDUCE_OP => correctNum(None,0) - case XOR_REDUCE_OP => correctNum(None,0) - case CONCAT_OP => correctNum(Option(2),0) - case BITS_SELECT_OP => correctNum(Option(1),2) - case HEAD_OP => correctNum(Option(1),1) - case TAIL_OP => correctNum(Option(1),1) + case Add => correctNum(Option(2),0) + case Sub => correctNum(Option(2),0) + case Mul => correctNum(Option(2),0) + case Div => correctNum(Option(2),0) + case Rem => correctNum(Option(2),0) + case Lt => correctNum(Option(2),0) + case Leq => correctNum(Option(2),0) + case Gt => correctNum(Option(2),0) + case Geq => correctNum(Option(2),0) + case Eq => correctNum(Option(2),0) + case Neq => correctNum(Option(2),0) + case Pad => correctNum(Option(1),1) + case AsUInt => correctNum(Option(1),0) + case AsSInt => correctNum(Option(1),0) + case AsClock => correctNum(Option(1),0) + case Shl => correctNum(Option(1),1) + case Shr => correctNum(Option(1),1) + case Dshl => correctNum(Option(2),0) + case Dshr => correctNum(Option(2),0) + case Cvt => correctNum(Option(1),0) + case Neg => correctNum(Option(1),0) + case Not => correctNum(Option(1),0) + case And => correctNum(Option(2),0) + case Or => correctNum(Option(2),0) + case Xor => correctNum(Option(2),0) + case Andr => correctNum(None,0) + case Orr => correctNum(None,0) + case Xorr => correctNum(None,0) + case Cat => correctNum(Option(2),0) + case Bits => correctNum(Option(1),2) + case Head => correctNum(Option(1),1) + case Tail => correctNum(Option(1),1) } } @@ -148,7 +149,7 @@ object CheckHighForm extends Pass with LazyLogging { } def checkValidLoc(e: Expression) = { e match { - case e @ ( _: UIntValue | _: SIntValue | _: DoPrim ) => errors.append(new InvalidLOCException) + case e @ (_: UIntLiteral | _: SIntLiteral | _: DoPrim ) => errors.append(new InvalidLOCException) case _ => // Do Nothing } } @@ -169,7 +170,7 @@ object CheckHighForm extends Pass with LazyLogging { t map (checkHighFormW) } - def checkHighFormM(m: Module): Module = { + def checkHighFormM(m: DefModule): DefModule = { val names = HashMap[String, Boolean]() val mnames = HashMap[String, Boolean]() def checkHighFormE(e: Expression): Expression = { @@ -189,7 +190,7 @@ object CheckHighForm extends Pass with LazyLogging { validSubexp(e.exp) e } - case e: UIntValue => + case e: UIntLiteral => if (e.value < 0) errors.append(new NegUIntException) case e => e map (validSubexp) } @@ -197,7 +198,7 @@ object CheckHighForm extends Pass with LazyLogging { e map (checkHighFormT) e } - def checkHighFormS(s: Stmt): Stmt = { + def checkHighFormS(s: Statement): Statement = { def checkName(name: String): String = { if (names.contains(name)) errors.append(new NotUniqueException(name)) else names(name) = true @@ -209,12 +210,8 @@ object CheckHighForm extends Pass with LazyLogging { s map (checkHighFormT) s map (checkHighFormE) s match { - case s: DefPoison => { - if (hasFlip(s.tpe)) errors.append(new PoisonWithFlipException(s.name)) - checkHighFormT(s.tpe) - } case s: DefMemory => { - if (hasFlip(s.data_type)) errors.append(new MemWithFlipException(s.name)) + if (hasFlip(s.dataType)) errors.append(new MemWithFlipException(s.name)) if (s.depth <= 0) errors.append(new NegMemSizeException) } case s: WDefInstance => { @@ -222,7 +219,7 @@ object CheckHighForm extends Pass with LazyLogging { errors.append(new ModuleNotDefinedException(s.module)) } case s: Connect => checkValidLoc(s.loc) - case s: BulkConnect => checkValidLoc(s.loc) + case s: PartialConnect => checkValidLoc(s.loc) case s: Print => checkFstring(s.string, s.args.length) case _ => // Do Nothing } @@ -243,8 +240,8 @@ object CheckHighForm extends Pass with LazyLogging { } m match { - case m: InModule => checkHighFormS(m.body) - case m: ExModule => // Do Nothing + case m: Module => checkHighFormS(m.body) + case m: ExtModule => // Do Nothing } m } @@ -290,8 +287,8 @@ object CheckTypes extends Pass with LazyLogging { class ValidIfPassiveTypes(info:Info) extends PassException(s"${info}: [module ${mname}] Must validif a passive type.") class ValidIfCondUInt(info:Info) extends PassException(s"${info}: [module ${mname}] A validif condition must be of type UInt.") //;---------------- Helper Functions -------------- - def ut () : UIntType = UIntType(UnknownWidth()) - def st () : SIntType = SIntType(UnknownWidth()) + def ut () : UIntType = UIntType(UnknownWidth) + def st () : SIntType = SIntType(UnknownWidth) def check_types_primop (e:DoPrim, errors:Errors, info:Info) : Unit = { def all_same_type (ls:Seq[Expression]) : Unit = { @@ -322,38 +319,38 @@ object CheckTypes extends Pass with LazyLogging { } e.op match { - case AS_UINT_OP => {} - case AS_SINT_OP => {} - case AS_CLOCK_OP => {} - case DYN_SHIFT_LEFT_OP => is_uint(e.args(1)); all_ground(e.args) - case DYN_SHIFT_RIGHT_OP => is_uint(e.args(1)); all_ground(e.args) - case ADD_OP => all_ground(e.args) - case SUB_OP => all_ground(e.args) - case MUL_OP => all_ground(e.args) - case DIV_OP => all_ground(e.args) - case REM_OP => all_ground(e.args) - case LESS_OP => all_ground(e.args) - case LESS_EQ_OP => all_ground(e.args) - case GREATER_OP => all_ground(e.args) - case GREATER_EQ_OP => all_ground(e.args) - case EQUAL_OP => all_ground(e.args) - case NEQUAL_OP => all_ground(e.args) - case PAD_OP => all_ground(e.args) - case SHIFT_LEFT_OP => all_ground(e.args) - case SHIFT_RIGHT_OP => all_ground(e.args) - case CONVERT_OP => all_ground(e.args) - case NEG_OP => all_ground(e.args) - case NOT_OP => all_ground(e.args) - case AND_OP => all_ground(e.args) - case OR_OP => all_ground(e.args) - case XOR_OP => all_ground(e.args) - case AND_REDUCE_OP => all_ground(e.args) - case OR_REDUCE_OP => all_ground(e.args) - case XOR_REDUCE_OP => all_ground(e.args) - case CONCAT_OP => all_ground(e.args) - case BITS_SELECT_OP => all_ground(e.args) - case HEAD_OP => all_ground(e.args) - case TAIL_OP => all_ground(e.args) + case AsUInt => + case AsSInt => + case AsClock => + case Dshl => is_uint(e.args(1)); all_ground(e.args) + case Dshr => is_uint(e.args(1)); all_ground(e.args) + case Add => all_ground(e.args) + case Sub => all_ground(e.args) + case Mul => all_ground(e.args) + case Div => all_ground(e.args) + case Rem => all_ground(e.args) + case Lt => all_ground(e.args) + case Leq => all_ground(e.args) + case Gt => all_ground(e.args) + case Geq => all_ground(e.args) + case Eq => all_ground(e.args) + case Neq => all_ground(e.args) + case Pad => all_ground(e.args) + case Shl => all_ground(e.args) + case Shr => all_ground(e.args) + case Cvt => all_ground(e.args) + case Neg => all_ground(e.args) + case Not => all_ground(e.args) + case And => all_ground(e.args) + case Or => all_ground(e.args) + case Xor => all_ground(e.args) + case Andr => all_ground(e.args) + case Orr => all_ground(e.args) + case Xorr => all_ground(e.args) + case Cat => all_ground(e.args) + case Bits => all_ground(e.args) + case Head => all_ground(e.args) + case Tail => all_ground(e.args) } } @@ -366,7 +363,7 @@ object CheckTypes extends Pass with LazyLogging { case (t:BundleType) => { var p = true for (x <- t.fields ) { - if (x.flip == REVERSE) p = false + if (x.flip == Flip) p = false if (!passive(x.tpe)) p = false } p @@ -415,15 +412,15 @@ object CheckTypes extends Pass with LazyLogging { if (!passive(tpe(e))) errors.append(new ValidIfPassiveTypes(info)) if (!(tpe(e.cond).typeof[UIntType])) errors.append(new ValidIfCondUInt(info)) } - case (_:UIntValue|_:SIntValue) => false + case (_:UIntLiteral | _:SIntLiteral) => false } e } - def bulk_equals (t1: Type, t2: Type, flip1: Flip, flip2: Flip): Boolean = { + def bulk_equals (t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Boolean = { //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) (t1,t2) match { - case (t1:ClockType,t2:ClockType) => flip1 == flip2 + case (ClockType, ClockType) => flip1 == flip2 case (t1:UIntType,t2:UIntType) => flip1 == flip2 case (t1:SIntType,t2:SIntType) => flip1 == flip2 case (t1:BundleType,t2:BundleType) => { @@ -444,20 +441,20 @@ object CheckTypes extends Pass with LazyLogging { } } - def check_types_s (s:Stmt) : Stmt = { + def check_types_s (s:Statement) : Statement = { s map (check_types_e(get_info(s))) match { - case (s:Connect) => if (wt(tpe(s.loc)) != wt(tpe(s.exp))) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.exp.serialize)) + case (s:Connect) => if (wt(tpe(s.loc)) != wt(tpe(s.expr))) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) case (s:DefRegister) => if (wt(s.tpe) != wt(tpe(s.init))) errors.append(new InvalidRegInit(s.info)) - case (s:BulkConnect) => if (!bulk_equals(tpe(s.loc),tpe(s.exp),DEFAULT,DEFAULT) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.exp.serialize)) + case (s:PartialConnect) => if (!bulk_equals(tpe(s.loc),tpe(s.expr),Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) case (s:Stop) => { - if (wt(tpe(s.clk)) != wt(ClockType()) ) errors.append(new ReqClk(s.info)) + if (wt(tpe(s.clk)) != wt(ClockType) ) errors.append(new ReqClk(s.info)) if (wt(tpe(s.en)) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) } case (s:Print)=> { for (x <- s.args ) { if (wt(tpe(x)) != wt(ut()) && wt(tpe(x)) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info)) } - if (wt(tpe(s.clk)) != wt(ClockType()) ) errors.append(new ReqClk(s.info)) + if (wt(tpe(s.clk)) != wt(ClockType) ) errors.append(new ReqClk(s.info)) if (wt(tpe(s.en)) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) } case (s:Conditionally) => if (wt(tpe(s.pred)) != wt(ut()) ) errors.append(new PredNotUInt(s.info)) @@ -470,8 +467,8 @@ object CheckTypes extends Pass with LazyLogging { for (m <- c.modules ) { mname = m.name (m) match { - case (m:ExModule) => false - case (m:InModule) => check_types_s(m.body) + case (m:ExtModule) => false + case (m:Module) => check_types_s(m.body) } } errors.trigger @@ -486,8 +483,8 @@ object CheckGenders extends Pass { def dir_to_gender (d:Direction) : Gender = { d match { - case INPUT => MALE - case OUTPUT => FEMALE //BI-GENDER + case Input => MALE + case Output => FEMALE //BI-GENDER } } @@ -517,7 +514,7 @@ object CheckGenders extends Pass { val kindx = get_kind(e) def flipQ (t:Type) : Boolean = { var fQ = false - def flip_rec (t:Type,f:Flip) : Type = { + def flip_rec (t:Type,f:Orientation) : Type = { (t) match { case (t:BundleType) => { for (field <- t.fields) { @@ -525,11 +522,11 @@ object CheckGenders extends Pass { } } case (t:VectorType) => flip_rec(t.tpe,f) - case (t) => if (f == REVERSE) fQ = true + case (t) => if (f == Flip) fQ = true } t } - flip_rec(t,DEFAULT) + flip_rec(t,Default) fQ } @@ -564,8 +561,8 @@ object CheckGenders extends Pass { case (e:WSubIndex) => get_gender(e.exp,genders) case (e:WSubAccess) => get_gender(e.exp,genders) case (e:DoPrim) => MALE - case (e:UIntValue) => MALE - case (e:SIntValue) => MALE + case (e:UIntLiteral) => MALE + case (e:SIntLiteral) => MALE case (e:Mux) => MALE case (e:ValidIf) => MALE } @@ -581,18 +578,17 @@ object CheckGenders extends Pass { case (e:DoPrim) => for (e <- e.args ) { check_gender(info,genders,MALE)(e) } case (e:Mux) => e map (check_gender(info,genders,MALE)) case (e:ValidIf) => e map (check_gender(info,genders,MALE)) - case (e:UIntValue) => false - case (e:SIntValue) => false + case (e:UIntLiteral) => false + case (e:SIntLiteral) => false } e } - def check_genders_s (genders:HashMap[String,Gender])(s:Stmt) : Stmt = { + def check_genders_s (genders:HashMap[String,Gender])(s:Statement) : Statement = { s map (check_genders_e(get_info(s),genders)) s map (check_genders_s(genders)) (s) match { case (s:DefWire) => genders(s.name) = BIGENDER - case (s:DefPoison) => genders(s.name) = MALE case (s:DefRegister) => genders(s.name) = BIGENDER case (s:DefNode) => { check_gender(s.info,genders,MALE)(s.value) @@ -602,7 +598,7 @@ object CheckGenders extends Pass { case (s:WDefInstance) => genders(s.name) = MALE case (s:Connect) => { check_gender(s.info,genders,FEMALE)(s.loc) - check_gender(s.info,genders,MALE)(s.exp) + check_gender(s.info,genders,MALE)(s.expr) } case (s:Print) => { for (x <- s.args ) { @@ -611,14 +607,14 @@ object CheckGenders extends Pass { check_gender(s.info,genders,MALE)(s.en) check_gender(s.info,genders,MALE)(s.clk) } - case (s:BulkConnect) => { + case (s:PartialConnect) => { check_gender(s.info,genders,FEMALE)(s.loc) - check_gender(s.info,genders,MALE)(s.exp) + check_gender(s.info,genders,MALE)(s.expr) } case (s:Conditionally) => { check_gender(s.info,genders,MALE)(s.pred) } - case (s:Empty) => false + case EmptyStmt => false case (s:Stop) => { check_gender(s.info,genders,MALE)(s.en) check_gender(s.info,genders,MALE)(s.clk) @@ -635,8 +631,8 @@ object CheckGenders extends Pass { genders(p.name) = dir_to_gender(p.direction) } (m) match { - case (m:ExModule) => false - case (m:InModule) => check_genders_s(genders)(m.body) + case (m:ExtModule) => false + case (m:Module) => check_genders_s(genders)(m.body) } } errors.trigger @@ -654,7 +650,7 @@ object CheckWidths extends Pass { class NegWidthException(info:Info) extends PassException(s"${info}: [module ${mname}] Width cannot be negative or zero.") def run (c:Circuit): Circuit = { val errors = new Errors() - def check_width_m (m:Module) : Unit = { + def check_width_m (m:DefModule) : Unit = { def check_width_w (info:Info)(w:Width) : Width = { (w) match { case (w:IntWidth)=> if (w.width <= 0) errors.append(new NegWidthException(info)) @@ -664,7 +660,7 @@ object CheckWidths extends Pass { } def check_width_e (info:Info)(e:Expression) : Expression = { (e map (check_width_e(info))) match { - case (e:UIntValue) => { + case (e:UIntLiteral) => { (e.width) match { case (w:IntWidth) => if (scala.math.max(1,e.value.bitLength) > w.width) { @@ -674,7 +670,7 @@ object CheckWidths extends Pass { } check_width_w(info)(e.width) } - case (e:SIntValue) => { + case (e:SIntLiteral) => { (e.width) match { case (w:IntWidth) => if (e.value.bitLength + 1 > w.width) errors.append(new WidthTooSmall(info, e.value)) @@ -687,7 +683,7 @@ object CheckWidths extends Pass { } e } - def check_width_s (s:Stmt) : Stmt = { + def check_width_s (s:Statement) : Statement = { s map (check_width_s) map (check_width_e(get_info(s))) def tm (t:Type) : Type = mapr(check_width_w(info(s)) _,t) s map (tm) @@ -698,8 +694,8 @@ object CheckWidths extends Pass { } (m) match { - case (m:ExModule) => {} - case (m:InModule) => check_width_s(m.body) + case (m:ExtModule) => {} + case (m:Module) => check_width_s(m.body) } } diff --git a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala index 717d95e8..7d4c96b2 100644 --- a/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala +++ b/src/main/scala/firrtl/passes/CommonSubexpressionElimination.scala @@ -28,6 +28,7 @@ MODIFICATIONS. package firrtl.passes import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ @@ -36,12 +37,12 @@ import annotation.tailrec object CommonSubexpressionElimination extends Pass { def name = "Common Subexpression Elimination" - private def cseOnce(s: Stmt): (Stmt, Long) = { + private def cseOnce(s: Statement): (Statement, Long) = { var nEliminated = 0L val expressions = collection.mutable.HashMap[MemoizedHash[Expression], String]() val nodes = collection.mutable.HashMap[String, Expression]() - def recordNodes(s: Stmt): Stmt = s match { + def recordNodes(s: Statement): Statement = s match { case x: DefNode => nodes(x.name) = x.value expressions.getOrElseUpdate(x.value, x.name) @@ -62,22 +63,22 @@ object CommonSubexpressionElimination extends Pass { case _ => e map eliminateNodeRef } - def eliminateNodeRefs(s: Stmt): Stmt = s map eliminateNodeRefs map eliminateNodeRef + def eliminateNodeRefs(s: Statement): Statement = s map eliminateNodeRefs map eliminateNodeRef recordNodes(s) (eliminateNodeRefs(s), nEliminated) } @tailrec - private def cse(s: Stmt): Stmt = { + private def cse(s: Statement): Statement = { val (res, n) = cseOnce(s) if (n > 0) cse(res) else res } def run(c: Circuit): Circuit = { val modulesx = c.modules.map { - case m: ExModule => m - case m: InModule => InModule(m.info, m.name, m.ports, cse(m.body)) + case m: ExtModule => m + case m: Module => Module(m.info, m.name, m.ports, cse(m.body)) } Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index 216e94b0..618a96c0 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -28,8 +28,10 @@ MODIFICATIONS. package firrtl.passes import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ +import firrtl.PrimOps._ import annotation.tailrec @@ -37,20 +39,20 @@ object ConstProp extends Pass { def name = "Constant Propagation" trait FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue): UIntValue - def simplify(e: Expression, lhs: UIntValue, rhs: Expression): Expression + def fold(c1: UIntLiteral, c2: UIntLiteral): UIntLiteral + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression): Expression def apply(e: DoPrim): Expression = (e.args(0), e.args(1)) match { - case (lhs: UIntValue, rhs: UIntValue) => fold(lhs, rhs) - case (lhs: UIntValue, rhs) => simplify(e, lhs, rhs) - case (lhs, rhs: UIntValue) => simplify(e, rhs, lhs) + case (lhs: UIntLiteral, rhs: UIntLiteral) => fold(lhs, rhs) + case (lhs: UIntLiteral, rhs) => simplify(e, lhs, rhs) + case (lhs, rhs: UIntLiteral) => simplify(e, rhs, lhs) case _ => e } } object FoldAND extends FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value & c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match { + def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value & c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { case IntWidth(w) if long_BANG(tpe(rhs)) == w => if (lhs.value == 0) lhs // and(x, 0) => 0 else if (lhs.value == (BigInt(1) << w.toInt) - 1) rhs // and(x, 1) => x @@ -60,8 +62,8 @@ object ConstProp extends Pass { } object FoldOR extends FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value | c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match { + def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value | c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { case IntWidth(w) if long_BANG(tpe(rhs)) == w => if (lhs.value == 0) rhs // or(x, 0) => x else if (lhs.value == (BigInt(1) << w.toInt) - 1) lhs // or(x, 1) => 1 @@ -71,8 +73,8 @@ object ConstProp extends Pass { } object FoldXOR extends FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue) = UIntValue(c1.value ^ c2.value, c1.width max c2.width) - def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match { + def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(c1.value ^ c2.value, c1.width max c2.width) + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { case IntWidth(w) if long_BANG(tpe(rhs)) == w => if (lhs.value == 0) rhs // xor(x, 0) => x else e @@ -81,8 +83,8 @@ object ConstProp extends Pass { } object FoldEqual extends FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue) = UIntValue(if (c1.value == c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match { + def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(if (c1.value == c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == 1 => if (lhs.value == 1) rhs // eq(x, 1) => x else e @@ -91,8 +93,8 @@ object ConstProp extends Pass { } object FoldNotEqual extends FoldLogicalOp { - def fold(c1: UIntValue, c2: UIntValue) = UIntValue(if (c1.value != c2.value) 1 else 0, IntWidth(1)) - def simplify(e: Expression, lhs: UIntValue, rhs: Expression) = lhs.width match { + def fold(c1: UIntLiteral, c2: UIntLiteral) = UIntLiteral(if (c1.value != c2.value) 1 else 0, IntWidth(1)) + def simplify(e: Expression, lhs: UIntLiteral, rhs: Expression) = lhs.width match { case IntWidth(w) if w == 1 && long_BANG(tpe(rhs)) == w => if (lhs.value == 0) rhs // neq(x, 0) => x else e @@ -101,15 +103,15 @@ object ConstProp extends Pass { } private def foldConcat(e: DoPrim) = (e.args(0), e.args(1)) match { - case (UIntValue(xv, IntWidth(xw)), UIntValue(yv, IntWidth(yw))) => UIntValue(xv << yw.toInt | yv, IntWidth(xw + yw)) + case (UIntLiteral(xv, IntWidth(xw)), UIntLiteral(yv, IntWidth(yw))) => UIntLiteral(xv << yw.toInt | yv, IntWidth(xw + yw)) case _ => e } private def foldShiftLeft(e: DoPrim) = e.consts(0).toInt match { case 0 => e.args(0) case x => e.args(0) match { - case UIntValue(v, IntWidth(w)) => UIntValue(v << x, IntWidth(w + x)) - case SIntValue(v, IntWidth(w)) => SIntValue(v << x, IntWidth(w + x)) + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v << x, IntWidth(w + x)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v << x, IntWidth(w + x)) case _ => e } } @@ -118,9 +120,9 @@ object ConstProp extends Pass { case 0 => e.args(0) case x => e.args(0) match { // TODO when amount >= x.width, return a zero-width wire - case UIntValue(v, IntWidth(w)) => UIntValue(v >> x, IntWidth((w - x) max 1)) + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v >> x, IntWidth((w - x) max 1)) // take sign bit if shift amount is larger than arg width - case SIntValue(v, IntWidth(w)) => SIntValue(v >> x, IntWidth((w - x) max 1)) + case SIntLiteral(v, IntWidth(w)) => SIntLiteral(v >> x, IntWidth((w - x) max 1)) case _ => e } } @@ -132,15 +134,15 @@ object ConstProp extends Pass { case _ => false } def isZero(e: Expression) = e match { - case UIntValue(value,_) => value == BigInt(0) - case SIntValue(value,_) => value == BigInt(0) + case UIntLiteral(value, _) => value == BigInt(0) + case SIntLiteral(value, _) => value == BigInt(0) case _ => false } x match { - case DoPrim(LESS_OP, Seq(a,b),_,_) if(isUInt(a) && isZero(b)) => zero - case DoPrim(LESS_EQ_OP, Seq(a,b),_,_) if(isZero(a) && isUInt(b)) => one - case DoPrim(GREATER_OP, Seq(a,b),_,_) if(isZero(a) && isUInt(b)) => zero - case DoPrim(GREATER_EQ_OP,Seq(a,b),_,_) if(isUInt(a) && isZero(b)) => one + case DoPrim(Lt, Seq(a,b),_,_) if(isUInt(a) && isZero(b)) => zero + case DoPrim(Leq, Seq(a,b),_,_) if(isZero(a) && isUInt(b)) => one + case DoPrim(Gt, Seq(a,b),_,_) if(isZero(a) && isUInt(b)) => zero + case DoPrim(Geq, Seq(a,b),_,_) if(isUInt(a) && isZero(b)) => one case e => e } } @@ -159,8 +161,8 @@ object ConstProp extends Pass { def <= (that: Range) = this.max <= that.min } def range(e: Expression): Range = e match { - case UIntValue(value, _) => Range(value, value) - case SIntValue(value, _) => Range(value, value) + case UIntLiteral(value, _) => Range(value, value) + case SIntLiteral(value, _) => Range(value, value) case _ => tpe(e) match { case SIntType(IntWidth(width)) => Range( min = BigInt(0) - BigInt(2).pow(width.toInt - 1), @@ -179,15 +181,15 @@ object ConstProp extends Pass { def r1 = range(e.args(1)) e.op match { // Always true - case LESS_OP if (r0 < r1) => one - case LESS_EQ_OP if (r0 <= r1) => one - case GREATER_OP if (r0 > r1) => one - case GREATER_EQ_OP if (r0 >= r1) => one + case Lt if (r0 < r1) => one + case Leq if (r0 <= r1) => one + case Gt if (r0 > r1) => one + case Geq if (r0 >= r1) => one // Always false - case LESS_OP if (r0 >= r1) => zero - case LESS_EQ_OP if (r0 > r1) => zero - case GREATER_OP if (r0 <= r1) => zero - case GREATER_EQ_OP if (r0 < r1) => zero + case Lt if (r0 >= r1) => zero + case Leq if (r0 > r1) => zero + case Gt if (r0 <= r1) => zero + case Geq if (r0 < r1) => zero case _ => e } } @@ -198,29 +200,29 @@ object ConstProp extends Pass { } private def constPropPrim(e: DoPrim): Expression = e.op match { - case SHIFT_LEFT_OP => foldShiftLeft(e) - case SHIFT_RIGHT_OP => foldShiftRight(e) - case CONCAT_OP => foldConcat(e) - case AND_OP => FoldAND(e) - case OR_OP => FoldOR(e) - case XOR_OP => FoldXOR(e) - case EQUAL_OP => FoldEqual(e) - case NEQUAL_OP => FoldNotEqual(e) - case LESS_OP|LESS_EQ_OP|GREATER_OP|GREATER_EQ_OP => foldComparison(e) - case NOT_OP => e.args(0) match { - case UIntValue(v, IntWidth(w)) => UIntValue(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) + case Shl => foldShiftLeft(e) + case Shr => foldShiftRight(e) + case Cat => foldConcat(e) + case And => FoldAND(e) + case Or => FoldOR(e) + case Xor => FoldXOR(e) + case Eq => FoldEqual(e) + case Neq => FoldNotEqual(e) + case (Lt | Leq | Gt | Geq) => foldComparison(e) + case Not => e.args(0) match { + case UIntLiteral(v, IntWidth(w)) => UIntLiteral(v ^ ((BigInt(1) << w.toInt) - 1), IntWidth(w)) case _ => e } - case BITS_SELECT_OP => e.args(0) match { - case UIntValue(v, _) => { + case Bits => e.args(0) match { + case UIntLiteral(v, _) => { val hi = e.consts(0).toInt val lo = e.consts(1).toInt require(hi >= lo) - UIntValue((v >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e))) + UIntLiteral((v >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e))) } case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match { case t: UIntType => x - case _ => DoPrim(AS_UINT_OP, Seq(x), Seq(), tpe(e)) + case _ => DoPrim(AsUInt, Seq(x), Seq(), tpe(e)) } case _ => e } @@ -230,33 +232,33 @@ object ConstProp extends Pass { private def constPropMuxCond(m: Mux) = { // Only propagate a value if its width matches the mux width def propagate(e: Expression, muxWidth: BigInt) = e match { - case UIntValue(v, _) => UIntValue(v, IntWidth(muxWidth)) + case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(muxWidth)) case _ => tpe(e) match { case UIntType(IntWidth(w)) if muxWidth == w => e case _ => m } } (m.cond, m.tpe) match { - case (UIntValue(c, _), UIntType(IntWidth(w))) => propagate(if (c == 1) m.tval else m.fval, w) + case (UIntLiteral(c, _), UIntType(IntWidth(w))) => propagate(if (c == 1) m.tval else m.fval, w) case _ => m } } private def constPropMux(m: Mux): Expression = (m.tval, m.fval) match { case _ if m.tval == m.fval => m.tval - case (t: UIntValue, f: UIntValue) => + case (t: UIntLiteral, f: UIntLiteral) => if (t.value == 1 && f.value == 0 && long_BANG(m.tpe) == 1) m.cond else constPropMuxCond(m) case _ => constPropMuxCond(m) } private def constPropNodeRef(r: WRef, e: Expression) = e match { - case _: UIntValue | _: SIntValue | _: WRef => e + case _: UIntLiteral | _: SIntLiteral | _: WRef => e case _ => r } @tailrec - private def constPropModule(m: InModule): InModule = { + private def constPropModule(m: Module): Module = { var nPropagated = 0L val nodeMap = collection.mutable.HashMap[String, Expression]() @@ -273,7 +275,7 @@ object ConstProp extends Pass { propagated } - def constPropStmt(s: Stmt): Stmt = { + def constPropStmt(s: Statement): Statement = { s match { case x: DefNode => nodeMap(x.name) = x.value case _ => @@ -281,14 +283,14 @@ object ConstProp extends Pass { s map constPropStmt map constPropExpression } - val res = InModule(m.info, m.name, m.ports, constPropStmt(m.body)) + val res = Module(m.info, m.name, m.ports, constPropStmt(m.body)) if (nPropagated > 0) constPropModule(res) else res } def run(c: Circuit): Circuit = { val modulesx = c.modules.map { - case m: ExModule => m - case m: InModule => constPropModule(m) + case m: ExtModule => m + case m: Module => constPropModule(m) } Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/passes/DeadCodeElimination.scala b/src/main/scala/firrtl/passes/DeadCodeElimination.scala index cb772556..80ba0e98 100644 --- a/src/main/scala/firrtl/passes/DeadCodeElimination.scala +++ b/src/main/scala/firrtl/passes/DeadCodeElimination.scala @@ -28,6 +28,7 @@ MODIFICATIONS. package firrtl.passes import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ @@ -36,7 +37,7 @@ import annotation.tailrec object DeadCodeElimination extends Pass { def name = "Dead Code Elimination" - private def dceOnce(s: Stmt): (Stmt, Long) = { + private def dceOnce(s: Statement): (Statement, Long) = { val referenced = collection.mutable.HashSet[String]() var nEliminated = 0L @@ -48,16 +49,16 @@ object DeadCodeElimination extends Pass { e } - def checkUse(s: Stmt): Stmt = s map checkUse map checkExpressionUse + def checkUse(s: Statement): Statement = s map checkUse map checkExpressionUse - def maybeEliminate(x: Stmt, name: String) = + def maybeEliminate(x: Statement, name: String) = if (referenced(name)) x else { nEliminated += 1 - Empty() + EmptyStmt } - def removeUnused(s: Stmt): Stmt = s match { + def removeUnused(s: Statement): Statement = s match { case x: DefRegister => maybeEliminate(x, x.name) case x: DefWire => maybeEliminate(x, x.name) case x: DefNode => maybeEliminate(x, x.name) @@ -69,15 +70,15 @@ object DeadCodeElimination extends Pass { } @tailrec - private def dce(s: Stmt): Stmt = { + private def dce(s: Statement): Statement = { val (res, n) = dceOnce(s) if (n > 0) dce(res) else res } def run(c: Circuit): Circuit = { val modulesx = c.modules.map { - case m: ExModule => m - case m: InModule => InModule(m.info, m.name, m.ports, dce(m.body)) + case m: ExtModule => m + case m: Module => Module(m.info, m.name, m.ports, dce(m.body)) } Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 540aab9f..b6e090f4 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -28,6 +28,7 @@ MODIFICATIONS. package firrtl.passes import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.PrimOps._ @@ -59,7 +60,7 @@ object ExpandWhens extends Pass { hashx } private def getFemaleRefs(n: String, t: Type, g: Gender): Seq[Expression] = { - def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, DEFAULT)) + def getGender(t: Type, i: Int, g: Gender): Gender = times(g, get_flip(t, i, Default)) val exps = create_exps(WRef(n, t, ExpKind(), g)) val expsx = ArrayBuffer[Expression]() for (j <- 0 until exps.size) { @@ -70,12 +71,12 @@ object ExpandWhens extends Pass { } expsx } - private def squashEmpty(s: Stmt): Stmt = { + private def squashEmpty(s: Statement): Statement = { s map squashEmpty match { case Begin(stmts) => - val newStmts = stmts filter (_ != Empty()) + val newStmts = stmts filter (_ != EmptyStmt) newStmts.size match { - case 0 => Empty() + case 0 => EmptyStmt case 1 => newStmts.head case _ => Begin(newStmts) } @@ -102,16 +103,16 @@ object ExpandWhens extends Pass { // ------------ Pass ------------------- def run(c: Circuit): Circuit = { - def expandWhens(m: InModule): (LinkedHashMap[WrappedExpression, Expression], ArrayBuffer[Stmt], Stmt) = { + def expandWhens(m: Module): (LinkedHashMap[WrappedExpression, Expression], ArrayBuffer[Statement], Statement) = { val namespace = Namespace(m) - val simlist = ArrayBuffer[Stmt]() + val simlist = ArrayBuffer[Statement]() // defaults ideally would be immutable.Map but conversion from mutable.LinkedHashMap to mutable.Map is VERY slow def expandWhens( netlist: LinkedHashMap[WrappedExpression, Expression], defaults: Seq[collection.mutable.Map[WrappedExpression, Expression]], p: Expression) - (s: Stmt): Stmt = { + (s: Statement): Statement = { s match { case w: DefWire => getFemaleRefs(w.name, w.tpe, BIGENDER) foreach (ref => netlist(ref) = WVoid()) @@ -120,13 +121,13 @@ object ExpandWhens extends Pass { getFemaleRefs(r.name, r.tpe, BIGENDER) foreach (ref => netlist(ref) = ref) r case c: Connect => - netlist(c.loc) = c.exp - Empty() + netlist(c.loc) = c.expr + EmptyStmt case c: IsInvalid => - netlist(c.exp) = WInvalid() - Empty() + netlist(c.expr) = WInvalid() + EmptyStmt case s: Conditionally => - val memos = ArrayBuffer[Stmt]() + val memos = ArrayBuffer[Statement]() val conseqNetlist = LinkedHashMap[WrappedExpression, Expression]() val altNetlist = LinkedHashMap[WrappedExpression, Expression]() @@ -164,14 +165,14 @@ object ExpandWhens extends Pass { } else { simlist += Print(s.info, s.string, s.args, s.clk, AND(p, s.en)) } - Empty() + EmptyStmt case s: Stop => if (weq(p, one)) { simlist += s } else { simlist += Stop(s.info, s.ret, s.clk, AND(p, s.en)) } - Empty() + EmptyStmt case s => s map expandWhens(netlist, defaults, p) } } @@ -187,11 +188,11 @@ object ExpandWhens extends Pass { } val modulesx = c.modules map { m => m match { - case m: ExModule => m - case m: InModule => + case m: ExtModule => m + case m: Module => val (netlist, simlist, bodyx) = expandWhens(m) val newBody = Begin(Seq(bodyx map squashEmpty) ++ expandNetlist(netlist) ++ simlist) - InModule(m.info, m.name, m.ports, newBody) + Module(m.info, m.name, m.ports, newBody) } } Circuit(c.info, modulesx, c.main) diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 5e523a37..786de0eb 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -6,6 +6,7 @@ import scala.collection.mutable import firrtl.Mappers.{ExpMap,StmtMap} import firrtl.Utils.WithAs +import firrtl.ir._ // Tags an annotation to be consumed by this pass @@ -47,12 +48,12 @@ object InlineInstances extends Transform { if (!moduleMap.contains(name)) errors += new PassException(s"Annotated module does not exist: ${name}") def checkExternal(name: String): Unit = moduleMap(name) match { - case m: ExModule => errors += new PassException(s"Annotated module cannot be an external module: ${name}") + case m: ExtModule => errors += new PassException(s"Annotated module cannot be an external module: ${name}") case _ => {} } def checkInstance(cn: ComponentName): Unit = { var containsCN = false - def onStmt(name: String)(s: Stmt): Stmt = { + def onStmt(name: String)(s: Statement): Statement = { s match { case WDefInstance(_, inst_name, module_name, tpe) => if (name == inst_name) { @@ -63,7 +64,7 @@ object InlineInstances extends Transform { } s map onStmt(name) } - onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[InModule].body) + onStmt(cn.name)(moduleMap(cn.module.name).asInstanceOf[Module].body) if (!containsCN) errors += new PassException(s"Annotated instance does not exist: ${cn.module.name}.${cn.name}") } annModuleNames.foreach{n => checkExists(n)} @@ -85,12 +86,12 @@ object InlineInstances extends Transform { // ---- Pass functions/data ---- // Contains all unaltered modules - val originalModules = mutable.HashMap[String,Module]() + val originalModules = mutable.HashMap[String,DefModule]() // Contains modules whose direct/indirect children modules have been inlined, and whose tagged instances have been inlined. - val inlinedModules = mutable.HashMap[String,Module]() + val inlinedModules = mutable.HashMap[String,DefModule]() // Recursive. - def onModule(m: Module): Module = { + def onModule(m: DefModule): DefModule = { val inlinedInstances = mutable.ArrayBuffer[String]() // Recursive. Replaces inst.port with inst$port def onExp(e: Expression): Expression = e match { @@ -106,7 +107,7 @@ object InlineInstances extends Transform { case e => e map onExp } // Recursive. Inlines tagged instances - def onStmt(s: Stmt): Stmt = s match { + def onStmt(s: Statement): Statement = s match { case WDefInstance(info, instName, moduleName, instTpe) => { def rename(name:String): String = { val newName = instName + inlineDelim + name @@ -114,7 +115,7 @@ object InlineInstances extends Transform { newName } // Rewrites references in inlined statements from ref to inst$ref - def renameStmt(s: Stmt): Stmt = { + def renameStmt(s: Statement): Statement = { def renameExp(e: Expression): Expression = { e map renameExp match { case WRef(name, tpe, kind, gen) => WRef(rename(name), tpe, kind, gen) @@ -136,10 +137,10 @@ object InlineInstances extends Transform { if (shouldInline) { inlinedInstances += instName val instInModule = instModule match { - case m: ExModule => throw new PassException("Cannot inline external module") - case m: InModule => m + case m: ExtModule => throw new PassException("Cannot inline external module") + case m: Module => m } - val stmts = mutable.ArrayBuffer[Stmt]() + val stmts = mutable.ArrayBuffer[Statement]() for (p <- instInModule.ports) { stmts += DefWire(p.info, rename(p.name), p.tpe) } @@ -150,12 +151,12 @@ object InlineInstances extends Transform { case s => s map onExp map onStmt } m match { - case InModule(info, name, ports, body) => { - val mx = InModule(info, name, ports, onStmt(body)) + case Module(info, name, ports, body) => { + val mx = Module(info, name, ports, onStmt(body)) inlinedModules(name) = mx mx } - case m: ExModule => { + case m: ExtModule => { inlinedModules(m.name) = m m } diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 1dc3f782..b86b0651 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -30,16 +30,17 @@ package firrtl.passes import com.typesafe.scalalogging.LazyLogging import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ // Datastructures import scala.collection.mutable.HashMap -/** Removes all aggregate types from a [[Circuit]] +/** Removes all aggregate types from a [[firrtl.ir.Circuit]] * - * @note Assumes [[firrtl.SubAccess]]es have been removed - * @note Assumes [[firrtl.Connect]]s and [[firrtl.IsInvalid]]s only operate on [[firrtl.Expression]]s of ground type + * @note Assumes [[firrtl.ir.SubAccess]]es have been removed + * @note Assumes [[firrtl.ir.Connect]]s and [[firrtl.ir.IsInvalid]]s only operate on [[firrtl.ir.Expression]]s of ground type * @example * {{{ * wire foo : { a : UInt<32>, b : UInt<16> } @@ -54,8 +55,8 @@ object LowerTypes extends Pass { /** Delimiter used in lowering names */ val delim = "_" - /** Expands a chain of referential [[firrtl.Expression]]s into the equivalent lowered name - * @param e [[firrtl.Expression]] made up of _only_ [[firrtl.WRef]], [[firrtl.WSubField]], and [[firrtl.WSubIndex]] + /** Expands a chain of referential [[firrtl.ir.Expression]]s into the equivalent lowered name + * @param e [[firrtl.ir.Expression]] made up of _only_ [[firrtl.WRef]], [[firrtl.WSubField]], and [[firrtl.WSubIndex]] * @return Lowered name of e */ def loweredName(e: Expression): String = e match { @@ -88,7 +89,7 @@ object LowerTypes extends Pass { implicit var mname: String = "" implicit var sinfo: Info = NoInfo - def lowerTypes(m: Module): Module = { + def lowerTypes(m: DefModule): DefModule = { val memDataTypeMap = HashMap[String, Type]() // Lowers an expression of MemKind @@ -110,7 +111,7 @@ object LowerTypes extends Pass { val exps = create_exps(mem.name, memType) exps map { e => val loMemName = loweredName(e) - val loMem = WRef(loMemName, UnknownType(), kind(mem), UNKNOWNGENDER) + val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) mergeRef(loMem, mergeRef(port, field)) } } @@ -122,7 +123,7 @@ object LowerTypes extends Pass { case Some(e) => val loMemExp = mergeRef(mem, e) val loMemName = loweredName(loMemExp) - WRef(loMemName, UnknownType(), kind(mem), UNKNOWNGENDER) + WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) case None => mem } Seq(mergeRef(loMem, mergeRef(port, field))) @@ -149,11 +150,11 @@ object LowerTypes extends Pass { } case e: Mux => e map (lowerTypesExp) case e: ValidIf => e map (lowerTypesExp) - case (_: UIntValue | _: SIntValue) => e + case (_: UIntLiteral | _: SIntLiteral) => e case e: DoPrim => e map (lowerTypesExp) } - def lowerTypesStmt(s: Stmt): Stmt = { + def lowerTypesStmt(s: Statement): Statement = { s map lowerTypesStmt match { case s: DefWire => sinfo = s.info @@ -195,14 +196,14 @@ object LowerTypes extends Pass { } case s: DefMemory => sinfo = s.info - memDataTypeMap += (s.name -> s.data_type) - if (s.data_type.isGround) { + memDataTypeMap += (s.name -> s.dataType) + if (s.dataType.isGround) { s } else { - val exps = create_exps(s.name, s.data_type) + val exps = create_exps(s.name, s.dataType) val stmts = exps map { e => DefMemory(s.info, loweredName(e), tpe(e), s.depth, - s.write_latency, s.read_latency, s.readers, s.writers, + s.writeLatency, s.readLatency, s.readers, s.writers, s.readwriters) } Begin(stmts) @@ -224,9 +225,9 @@ object LowerTypes extends Pass { Begin(stmts) case s: IsInvalid => sinfo = s.info - kind(s.exp) match { + kind(s.expr) match { case k: MemKind => - val exps = lowerTypesMemExp(s.exp) + val exps = lowerTypesMemExp(s.expr) Begin(exps map (exp => IsInvalid(s.info, exp))) case _ => s map (lowerTypesExp) } @@ -234,7 +235,7 @@ object LowerTypes extends Pass { sinfo = s.info kind(s.loc) match { case k: MemKind => - val exp = lowerTypesExp(s.exp) + val exp = lowerTypesExp(s.expr) val locs = lowerTypesMemExp(s.loc) Begin(locs map (loc => Connect(s.info, loc, exp))) case _ => s map (lowerTypesExp) @@ -251,8 +252,8 @@ object LowerTypes extends Pass { exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), tpe(e)) ) } m match { - case m: ExModule => m.copy(ports = portsx) - case m: InModule => InModule(m.info, m.name, portsx, lowerTypesStmt(m.body)) + case m: ExtModule => m.copy(ports = portsx) + case m: Module => Module(m.info, m.name, portsx, lowerTypesStmt(m.body)) } } diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 049da53b..0cabc293 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -3,6 +3,8 @@ package passes import firrtl.Mappers.{ExpMap, StmtMap} import firrtl.Utils.{tpe, long_BANG} +import firrtl.PrimOps._ +import firrtl.ir._ // Makes all implicit width extensions and truncations explicit object PadWidths extends Pass { @@ -17,17 +19,15 @@ object PadWidths extends Pass { // default case should never be reached } if (i > width(e)) - DoPrim(PAD_OP, Seq(e), Seq(i), tx) + DoPrim(Pad, Seq(e), Seq(i), tx) else if (i < width(e)) - DoPrim(BITS_SELECT_OP, Seq(e), Seq(i - 1, 0), tx) + DoPrim(Bits, Seq(e), Seq(i - 1, 0), tx) else e } // Recursive, updates expression so children exp's have correct widths private def onExp(e: Expression): Expression = { - val sensitiveOps = Seq( - LESS_OP, LESS_EQ_OP, GREATER_OP, GREATER_EQ_OP, EQUAL_OP, - NEQUAL_OP, NOT_OP, AND_OP, OR_OP, XOR_OP, ADD_OP, SUB_OP, - MUL_OP, DIV_OP, REM_OP, SHIFT_RIGHT_OP) + val sensitiveOps = Seq( Lt, Leq, Gt, Geq, Eq, Neq, Not, And, Or, Xor, + Add, Sub, Mul, Div, Rem, Shr) val x = e map onExp x match { case Mux(cond, tval, fval, tpe) => { @@ -40,15 +40,15 @@ object PadWidths extends Pass { val i = args.map(a => width(a)).foldLeft(0) {(a, b) => math.max(a, b)} x map fixup(i) } - case DYN_SHIFT_LEFT_OP => { + case Dshl => { // special case as args aren't all same width val ax = fixup(width(tpe))(args(0)) - DoPrim(DSHLW_OP, Seq(ax, args(1)), consts, tpe) + DoPrim(Dshlw, Seq(ax, args(1)), consts, tpe) } - case SHIFT_LEFT_OP => { + case Shl => { // special case as arg should be same width as result val ax = fixup(width(tpe))(args(0)) - DoPrim(SHLW_OP, Seq(ax), consts, tpe) + DoPrim(Shlw, Seq(ax), consts, tpe) } case _ => x } @@ -57,10 +57,10 @@ object PadWidths extends Pass { } } // Recursive. Fixes assignments and register initialization widths - private def onStmt(s: Stmt): Stmt = { + private def onStmt(s: Statement): Statement = { s map onExp match { case s: Connect => { - val ex = fixup(width(s.loc))(s.exp) + val ex = fixup(width(s.loc))(s.expr) Connect(s.info, s.loc, ex) } case s: DefRegister => { @@ -70,10 +70,10 @@ object PadWidths extends Pass { case s => s map onStmt } } - private def onModule(m: Module): Module = { + private def onModule(m: DefModule): DefModule = { m match { - case m:InModule => InModule(m.info, m.name, m.ports, onStmt(m.body)) - case m:ExModule => m + case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body)) + case m: ExtModule => m } } def run(c: Circuit): Circuit = Circuit(c.info, c.modules.map(onModule _), c.main) diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index abd758bf..6b88c514 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -36,6 +36,7 @@ import scala.collection.mutable.HashMap import scala.collection.mutable.ArrayBuffer import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ import firrtl.Serialize._ @@ -69,24 +70,24 @@ object ToWorkingIR extends Pass { def run (c:Circuit): Circuit = { def toExp (e:Expression) : Expression = { e map (toExp) match { - case e:Ref => WRef(e.name, e.tpe, NodeKind(), UNKNOWNGENDER) - case e:SubField => WSubField(e.exp, e.name, e.tpe, UNKNOWNGENDER) - case e:SubIndex => WSubIndex(e.exp, e.value, e.tpe, UNKNOWNGENDER) - case e:SubAccess => WSubAccess(e.exp, e.index, e.tpe, UNKNOWNGENDER) + case e:Reference => WRef(e.name, e.tpe, NodeKind(), UNKNOWNGENDER) + case e:SubField => WSubField(e.expr, e.name, e.tpe, UNKNOWNGENDER) + case e:SubIndex => WSubIndex(e.expr, e.value, e.tpe, UNKNOWNGENDER) + case e:SubAccess => WSubAccess(e.expr, e.index, e.tpe, UNKNOWNGENDER) case e => e } } - def toStmt (s:Stmt) : Stmt = { + def toStmt (s:Statement) : Statement = { s map (toExp) match { - case s:DefInstance => WDefInstance(s.info,s.name,s.module,UnknownType()) + case s:DefInstance => WDefInstance(s.info,s.name,s.module,UnknownType) case s => s map (toStmt) } } val modulesx = c.modules.map { m => mname = m.name m match { - case m:InModule => InModule(m.info,m.name, m.ports, toStmt(m.body)) - case m:ExModule => m + case m:Module => Module(m.info,m.name, m.ports, toStmt(m.body)) + case m:ExtModule => m } } Circuit(c.info,modulesx,c.main) @@ -97,24 +98,23 @@ object ResolveKinds extends Pass { private var mname = "" def name = "Resolve Kinds" def run (c:Circuit): Circuit = { - def resolve_kinds (m:Module, c:Circuit):Module = { + def resolve_kinds (m:DefModule, c:Circuit):DefModule = { val kinds = LinkedHashMap[String,Kind]() - def resolve (body:Stmt) = { + def resolve (body:Statement) = { def resolve_expr (e:Expression):Expression = { e match { case e:WRef => WRef(e.name,tpe(e),kinds(e.name),e.gender) case e => e map (resolve_expr) } } - def resolve_stmt (s:Stmt):Stmt = s map (resolve_stmt) map (resolve_expr) + def resolve_stmt (s:Statement):Statement = s map (resolve_stmt) map (resolve_expr) resolve_stmt(body) } - def find (m:Module) = { - def find_stmt (s:Stmt):Stmt = { + def find (m:DefModule) = { + def find_stmt (s:Statement):Statement = { s match { case s:DefWire => kinds(s.name) = WireKind() - case s:DefPoison => kinds(s.name) = PoisonKind() case s:DefNode => kinds(s.name) = NodeKind() case s:DefRegister => kinds(s.name) = RegKind() case s:WDefInstance => kinds(s.name) = InstanceKind() @@ -125,19 +125,19 @@ object ResolveKinds extends Pass { } m.ports.foreach { p => kinds(p.name) = PortKind() } m match { - case m:InModule => find_stmt(m.body) - case m:ExModule => false + case m:Module => find_stmt(m.body) + case m:ExtModule => false } } mname = m.name find(m) m match { - case m:InModule => { + case m:Module => { val bodyx = resolve(m.body) - InModule(m.info,m.name,m.ports,bodyx) + Module(m.info,m.name,m.ports,bodyx) } - case m:ExModule => ExModule(m.info,m.name,m.ports) + case m:ExtModule => ExtModule(m.info,m.name,m.ports) } } val modulesx = c.modules.map(m => resolve_kinds(m,c)) @@ -148,18 +148,17 @@ object ResolveKinds extends Pass { object InferTypes extends Pass { private var mname = "" def name = "Infer Types" - def set_type (s:Stmt,t:Type) : Stmt = { + def set_type (s:Statement, t:Type) : Statement = { s match { case s:DefWire => DefWire(s.info,s.name,t) case s:DefRegister => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init) - case s:DefMemory => DefMemory(s.info,s.name,t,s.depth,s.write_latency,s.read_latency,s.readers,s.writers,s.readwriters) + case s:DefMemory => DefMemory(s.info,s.name,t,s.depth,s.writeLatency,s.readLatency,s.readers,s.writers,s.readwriters) case s:DefNode => s - case s:DefPoison => DefPoison(s.info,s.name,t) } } def remove_unknowns_w (w:Width)(implicit namespace: Namespace):Width = { w match { - case w:UnknownWidth => VarWidth(namespace.newName("w")) + case UnknownWidth => VarWidth(namespace.newName("w")) case w => w } } @@ -167,7 +166,7 @@ object InferTypes extends Pass { def run (c:Circuit): Circuit = { val module_types = LinkedHashMap[String,Type]() implicit val wnamespace = Namespace() - def infer_types (m:Module) : Module = { + def infer_types (m:DefModule) : DefModule = { val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { e map (infer_types_e) match { @@ -178,11 +177,11 @@ object InferTypes extends Pass { case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(tpe(e.exp)),e.gender) case e:DoPrim => set_primop_type(e) case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval)) - case e:UIntValue => e - case e:SIntValue => e + case e:UIntLiteral => e + case e:SIntLiteral => e } } - def infer_types_s (s:Stmt) : Stmt = { + def infer_types_s (s:Statement) : Statement = { s match { case s:DefRegister => { val t = remove_unknowns(get_type(s)) @@ -195,12 +194,6 @@ object InferTypes extends Pass { types(s.name) = t set_type(sx,t) } - case s:DefPoison => { - val sx = s map (infer_types_e) - val t = remove_unknowns(get_type(sx)) - types(s.name) = t - set_type(sx,t) - } case s:DefNode => { val sx = s map (infer_types_e) val t = remove_unknowns(get_type(sx)) @@ -210,7 +203,7 @@ object InferTypes extends Pass { case s:DefMemory => { val t = remove_unknowns(get_type(s)) types(s.name) = t - val dt = remove_unknowns(s.data_type) + val dt = remove_unknowns(s.dataType) set_type(s,dt) } case s:WDefInstance => { @@ -224,8 +217,8 @@ object InferTypes extends Pass { mname = m.name m.ports.foreach(p => types(p.name) = p.tpe) m match { - case m:InModule => InModule(m.info,m.name,m.ports,infer_types_s(m.body)) - case m:ExModule => m + case m:Module => Module(m.info,m.name,m.ports,infer_types_s(m.body)) + case m:ExtModule => m } } @@ -234,8 +227,8 @@ object InferTypes extends Pass { mname = m.name val portsx = m.ports.map(p => Port(p.info,p.name,p.direction,remove_unknowns(p.tpe))) m match { - case m:InModule => InModule(m.info,m.name,portsx,m.body) - case m:ExModule => ExModule(m.info,m.name,portsx) + case m:Module => Module(m.info,m.name,portsx,m.body) + case m:ExtModule => ExtModule(m.info,m.name,portsx) } } } @@ -254,8 +247,8 @@ object ResolveGenders extends Pass { case e:WSubField => { val expx = field_flip(tpe(e.exp),e.name) match { - case DEFAULT => resolve_e(g)(e.exp) - case REVERSE => resolve_e(swap(g))(e.exp) + case Default => resolve_e(g)(e.exp) + case Flip => resolve_e(swap(g))(e.exp) } WSubField(expx,e.name,e.tpe,g) } @@ -272,21 +265,21 @@ object ResolveGenders extends Pass { } } - def resolve_s (s:Stmt) : Stmt = { + def resolve_s (s:Statement) : Statement = { s match { case s:IsInvalid => { - val expx = resolve_e(FEMALE)(s.exp) + val expx = resolve_e(FEMALE)(s.expr) IsInvalid(s.info,expx) } case s:Connect => { val locx = resolve_e(FEMALE)(s.loc) - val expx = resolve_e(MALE)(s.exp) + val expx = resolve_e(MALE)(s.expr) Connect(s.info,locx,expx) } - case s:BulkConnect => { + case s:PartialConnect => { val locx = resolve_e(FEMALE)(s.loc) - val expx = resolve_e(MALE)(s.exp) - BulkConnect(s.info,locx,expx) + val expx = resolve_e(MALE)(s.expr) + PartialConnect(s.info,locx,expx) } case s => s map (resolve_e(MALE)) map (resolve_s) } @@ -295,11 +288,11 @@ object ResolveGenders extends Pass { m => { mname = m.name m match { - case m:InModule => { + case m:Module => { val bodyx = resolve_s(m.body) - InModule(m.info,m.name,m.ports,bodyx) + Module(m.info,m.name,m.ports,bodyx) } - case m:ExModule => m + case m:ExtModule => m } } } @@ -479,7 +472,7 @@ object InferWidths extends Pass { (t) match { case (t:UIntType) => t.width case (t:SIntType) => t.width - case (t:ClockType) => IntWidth(1) + case ClockType => IntWidth(1) case (t) => error("No width!"); IntWidth(-1) } } def width_BANG (e:Expression) : Width = width_BANG(tpe(e)) @@ -534,15 +527,15 @@ object InferWidths extends Pass { val portsx = m.ports.map{ p => { Port(p.info,p.name,p.direction,mapr(reduce_var_widths_w _,p.tpe)) }} (m) match { - case (m:ExModule) => ExModule(m.info,m.name,portsx) - case (m:InModule) => mname = m.name; InModule(m.info,m.name,portsx,mapr(reduce_var_widths_w _,m.body)) }}} + case (m:ExtModule) => ExtModule(m.info,m.name,portsx) + case (m:Module) => mname = m.name; Module(m.info,m.name,portsx,mapr(reduce_var_widths_w _,m.body)) }}} Circuit(c.info,modulesx,c.main) } def run (c:Circuit): Circuit = { val v = ArrayBuffer[WGeq]() def constrain (w1:Width,w2:Width) : Unit = v += WGeq(w1,w2) - def get_constraints_t (t1:Type,t2:Type,f:Flip) : Unit = { + def get_constraints_t (t1:Type,t2:Type,f:Orientation) : Unit = { (t1,t2) match { case (t1:UIntType,t2:UIntType) => constrain(t1.width,t2.width) case (t1:SIntType,t2:SIntType) => constrain(t1.width,t2.width) @@ -557,32 +550,32 @@ object InferWidths extends Pass { constrain(ONE,width_BANG(e.cond)) e } case (e) => e }} - def get_constraints (s:Stmt) : Stmt = { + def get_constraints (s:Statement) : Statement = { (s map (get_constraints_e)) match { case (s:Connect) => { val n = get_size(tpe(s.loc)) val ce_loc = create_exps(s.loc) - val ce_exp = create_exps(s.exp) + val ce_exp = create_exps(s.expr) for (i <- 0 until n) { val locx = ce_loc(i) val expx = ce_exp(i) - get_flip(tpe(s.loc),i,DEFAULT) match { - case DEFAULT => constrain(width_BANG(locx),width_BANG(expx)) - case REVERSE => constrain(width_BANG(expx),width_BANG(locx)) }} + get_flip(tpe(s.loc),i,Default) match { + case Default => constrain(width_BANG(locx),width_BANG(expx)) + case Flip => constrain(width_BANG(expx),width_BANG(locx)) }} s } - case (s:BulkConnect) => { - val ls = get_valid_points(tpe(s.loc),tpe(s.exp),DEFAULT,DEFAULT) + case (s:PartialConnect) => { + val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) for (x <- ls) { val locx = create_exps(s.loc)(x._1) - val expx = create_exps(s.exp)(x._2) - get_flip(tpe(s.loc),x._1,DEFAULT) match { - case DEFAULT => constrain(width_BANG(locx),width_BANG(expx)) - case REVERSE => constrain(width_BANG(expx),width_BANG(locx)) }} + val expx = create_exps(s.expr)(x._2) + get_flip(tpe(s.loc),x._1,Default) match { + case Default => constrain(width_BANG(locx),width_BANG(expx)) + case Flip => constrain(width_BANG(expx),width_BANG(locx)) }} s } case (s:DefRegister) => { constrain(width_BANG(s.reset),ONE) constrain(ONE,width_BANG(s.reset)) - get_constraints_t(s.tpe,tpe(s.init),DEFAULT) + get_constraints_t(s.tpe,tpe(s.init),Default) s } case (s:Conditionally) => { v += WGeq(width_BANG(s.pred),ONE) @@ -592,7 +585,7 @@ object InferWidths extends Pass { for (m <- c.modules) { (m) match { - case (m:InModule) => mname = m.name; get_constraints(m.body) + case (m:Module) => mname = m.name; get_constraints(m.body) case (m) => false }} //println-debug("======== ALL CONSTRAINTS ========") //for x in v do : println-debug(x) @@ -639,13 +632,13 @@ object PullMuxes extends Pass { } ex map (pull_muxes_e) } - def pull_muxes (s:Stmt) : Stmt = s map (pull_muxes) map (pull_muxes_e) + def pull_muxes (s:Statement) : Statement = s map (pull_muxes) map (pull_muxes_e) val modulesx = c.modules.map { m => { mname = m.name m match { - case (m:InModule) => InModule(m.info,m.name,m.ports,pull_muxes(m.body)) - case (m:ExModule) => m + case (m:Module) => Module(m.info,m.name,m.ports,pull_muxes(m.body)) + case (m:ExtModule) => m } } } @@ -657,10 +650,10 @@ object ExpandConnects extends Pass { private var mname = "" def name = "Expand Connects" def run (c:Circuit): Circuit = { - def expand_connects (m:InModule) : InModule = { + def expand_connects (m:Module) : Module = { mname = m.name val genders = LinkedHashMap[String,Gender]() - def expand_s (s:Stmt) : Stmt = { + def expand_s (s:Statement) : Statement = { def set_gender (e:Expression) : Expression = { e map (set_gender) match { case (e:WRef) => WRef(e.name,e.tpe,e.kind,genders(e.name)) @@ -679,12 +672,11 @@ object ExpandConnects extends Pass { case (s:DefRegister) => { genders(s.name) = BIGENDER; s } case (s:WDefInstance) => { genders(s.name) = MALE; s } case (s:DefMemory) => { genders(s.name) = MALE; s } - case (s:DefPoison) => { genders(s.name) = MALE; s } case (s:DefNode) => { genders(s.name) = MALE; s } case (s:IsInvalid) => { - val n = get_size(tpe(s.exp)) - val invalids = ArrayBuffer[Stmt]() - val exps = create_exps(s.exp) + val n = get_size(tpe(s.expr)) + val invalids = ArrayBuffer[Statement]() + val exps = create_exps(s.expr) for (i <- 0 until n) { val expx = exps(i) val gexpx = set_gender(expx) @@ -695,38 +687,38 @@ object ExpandConnects extends Pass { } } if (invalids.length == 0) { - Empty() + EmptyStmt } else if (invalids.length == 1) { invalids(0) } else Begin(invalids) } case (s:Connect) => { val n = get_size(tpe(s.loc)) - val connects = ArrayBuffer[Stmt]() + val connects = ArrayBuffer[Statement]() val locs = create_exps(s.loc) - val exps = create_exps(s.exp) + val exps = create_exps(s.expr) for (i <- 0 until n) { val locx = locs(i) val expx = exps(i) - val sx = get_flip(tpe(s.loc),i,DEFAULT) match { - case DEFAULT => Connect(s.info,locx,expx) - case REVERSE => Connect(s.info,expx,locx) + val sx = get_flip(tpe(s.loc),i,Default) match { + case Default => Connect(s.info,locx,expx) + case Flip => Connect(s.info,expx,locx) } connects += sx } Begin(connects) } - case (s:BulkConnect) => { - val ls = get_valid_points(tpe(s.loc),tpe(s.exp),DEFAULT,DEFAULT) - val connects = ArrayBuffer[Stmt]() + case (s:PartialConnect) => { + val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) + val connects = ArrayBuffer[Statement]() val locs = create_exps(s.loc) - val exps = create_exps(s.exp) + val exps = create_exps(s.expr) ls.foreach { x => { val locx = locs(x._1) val expx = exps(x._2) - val sx = get_flip(tpe(s.loc),x._1,DEFAULT) match { - case DEFAULT => Connect(s.info,locx,expx) - case REVERSE => Connect(s.info,expx,locx) + val sx = get_flip(tpe(s.loc),x._1,Default) match { + case Default => Connect(s.info,locx,expx) + case Flip => Connect(s.info,expx,locx) } connects += sx }} @@ -737,14 +729,14 @@ object ExpandConnects extends Pass { } m.ports.foreach { p => genders(p.name) = to_gender(p.direction) } - InModule(m.info,m.name,m.ports,expand_s(m.body)) + Module(m.info,m.name,m.ports,expand_s(m.body)) } val modulesx = c.modules.map { m => { m match { - case (m:ExModule) => m - case (m:InModule) => expand_connects(m) + case (m:ExtModule) => m + case (m:Module) => expand_connects(m) } } } @@ -816,11 +808,11 @@ object RemoveAccesses extends Pass { ret } def run (c:Circuit): Circuit = { - def remove_m (m:InModule) : InModule = { + def remove_m (m:Module) : Module = { val namespace = Namespace(m) mname = m.name - def remove_s (s:Stmt) : Stmt = { - val stmts = ArrayBuffer[Stmt]() + def remove_s (s:Statement) : Statement = { + val stmts = ArrayBuffer[Statement]() def create_temp (e:Expression) : Expression = { val n = namespace.newTemp stmts += DefWire(info(s),n,tpe(e)) @@ -831,8 +823,8 @@ object RemoveAccesses extends Pass { case (e:DoPrim) => e map (remove_e) case (e:Mux) => e map (remove_e) case (e:ValidIf) => e map (remove_e) - case (e:SIntValue) => e - case (e:UIntValue) => e + case (e:SIntLiteral) => e + case (e:UIntLiteral) => e case x => { val e = x match { case (w:WSubAccess) => WSubAccess(w.exp,remove_e(w.index),w.tpe,w.gender) @@ -852,7 +844,7 @@ object RemoveAccesses extends Pass { if (i < temps.size) { stmts += Connect(info(s),get_temp(i),x.base) } else { - stmts += Conditionally(info(s),x.guard,Connect(info(s),get_temp(i),x.base),Empty()) + stmts += Conditionally(info(s),x.guard,Connect(info(s),get_temp(i),x.base),EmptyStmt) } } } @@ -872,25 +864,25 @@ object RemoveAccesses extends Pass { if (ls.size == 1 & weq(ls(0).guard,one)) s.loc else { val temp = create_temp(s.loc) - for (x <- ls) { stmts += Conditionally(s.info,x.guard,Connect(s.info,x.base,temp),Empty()) } + for (x <- ls) { stmts += Conditionally(s.info,x.guard,Connect(s.info,x.base,temp),EmptyStmt) } temp } - Connect(s.info,locx,remove_e(s.exp)) - } else { Connect(s.info,s.loc,remove_e(s.exp)) } + Connect(s.info,locx,remove_e(s.expr)) + } else { Connect(s.info,s.loc,remove_e(s.expr)) } } case (s) => s map (remove_e) map (remove_s) } stmts += sx if (stmts.size != 1) Begin(stmts) else stmts(0) } - InModule(m.info,m.name,m.ports,remove_s(m.body)) + Module(m.info,m.name,m.ports,remove_s(m.body)) } val modulesx = c.modules.map{ m => { m match { - case (m:ExModule) => m - case (m:InModule) => remove_m(m) + case (m:ExtModule) => m + case (m:Module) => remove_m(m) } } } @@ -903,15 +895,15 @@ object RemoveAccesses extends Pass { object Legalize extends Pass { def name = "Legalize" def legalizeShiftRight (e: DoPrim): Expression = e.op match { - case SHIFT_RIGHT_OP => { + case Shr => { val amount = e.consts(0).toInt val width = long_BANG(tpe(e.args(0))) lazy val msb = width - 1 if (amount >= width) { e.tpe match { - case t: UIntType => UIntValue(0, IntWidth(1)) + case t: UIntType => UIntLiteral(0, IntWidth(1)) case t: SIntType => - DoPrim(BITS_SELECT_OP, e.args, Seq(msb, msb), SIntType(IntWidth(1))) + DoPrim(Bits, e.args, Seq(msb, msb), SIntType(IntWidth(1))) case t => error(s"Unsupported type ${t} for Primop Shift Right") } } else { @@ -920,16 +912,16 @@ object Legalize extends Pass { } case _ => e } - def legalizeConnect(c: Connect): Stmt = { + def legalizeConnect(c: Connect): Statement = { val t = tpe(c.loc) val w = long_BANG(t) - if (w >= long_BANG(tpe(c.exp))) c + if (w >= long_BANG(tpe(c.expr))) c else { val newType = t match { case _: UIntType => UIntType(IntWidth(w)) case _: SIntType => SIntType(IntWidth(w)) } - Connect(c.info, c.loc, DoPrim(BITS_SELECT_OP, Seq(c.exp), Seq(w-1, 0), newType)) + Connect(c.info, c.loc, DoPrim(Bits, Seq(c.expr), Seq(w-1, 0), newType)) } } def run (c: Circuit): Circuit = { @@ -939,14 +931,14 @@ object Legalize extends Pass { case e => e } } - def legalizeS (s: Stmt): Stmt = { + def legalizeS (s: Statement): Statement = { val legalizedStmt = s match { case c: Connect => legalizeConnect(c) case _ => s } legalizedStmt map legalizeS map legalizeE } - def legalizeM (m: Module): Module = m map (legalizeS) + def legalizeM (m: DefModule): DefModule = m map (legalizeS) Circuit(c.info, c.modules.map(legalizeM), c.main) } } @@ -958,11 +950,11 @@ object VerilogWrap extends Pass { e map (v_wrap_e) match { case (e:DoPrim) => { def a0 () = e.args(0) - if (e.op == TAIL_OP) { + if (e.op == Tail) { (a0()) match { case (e0:DoPrim) => { - if (e0.op == ADD_OP) DoPrim(ADDW_OP,e0.args,Seq(),tpe(e)) - else if (e0.op == SUB_OP) DoPrim(SUBW_OP,e0.args,Seq(),tpe(e)) + if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),tpe(e)) + else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),tpe(e)) else e } case (e0) => e @@ -973,7 +965,7 @@ object VerilogWrap extends Pass { case (e) => e } } - def v_wrap_s (s:Stmt) : Stmt = { + def v_wrap_s (s:Statement) : Statement = { s map (v_wrap_s) map (v_wrap_e) match { case s: Print => Print(s.info, VerilogStringLitHandler.format(s.string), s.args, s.clk, s.en) @@ -983,11 +975,11 @@ object VerilogWrap extends Pass { def run (c:Circuit): Circuit = { val modulesx = c.modules.map{ m => { (m) match { - case (m:InModule) => { + case (m:Module) => { mname = m.name - InModule(m.info,m.name,m.ports,v_wrap_s(m.body)) + Module(m.info,m.name,m.ports,v_wrap_s(m.body)) } - case (m:ExModule) => m + case (m:ExtModule) => m } }} Circuit(c.info,modulesx,c.main) @@ -1006,7 +998,7 @@ object VerilogRename extends Pass { case (e) => e map (verilog_rename_e) } } - def verilog_rename_s (s:Stmt) : Stmt = { + def verilog_rename_s (s:Statement) : Statement = { s map (verilog_rename_s) map (verilog_rename_e) map (verilog_rename_n) } val modulesx = c.modules.map{ m => { @@ -1014,8 +1006,8 @@ object VerilogRename extends Pass { Port(p.info,verilog_rename_n(p.name),p.direction,p.tpe) }} m match { - case (m:InModule) => InModule(m.info,m.name,portsx,verilog_rename_s(m.body)) - case (m:ExModule) => m + case (m:Module) => Module(m.info,m.name,portsx,verilog_rename_s(m.body)) + case (m:ExtModule) => m } }} Circuit(c.info,modulesx,c.main) @@ -1025,55 +1017,54 @@ object VerilogRename extends Pass { object CInferTypes extends Pass { def name = "CInfer Types" var mname = "" - def set_type (s:Stmt,t:Type) : Stmt = { + def set_type (s:Statement, t:Type) : Statement = { (s) match { case (s:DefWire) => DefWire(s.info,s.name,t) case (s:DefRegister) => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init) case (s:CDefMemory) => CDefMemory(s.info,s.name,t,s.size,s.seq) case (s:CDefMPort) => CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction) case (s:DefNode) => s - case (s:DefPoison) => DefPoison(s.info,s.name,t) } } def to_field (p:Port) : Field = { - if (p.direction == OUTPUT) Field(p.name,DEFAULT,p.tpe) - else if (p.direction == INPUT) Field(p.name,REVERSE,p.tpe) - else error("Shouldn't be here"); Field(p.name,REVERSE,p.tpe) + if (p.direction == Output) Field(p.name,Default,p.tpe) + else if (p.direction == Input) Field(p.name,Flip,p.tpe) + else error("Shouldn't be here"); Field(p.name,Flip,p.tpe) } - def module_type (m:Module) : Type = BundleType(m.ports.map(p => to_field(p))) + def module_type (m:DefModule) : Type = BundleType(m.ports.map(p => to_field(p))) def field_type (v:Type,s:String) : Type = { (v) match { case (v:BundleType) => { val ft = v.fields.find(p => p.name == s) if (ft != None) ft.get.tpe - else UnknownType() + else UnknownType } - case (v) => UnknownType() + case (v) => UnknownType } } def sub_type (v:Type) : Type = (v) match { case (v:VectorType) => v.tpe - case (v) => UnknownType() + case (v) => UnknownType } def run (c:Circuit) : Circuit = { val module_types = LinkedHashMap[String,Type]() - def infer_types (m:Module) : Module = { + def infer_types (m:DefModule) : DefModule = { val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { (e map (infer_types_e)) match { - case (e:Ref) => Ref(e.name, types.getOrElse(e.name,UnknownType())) - case (e:SubField) => SubField(e.exp,e.name,field_type(tpe(e.exp),e.name)) - case (e:SubIndex) => SubIndex(e.exp,e.value,sub_type(tpe(e.exp))) - case (e:SubAccess) => SubAccess(e.exp,e.index,sub_type(tpe(e.exp))) + case (e:Reference) => Reference(e.name, types.getOrElse(e.name,UnknownType)) + case (e:SubField) => SubField(e.expr,e.name,field_type(tpe(e.expr),e.name)) + case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(tpe(e.expr))) + case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(tpe(e.expr))) case (e:DoPrim) => set_primop_type(e) case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval)) case (e:ValidIf) => ValidIf(e.cond,e.value,tpe(e.value)) - case (_:UIntValue|_:SIntValue) => e + case (_:UIntLiteral | _:SIntLiteral) => e } } - def infer_types_s (s:Stmt) : Stmt = { + def infer_types_s (s:Statement) : Statement = { (s) match { case (s:DefRegister) => { types(s.name) = s.tpe @@ -1084,10 +1075,6 @@ object CInferTypes extends Pass { types(s.name) = s.tpe s } - case (s:DefPoison) => { - types(s.name) = s.tpe - s - } case (s:DefNode) => { val sx = s map (infer_types_e) val t = get_type(sx) @@ -1099,7 +1086,7 @@ object CInferTypes extends Pass { s } case (s:CDefMPort) => { - val t = types.getOrElse(s.mem,UnknownType()) + val t = types.getOrElse(s.mem,UnknownType) types(s.name) = t CDefMPort(s.info,s.name,t,s.mem,s.exps,s.direction) } @@ -1108,7 +1095,7 @@ object CInferTypes extends Pass { s } case (s:DefInstance) => { - types(s.name) = module_types.getOrElse(s.module,UnknownType()) + types(s.name) = module_types.getOrElse(s.module,UnknownType) s } case (s) => s map(infer_types_s) map (infer_types_e) @@ -1118,8 +1105,8 @@ object CInferTypes extends Pass { types(p.name) = p.tpe } (m) match { - case (m:InModule) => InModule(m.info,m.name,m.ports,infer_types_s(m.body)) - case (m:ExModule) => m + case (m:Module) => Module(m.info,m.name,m.ports,infer_types_s(m.body)) + case (m:ExtModule) => m } } @@ -1136,11 +1123,11 @@ object CInferMDir extends Pass { def name = "CInfer MDir" var mname = "" def run (c:Circuit) : Circuit = { - def infer_mdir (m:Module) : Module = { + def infer_mdir (m:DefModule) : DefModule = { val mports = LinkedHashMap[String,MPortDir]() def infer_mdir_e (dir:MPortDir)(e:Expression) : Expression = { (e map (infer_mdir_e(dir))) match { - case (e:Ref) => { + case (e:Reference) => { if (mports.contains(e.name)) { val new_mport_dir = { (mports(e.name),dir) match { @@ -1169,26 +1156,26 @@ object CInferMDir extends Pass { case (e) => e } } - def infer_mdir_s (s:Stmt) : Stmt = { + def infer_mdir_s (s:Statement) : Statement = { (s) match { case (s:CDefMPort) => { mports(s.name) = s.direction s map (infer_mdir_e(MRead)) } case (s:Connect) => { - infer_mdir_e(MRead)(s.exp) + infer_mdir_e(MRead)(s.expr) infer_mdir_e(MWrite)(s.loc) s } - case (s:BulkConnect) => { - infer_mdir_e(MRead)(s.exp) + case (s:PartialConnect) => { + infer_mdir_e(MRead)(s.expr) infer_mdir_e(MWrite)(s.loc) s } case (s) => s map (infer_mdir_s) map (infer_mdir_e(MRead)) } } - def set_mdir_s (s:Stmt) : Stmt = { + def set_mdir_s (s:Statement) : Statement = { (s) match { case (s:CDefMPort) => CDefMPort(s.info,s.name,s.tpe,s.mem,s.exps,mports(s.name)) @@ -1196,11 +1183,11 @@ object CInferMDir extends Pass { } } (m) match { - case (m:InModule) => { + case (m:Module) => { infer_mdir_s(m.body) - InModule(m.info,m.name,m.ports,set_mdir_s(m.body)) + Module(m.info,m.name,m.ports,set_mdir_s(m.body)) } - case (m:ExModule) => m + case (m:ExtModule) => m } } @@ -1227,23 +1214,23 @@ object RemoveCHIRRTL extends Pass { ValidIf(e.cond,e1,tpe(e1)) }) case (e) => (tpe(e)) match { - case (_:UIntType|_:SIntType|_:ClockType) => Seq(e) + case (_:UIntType|_:SIntType|ClockType) => Seq(e) case (t:BundleType) => t.fields.flatMap(f => create_exps(SubField(e,f.name,f.tpe))) case (t:VectorType)=> (0 until t.size).flatMap(i => create_exps(SubIndex(e,i,t.tpe))) - case (t:UnknownType) => Seq(e) + case UnknownType => Seq(e) } } } def run (c:Circuit) : Circuit = { - def remove_chirrtl_m (m:InModule) : InModule = { + def remove_chirrtl_m (m:Module) : Module = { val hash = LinkedHashMap[String,MPorts]() val repl = LinkedHashMap[String,DataRef]() - val ut = UnknownType() + val ut = UnknownType val mport_types = LinkedHashMap[String,Type]() def EMPs () : MPorts = MPorts(ArrayBuffer[MPort](),ArrayBuffer[MPort](),ArrayBuffer[MPort]()) - def collect_mports (s:Stmt) : Stmt = { + def collect_mports (s:Statement) : Statement = { (s) match { case (s:CDefMPort) => { val mports = hash.getOrElse(s.mem,EMPs()) @@ -1258,32 +1245,32 @@ object RemoveCHIRRTL extends Pass { case (s) => s map (collect_mports) } } - def collect_refs (s:Stmt) : Stmt = { + def collect_refs (s:Statement) : Statement = { (s) match { case (s:CDefMemory) => { mport_types(s.name) = s.tpe - val stmts = ArrayBuffer[Stmt]() + val stmts = ArrayBuffer[Statement]() val taddr = UIntType(IntWidth(scala.math.max(1,ceil_log2(s.size)))) val tdata = s.tpe def set_poison (vec:Seq[MPort],addr:String) : Unit = { for (r <- vec ) { - stmts += IsInvalid(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),addr,taddr)) - stmts += IsInvalid(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),"clk",taddr)) + stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),addr,taddr)) + stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),"clk",taddr)) } } def set_enable (vec:Seq[MPort],en:String) : Unit = { for (r <- vec ) { - stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),en,taddr),zero) + stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),en,taddr),zero) }} def set_wmode (vec:Seq[MPort],wmode:String) : Unit = { for (r <- vec) { - stmts += Connect(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),wmode,taddr),zero) + stmts += Connect(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),wmode,taddr),zero) }} def set_write (vec:Seq[MPort],data:String,mask:String) : Unit = { val tmask = create_mask(s.tpe) for (r <- vec ) { - stmts += IsInvalid(s.info,SubField(SubField(Ref(s.name,ut),r.name,ut),data,tdata)) - for (x <- create_exps(SubField(SubField(Ref(s.name,ut),r.name,ut),mask,tmask)) ) { + stmts += IsInvalid(s.info,SubField(SubField(Reference(s.name,ut),r.name,ut),data,tdata)) + for (x <- create_exps(SubField(SubField(Reference(s.name,ut),r.name,ut),mask,tmask)) ) { stmts += Connect(s.info,x,zero) }}} val rds = (hash.getOrElse(s.name,EMPs())).readers @@ -1310,47 +1297,47 @@ object RemoveCHIRRTL extends Pass { val masks = ArrayBuffer[String]() s.direction match { case MReadWrite => { - repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"rdata","data","mask",true) + repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"rdata","data","mask",true) addrs += "addr" clks += "clk" ens += "en" masks += "mask" } case MWrite => { - repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"data","data","mask",false) + repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","mask",false) addrs += "addr" clks += "clk" ens += "en" masks += "mask" } case _ => { - repl(s.name) = DataRef(SubField(Ref(s.mem,ut),s.name,ut),"data","data","blah",false) + repl(s.name) = DataRef(SubField(Reference(s.mem,ut),s.name,ut),"data","data","blah",false) addrs += "addr" clks += "clk" ens += "en" } } - val stmts = ArrayBuffer[Stmt]() + val stmts = ArrayBuffer[Statement]() for (x <- addrs ) { - stmts += Connect(s.info,SubField(SubField(Ref(s.mem,ut),s.name,ut),x,ut),s.exps(0)) + stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(0)) } for (x <- clks ) { - stmts += Connect(s.info,SubField(SubField(Ref(s.mem,ut),s.name,ut),x,ut),s.exps(1)) + stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),s.exps(1)) } for (x <- ens ) { - stmts += Connect(s.info,SubField(SubField(Ref(s.mem,ut),s.name,ut),x,ut),one) + stmts += Connect(s.info,SubField(SubField(Reference(s.mem,ut),s.name,ut),x,ut),one) } Begin(stmts) } case (s) => s map (collect_refs) } } - def remove_chirrtl_s (s:Stmt) : Stmt = { + def remove_chirrtl_s (s:Statement) : Statement = { var has_write_mport = false var has_readwrite_mport:Option[Expression] = None def remove_chirrtl_e (g:Gender)(e:Expression) : Expression = { (e) match { - case (e:Ref) => { + case (e:Reference) => { if (repl.contains(e.name)) { val vt = repl(e.name) g match { @@ -1364,13 +1351,13 @@ object RemoveCHIRRTL extends Pass { } } else e } - case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.exp),remove_chirrtl_e(MALE)(e.index),e.tpe) + case (e:SubAccess) => SubAccess(remove_chirrtl_e(g)(e.expr),remove_chirrtl_e(MALE)(e.index),e.tpe) case (e) => e map (remove_chirrtl_e(g)) } } def get_mask (e:Expression) : Expression = { (e map (get_mask)) match { - case (e:Ref) => { + case (e:Reference) => { if (repl.contains(e.name)) { val vt = repl(e.name) val t = create_mask(e.tpe) @@ -1382,8 +1369,8 @@ object RemoveCHIRRTL extends Pass { } (s) match { case (s:Connect) => { - val stmts = ArrayBuffer[Stmt]() - val rocx = remove_chirrtl_e(MALE)(s.exp) + val stmts = ArrayBuffer[Statement]() + val rocx = remove_chirrtl_e(MALE)(s.expr) val locx = remove_chirrtl_e(FEMALE)(s.loc) stmts += Connect(s.info,locx,rocx) if (has_write_mport) { @@ -1399,13 +1386,13 @@ object RemoveCHIRRTL extends Pass { if (stmts.size > 1) Begin(stmts) else stmts(0) } - case (s:BulkConnect) => { - val stmts = ArrayBuffer[Stmt]() + case (s:PartialConnect) => { + val stmts = ArrayBuffer[Statement]() val locx = remove_chirrtl_e(FEMALE)(s.loc) - val rocx = remove_chirrtl_e(MALE)(s.exp) - stmts += BulkConnect(s.info,locx,rocx) + val rocx = remove_chirrtl_e(MALE)(s.expr) + stmts += PartialConnect(s.info,locx,rocx) if (has_write_mport != false) { - val ls = get_valid_points(tpe(s.loc),tpe(s.exp),DEFAULT,DEFAULT) + val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) val locs = create_exps(get_mask(s.loc)) for (x <- ls ) { val locx = locs(x._1) @@ -1424,12 +1411,12 @@ object RemoveCHIRRTL extends Pass { } collect_mports(m.body) val sx = collect_refs(m.body) - InModule(m.info,m.name, m.ports, remove_chirrtl_s(sx)) + Module(m.info,m.name, m.ports, remove_chirrtl_s(sx)) } val modulesx = c.modules.map{ m => { (m) match { - case (m:InModule) => remove_chirrtl_m(m) - case (m:ExModule) => m + case (m:Module) => remove_chirrtl_m(m) + case (m:ExtModule) => m }}} Circuit(c.info,modulesx, c.main) } diff --git a/src/main/scala/firrtl/passes/RemoveValidIf.scala b/src/main/scala/firrtl/passes/RemoveValidIf.scala index 4bc6162a..a534cc50 100644 --- a/src/main/scala/firrtl/passes/RemoveValidIf.scala +++ b/src/main/scala/firrtl/passes/RemoveValidIf.scala @@ -1,6 +1,7 @@ package firrtl package passes import firrtl.Mappers.{ExpMap, StmtMap} +import firrtl.ir._ // Removes ValidIf as an optimization object RemoveValidIf extends Pass { @@ -13,12 +14,12 @@ object RemoveValidIf extends Pass { } } // Recursive. - private def onStmt(s: Stmt): Stmt = s map onStmt map onExp + private def onStmt(s: Statement): Statement = s map onStmt map onExp - private def onModule(m: Module): Module = { + private def onModule(m: DefModule): DefModule = { m match { - case m:InModule => InModule(m.info, m.name, m.ports, onStmt(m.body)) - case m:ExModule => m + case m: Module => Module(m.info, m.name, m.ports, onStmt(m.body)) + case m: ExtModule => m } } diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index a66f7152..973e1be9 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -3,6 +3,7 @@ package passes import firrtl.Mappers.{ExpMap, StmtMap} import firrtl.Utils.{tpe, kind, gender, info} +import firrtl.ir._ import scala.collection.mutable @@ -10,10 +11,10 @@ import scala.collection.mutable // and named intermediate nodes object SplitExpressions extends Pass { def name = "Split Expressions" - private def onModule(m: InModule): InModule = { + private def onModule(m: Module): Module = { val namespace = Namespace(m) - def onStmt(s: Stmt): Stmt = { - val v = mutable.ArrayBuffer[Stmt]() + def onStmt(s: Statement): Statement = { + val v = mutable.ArrayBuffer[Statement]() // Splits current expression if needed // Adds named temporaries to v def split(e: Expression): Expression = e match { @@ -45,7 +46,7 @@ object SplitExpressions extends Pass { val x = s map onExp x match { case x: Begin => x map onStmt - case x: Empty => x + case EmptyStmt => x case x => { v += x if (v.size > 1) Begin(v.toVector) @@ -53,12 +54,12 @@ object SplitExpressions extends Pass { } } } - InModule(m.info, m.name, m.ports, onStmt(m.body)) + Module(m.info, m.name, m.ports, onStmt(m.body)) } def run(c: Circuit): Circuit = { val modulesx = c.modules.map( _ match { - case (m:InModule) => onModule(m) - case (m:ExModule) => m + case m: Module => onModule(m) + case m: ExtModule => m }) Circuit(c.info, modulesx, c.main) } diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index 6cec0f1d..aa2c1d5d 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -31,12 +31,13 @@ import com.typesafe.scalalogging.LazyLogging import scala.annotation.tailrec import firrtl._ +import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ /** Resolve name collisions that would occur in [[LowerTypes]] * - * @note Must be run after [[InferTypes]] because [[DefNode]]s need type + * @note Must be run after [[InferTypes]] because [[ir.DefNode]]s need type * @example * {{{ * wire a = { b, c }[2] @@ -192,7 +193,7 @@ object Uniquify extends Pass { val (subExp, subMap) = rec(e.exp, m) val index = uniquifyNamesExp(e.index, map) (WSubAccess(subExp, index, e.tpe, e.gender), subMap) - case (_: UIntValue | _: SIntValue) => (exp, m) + case (_: UIntLiteral | _: SIntLiteral) => (exp, m) case (_: Mux | _: ValidIf | _: DoPrim) => (exp map ((e: Expression) => uniquifyNamesExp(e, map)), m) } @@ -220,28 +221,28 @@ object Uniquify extends Pass { } // Creates a Bundle Type from a Stmt - def stmtToType(s: Stmt)(implicit sinfo: Info, mname: String): BundleType = { + def stmtToType(s: Statement)(implicit sinfo: Info, mname: String): BundleType = { // Recursive helper - def recStmtToType(s: Stmt): Seq[Field] = s match { - case s: DefWire => Seq(Field(s.name, DEFAULT, s.tpe)) - case s: DefRegister => Seq(Field(s.name, DEFAULT, s.tpe)) - case s: WDefInstance => Seq(Field(s.name, DEFAULT, s.tpe)) - case s: DefMemory => s.data_type match { + def recStmtToType(s: Statement): Seq[Field] = s match { + case s: DefWire => Seq(Field(s.name, Default, s.tpe)) + case s: DefRegister => Seq(Field(s.name, Default, s.tpe)) + case s: WDefInstance => Seq(Field(s.name, Default, s.tpe)) + case s: DefMemory => s.dataType match { case (_: UIntType | _: SIntType) => - Seq(Field(s.name, DEFAULT, get_type(s))) + Seq(Field(s.name, Default, get_type(s))) case tpe: BundleType => val newFields = tpe.fields map ( f => - DefMemory(s.info, f.name, f.tpe, s.depth, s.write_latency, - s.read_latency, s.readers, s.writers, s.readwriters) + DefMemory(s.info, f.name, f.tpe, s.depth, s.writeLatency, + s.readLatency, s.readers, s.writers, s.readwriters) ) flatMap (recStmtToType) - Seq(Field(s.name, DEFAULT, BundleType(newFields))) + Seq(Field(s.name, Default, BundleType(newFields))) case tpe: VectorType => val newFields = (0 until tpe.size) map ( i => - s.copy(name = i.toString, data_type = tpe.tpe) + s.copy(name = i.toString, dataType = tpe.tpe) ) flatMap (recStmtToType) - Seq(Field(s.name, DEFAULT, BundleType(newFields))) + Seq(Field(s.name, Default, BundleType(newFields))) } - case s: DefNode => Seq(Field(s.name, DEFAULT, get_type(s))) + case s: DefNode => Seq(Field(s.name, Default, get_type(s))) case s: Conditionally => recStmtToType(s.conseq) ++ recStmtToType(s.alt) case s: Begin => (s.stmts map (recStmtToType)).flatten case s => Seq() @@ -258,7 +259,7 @@ object Uniquify extends Pass { val portNameMap = collection.mutable.HashMap[String, Map[String, NameMapNode]]() val portTypeMap = collection.mutable.HashMap[String, Type]() - def uniquifyModule(m: Module): Module = { + def uniquifyModule(m: DefModule): DefModule = { val namespace = collection.mutable.HashSet[String]() val nameMap = collection.mutable.HashMap[String, NameMapNode]() @@ -267,11 +268,11 @@ object Uniquify extends Pass { uniquifyNamesExp(e, nameMap.toMap) case e: Mux => e map (uniquifyExp) case e: ValidIf => e map (uniquifyExp) - case (_: UIntValue | _: SIntValue) => e + case (_: UIntLiteral | _: SIntLiteral) => e case e: DoPrim => e map (uniquifyExp) } - def uniquifyStmt(s: Stmt): Stmt = { + def uniquifyStmt(s: Statement): Statement = { s map uniquifyStmt map uniquifyExp match { case s: DefWire => sinfo = s.info @@ -302,8 +303,8 @@ object Uniquify extends Pass { sinfo = s.info if (nameMap.contains(s.name)) { val node = nameMap(s.name) - val dataType = uniquifyNamesType(s.data_type, node.elts) - val mem = s.copy(name = node.name, data_type = dataType) + val dataType = uniquifyNamesType(s.dataType, node.elts) + val mem = s.copy(name = node.name, dataType = dataType) // Create new mapping to handle references to memory data fields val uniqueMemMap = createNameMapping(get_type(s), get_type(mem)) nameMap(s.name) = NameMapNode(node.name, node.elts ++ uniqueMemMap) @@ -323,7 +324,7 @@ object Uniquify extends Pass { } } - def uniquifyBody(s: Stmt): Stmt = { + def uniquifyBody(s: Statement): Statement = { val bodyType = stmtToType(s) val uniqueBodyType = uniquifyNames(bodyType, namespace) val localMap = createNameMapping(bodyType, uniqueBodyType) @@ -336,8 +337,8 @@ object Uniquify extends Pass { sinfo = m.info mname = m.name m match { - case m: ExModule => m - case m: InModule => + case m: ExtModule => m + case m: Module => // Adds port names to namespace and namemap nameMap ++= portNameMap(m.name) namespace ++= create_exps("", portTypeMap(m.name)) map @@ -346,7 +347,7 @@ object Uniquify extends Pass { } } - def uniquifyPorts(m: Module): Module = { + def uniquifyPorts(m: DefModule): DefModule = { def uniquifyPorts(ports: Seq[Port]): Seq[Port] = { val portsType = BundleType(ports map (_.toField)) val uniquePortsType = uniquifyNames(portsType, collection.mutable.HashSet()) @@ -362,8 +363,8 @@ object Uniquify extends Pass { sinfo = m.info mname = m.name m match { - case m: ExModule => m.copy(ports = uniquifyPorts(m.ports)) - case m: InModule => m.copy(ports = uniquifyPorts(m.ports)) + case m: ExtModule => m.copy(ports = uniquifyPorts(m.ports)) + case m: Module => m.copy(ports = uniquifyPorts(m.ports)) } } diff --git a/src/test/scala/firrtlTests/AnnotationTests.scala b/src/test/scala/firrtlTests/AnnotationTests.scala index 81a74b54..e04b4e14 100644 --- a/src/test/scala/firrtlTests/AnnotationTests.scala +++ b/src/test/scala/firrtlTests/AnnotationTests.scala @@ -6,8 +6,9 @@ import org.scalatest.FlatSpec import org.scalatest.Matchers import org.scalatest.junit.JUnitRunner -import firrtl.{Parser,Circuit} +import firrtl.ir.Circuit import firrtl.{ + Parser, Named, ModuleName, ComponentName, diff --git a/src/test/scala/firrtlTests/CheckInitializationSpec.scala b/src/test/scala/firrtlTests/CheckInitializationSpec.scala index 49d5bc08..515bbfc8 100644 --- a/src/test/scala/firrtlTests/CheckInitializationSpec.scala +++ b/src/test/scala/firrtlTests/CheckInitializationSpec.scala @@ -30,7 +30,8 @@ package firrtlTests import java.io._ import org.scalatest._ import org.scalatest.prop._ -import firrtl._ +import firrtl.Parser +import firrtl.ir.Circuit import firrtl.Parser.IgnoreInfo import firrtl.passes._ diff --git a/src/test/scala/firrtlTests/CheckSpec.scala b/src/test/scala/firrtlTests/CheckSpec.scala index 5c1b1a67..69645ddc 100644 --- a/src/test/scala/firrtlTests/CheckSpec.scala +++ b/src/test/scala/firrtlTests/CheckSpec.scala @@ -3,7 +3,8 @@ package firrtlTests import java.io._ import org.scalatest._ import org.scalatest.prop._ -import firrtl.{Parser,Circuit} +import firrtl.Parser +import firrtl.ir.Circuit import firrtl.passes.{Pass,ToWorkingIR,CheckHighForm,ResolveKinds,InferTypes,CheckTypes,PassExceptions} class CheckSpec extends FlatSpec with Matchers { diff --git a/src/test/scala/firrtlTests/ChirrtlSpec.scala b/src/test/scala/firrtlTests/ChirrtlSpec.scala index 858d43b6..d3e02ff1 100644 --- a/src/test/scala/firrtlTests/ChirrtlSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlSpec.scala @@ -30,7 +30,8 @@ package firrtlTests import java.io._ import org.scalatest._ import org.scalatest.prop._ -import firrtl.{Parser,Circuit} +import firrtl.Parser +import firrtl.ir.Circuit import firrtl.passes._ class ChirrtlSpec extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/CompilerTests.scala b/src/test/scala/firrtlTests/CompilerTests.scala index f66b39e6..ce70a992 100644 --- a/src/test/scala/firrtlTests/CompilerTests.scala +++ b/src/test/scala/firrtlTests/CompilerTests.scala @@ -6,12 +6,13 @@ import org.scalatest.FlatSpec import org.scalatest.Matchers import org.scalatest.junit.JUnitRunner -import firrtl.{Parser,Circuit} +import firrtl.ir.Circuit import firrtl.{ HighFirrtlCompiler, LowFirrtlCompiler, VerilogCompiler, - Compiler + Compiler, + Parser } /** diff --git a/src/test/scala/firrtlTests/ConstantPropagationTests.scala b/src/test/scala/firrtlTests/ConstantPropagationTests.scala index cfcb7f45..bfe58a2c 100644 --- a/src/test/scala/firrtlTests/ConstantPropagationTests.scala +++ b/src/test/scala/firrtlTests/ConstantPropagationTests.scala @@ -2,8 +2,9 @@ package firrtlTests import org.scalatest.Matchers import java.io.{StringWriter,Writer} -import firrtl._ +import firrtl.ir.Circuit import firrtl.Parser.IgnoreInfo +import firrtl.Parser import firrtl.passes._ // Tests the following cases for constant propagation: diff --git a/src/test/scala/firrtlTests/InlineInstancesTests.scala b/src/test/scala/firrtlTests/InlineInstancesTests.scala index 52c01dc4..4a9f21bc 100644 --- a/src/test/scala/firrtlTests/InlineInstancesTests.scala +++ b/src/test/scala/firrtlTests/InlineInstancesTests.scala @@ -6,9 +6,10 @@ import org.scalatest.FlatSpec import org.scalatest.Matchers import org.scalatest.junit.JUnitRunner -import firrtl.{Parser,Circuit} +import firrtl.ir.Circuit import firrtl.passes.{PassExceptions,InlineCAKind} import firrtl.{ + Parser, Named, ModuleName, ComponentName, diff --git a/src/test/scala/firrtlTests/LowerTypesSpec.scala b/src/test/scala/firrtlTests/LowerTypesSpec.scala index 736849f5..e9096139 100644 --- a/src/test/scala/firrtlTests/LowerTypesSpec.scala +++ b/src/test/scala/firrtlTests/LowerTypesSpec.scala @@ -4,7 +4,8 @@ package firrtlTests import java.io._ import org.scalatest._ import org.scalatest.prop._ -import firrtl.{Parser,Circuit} +import firrtl.Parser +import firrtl.ir.Circuit import firrtl.passes._ class LowerTypesSpec extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/PassTests.scala b/src/test/scala/firrtlTests/PassTests.scala index 38ecc7c3..efe7438c 100644 --- a/src/test/scala/firrtlTests/PassTests.scala +++ b/src/test/scala/firrtlTests/PassTests.scala @@ -4,7 +4,8 @@ import com.typesafe.scalalogging.LazyLogging import java.io.{StringWriter,Writer} import org.scalatest.{FlatSpec, Matchers} import org.scalatest.junit.JUnitRunner -import firrtl.{Parser,Circuit,FIRRTLEmitter} +import firrtl.{Parser,FIRRTLEmitter} +import firrtl.ir.Circuit import firrtl.Parser.IgnoreInfo import firrtl.passes.{Pass, PassExceptions} import firrtl.{ diff --git a/src/test/scala/firrtlTests/UniquifySpec.scala b/src/test/scala/firrtlTests/UniquifySpec.scala index 71f41074..6cc83c9c 100644 --- a/src/test/scala/firrtlTests/UniquifySpec.scala +++ b/src/test/scala/firrtlTests/UniquifySpec.scala @@ -30,7 +30,8 @@ package firrtlTests import java.io._ import org.scalatest._ import org.scalatest.prop._ -import firrtl.{Parser, Circuit} +import firrtl.Parser +import firrtl.ir.Circuit import firrtl.passes._ class UniquifySpec extends FirrtlFlatSpec { diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 98693c61..ead55755 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -31,6 +31,7 @@ import java.io._ import org.scalatest._ import org.scalatest.prop._ import firrtl._ +import firrtl.ir.Circuit import firrtl.passes._ import firrtl.Parser.IgnoreInfo |
