diff options
| author | azidar | 2016-01-30 01:02:48 -0800 |
|---|---|---|
| committer | azidar | 2016-02-09 18:55:26 -0800 |
| commit | f6917276250258091e98a51719b35cf5935ceabf (patch) | |
| tree | b9b3db517d4c69563c4adfca8f07a21e88c3d5d6 /src | |
| parent | 0181686fe4bdf24f9e22f406c43dbeb98789cb8b (diff) | |
WIP. Finished to working ir, resolve kinds, and infer types
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Driver.scala | 5 | ||||
| -rw-r--r-- | src/main/scala/firrtl/IR.scala | 72 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Passes.scala | 246 | ||||
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 422 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 954 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Visitor.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 6 |
7 files changed, 1140 insertions, 573 deletions
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index 2e525f7e..7ea0286d 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -52,7 +52,7 @@ object DriverPasses { } def optimize(passes: Seq[DriverPass]): Seq[DriverPass] = { - aggregateStanzaPasses(passes) + aggregateStanzaPasses(aggregateStanzaPasses(passes)) } } @@ -84,10 +84,11 @@ object Driver extends LazyLogging { // ===================================== StanzaPass("high-form-check"), // ===================================== - ScalaPass(toWorkingIr), + ScalaPass(resolve), // ===================================== StanzaPass("to-working-ir"), // ===================================== + StanzaPass("resolve-kinds"), StanzaPass("infer-types"), StanzaPass("check-types"), StanzaPass("resolve-genders"), diff --git a/src/main/scala/firrtl/IR.scala b/src/main/scala/firrtl/IR.scala index 6b93c763..c86be37a 100644 --- a/src/main/scala/firrtl/IR.scala +++ b/src/main/scala/firrtl/IR.scala @@ -19,38 +19,38 @@ case class FIRRTLException(str:String) extends Exception trait AST trait PrimOp extends AST -case object AddOp extends PrimOp -case object SubOp extends PrimOp -case object MulOp extends PrimOp -case object DivOp extends PrimOp -case object RemOp extends PrimOp -case object LessOp extends PrimOp -case object LessEqOp extends PrimOp -case object GreaterOp extends PrimOp -case object GreaterEqOp extends PrimOp -case object EqualOp extends PrimOp -case object NEqualOp extends PrimOp -case object PadOp extends PrimOp -case object AsUIntOp extends PrimOp -case object AsSIntOp extends PrimOp -case object AsClockOp extends PrimOp -case object ShiftLeftOp extends PrimOp -case object ShiftRightOp extends PrimOp -case object DynShiftLeftOp extends PrimOp -case object DynShiftRightOp extends PrimOp -case object ConvertOp extends PrimOp -case object NegOp extends PrimOp -case object BitNotOp extends PrimOp -case object BitAndOp extends PrimOp -case object BitOrOp extends PrimOp -case object BitXorOp extends PrimOp -case object BitAndReduceOp extends PrimOp -case object BitOrReduceOp extends PrimOp -case object BitXorReduceOp extends PrimOp -case object ConcatOp extends PrimOp -case object BitsSelectOp extends PrimOp -case object HeadOp extends PrimOp -case object TailOp extends PrimOp +case object ADD_OP extends PrimOp +case object SUB_OP extends PrimOp +case object MUL_OP extends PrimOp +case object DIV_OP extends PrimOp +case object REM_OP extends PrimOp +case object LESS_OP extends PrimOp +case object LESS_EQ_OP extends PrimOp +case object GREATER_OP extends PrimOp +case object GREATER_EQ_OP extends PrimOp +case object EQUAL_OP extends PrimOp +case object NEQUAL_OP extends PrimOp +case object PAD_OP extends PrimOp +case object AS_UINT_OP extends PrimOp +case object AS_SINT_OP extends PrimOp +case object AS_CLOCK_OP extends PrimOp +case object SHIFT_LEFT_OP extends PrimOp +case object SHIFT_RIGHT_OP extends PrimOp +case object DYN_SHIFT_LEFT_OP extends PrimOp +case object DYN_SHIFT_RIGHT_OP extends PrimOp +case object CONVERT_OP extends PrimOp +case object NEG_OP extends PrimOp +case object NOT_OP extends PrimOp +case object AND_OP extends PrimOp +case object OR_OP extends PrimOp +case object XOR_OP extends PrimOp +case object AND_REDUCE_OP extends PrimOp +case object OR_REDUCE_OP extends PrimOp +case object XOR_REDUCE_OP extends PrimOp +case object CONCAT_OP extends PrimOp +case object BITS_SELECT_OP extends PrimOp +case object HEAD_OP extends PrimOp +case object TAIL_OP extends PrimOp trait Expression extends AST case class Ref(name: String, tpe: Type) extends Expression @@ -68,8 +68,8 @@ case class DefWire(info: Info, name: String, tpe: Type) extends Stmt case class DefPoison(info: Info, name: String, tpe: Type) extends Stmt case class DefRegister(info: Info, name: String, tpe: Type, clock: Expression, reset: Expression, init: Expression) extends Stmt case class DefInstance(info: Info, name: String, module: String) extends Stmt -case class DefMemory(info: Info, name: String, dataType: Type, depth: Int, writeLatency: Int, - readLatency: Int, readers: Seq[String], writers: Seq[String], readwriters: Seq[String]) extends Stmt +case class DefMemory(info: Info, name: String, data_type: Type, depth: Int, write_latency: Int, + read_latency: Int, readers: Seq[String], writers: Seq[String], readwriters: Seq[String]) extends Stmt case class DefNode(info: Info, name: String, value: Expression) extends Stmt case class Conditionally(info: Info, pred: Expression, conseq: Stmt, alt: Stmt) extends Stmt case class Begin(stmts: Seq[Stmt]) extends Stmt @@ -82,7 +82,7 @@ case class Empty() extends Stmt trait Width extends AST case class IntWidth(width: BigInt) extends Width -case object UnknownWidth extends Width +case class UnknownWidth() extends Width trait Flip extends AST case object Default extends Flip @@ -102,7 +102,7 @@ trait Direction extends AST case object Input extends Direction case object Output extends Direction -case class Port(info: Info, name: String, dir: Direction, tpe: Type) extends AST +case class Port(info: Info, name: String, direction: Direction, tpe: Type) extends AST trait Module extends AST { val info : Info diff --git a/src/main/scala/firrtl/Passes.scala b/src/main/scala/firrtl/Passes.scala index e28205f9..29b42d54 100644 --- a/src/main/scala/firrtl/Passes.scala +++ b/src/main/scala/firrtl/Passes.scala @@ -2,6 +2,7 @@ package firrtl import com.typesafe.scalalogging.LazyLogging +import scala.collection.mutable.HashMap import Utils._ import DebugUtils._ @@ -9,29 +10,50 @@ import PrimOps._ object Passes extends LazyLogging { - // TODO Perhaps we should get rid of Logger since this map would be nice - ////private val defaultLogger = Logger() - //private def mapNameToPass = Map[String, Circuit => Circuit] ( - // "infer-types" -> inferTypes - //) - def nameToPass(name: String): Circuit => Circuit = { - //mapNameToPass.getOrElse(name, throw new Exception("No Standard FIRRTL Pass of name " + name)) - name match { - case "to-working-ir" => toWorkingIr - //case "infer-types" => inferTypes - // errrrrrrrrrr... - //case "renameall" => renameall(Map()) - } - } - - private def toField(p: Port): Field = { - logger.debug(s"toField called on port ${p.serialize}") - p.dir match { - case Input => Field(p.name, Reverse, p.tpe) - case Output => Field(p.name, Default, p.tpe) - } - } - + // TODO Perhaps we should get rid of Logger since this map would be nice + ////private val defaultLogger = Logger() + //private def mapNameToPass = Map[String, Circuit => Circuit] ( + // "infer-types" -> inferTypes + //) + def nameToPass(name: String): Circuit => Circuit = { + //mapNameToPass.getOrElse(name, throw new Exception("No Standard FIRRTL Pass of name " + name)) + name match { + case "to-working-ir" => toWorkingIr + //case "infer-types" => inferTypes + // errrrrrrrrrr... + //case "renameall" => renameall(Map()) + } + } + + private def toField(p: Port): Field = { + logger.debug(s"toField called on port ${p.serialize}") + p.direction match { + case Input => Field(p.name, Reverse, p.tpe) + case Output => Field(p.name, Default, p.tpe) + } + } + // ============== RESOLVE ALL =================== + def resolve (c:Circuit) = { + val passes = Seq( + toWorkingIr _, + resolveKinds _, + inferTypes _) + val names = Seq( + "To Working IR", + "Resolve Kinds", + "Infer Types") + var c_BANG = c + (names, passes).zipped.foreach { + (n,p) => { + println("Starting " + n) + c_BANG = p(c_BANG) + println("Finished " + n) + } + } + c_BANG + } + + // ============== TO WORKING IR ================== def toWorkingIr (c:Circuit) = { def toExp (e:Expression) : Expression = { @@ -55,10 +77,186 @@ object Passes extends LazyLogging { case m:ExModule => m } } - Circuit(c.info,modulesx,c.main) + println("Before To Working IR") + println(c.serialize()) + val x = Circuit(c.info,modulesx,c.main) + println("After To Working IR") + println(x.serialize()) + x + } + + // =============================================== + + // ============== RESOLVE KINDS ================== + def resolveKinds (c:Circuit) = { + def resolve_kinds (m:Module, c:Circuit):Module = { + val kinds = HashMap[String,Kind]() + def resolve (body:Stmt) = { + def resolve_expr (e:Expression):Expression = { + e match { + case e:WRef => WRef(e.name,tpe(e),kinds(e.name),e.gender) + case e => eMap(resolve_expr,e) + } + } + def resolve_stmt (s:Stmt):Stmt = eMap(resolve_expr,sMap(resolve_stmt,s)) + resolve_stmt(body) + } + + def find (m:Module) = { + def find_stmt (s:Stmt):Stmt = { + s match { + case s:DefWire => kinds += (s.name -> WireKind()) + case s:DefPoison => kinds += (s.name -> PoisonKind()) + case s:DefNode => kinds += (s.name -> NodeKind()) + case s:DefRegister => kinds += (s.name -> RegKind()) + case s:WDefInstance => kinds += (s.name -> InstanceKind()) + case s:DefMemory => kinds += (s.name -> MemKind(s.readers ++ s.writers ++ s.readwriters)) + case s => false + } + sMap(find_stmt,s) + } + m.ports.foreach { p => kinds += (p.name -> PortKind()) } + println(kinds) + m match { + case m:InModule => find_stmt(m.body) + case m:ExModule => false + } + } + + find(m) + m match { + case m:InModule => { + val bodyx = resolve(m.body) + InModule(m.info,m.name,m.ports,bodyx) + } + case m:ExModule => ExModule(m.info,m.name,m.ports) + } + } + val modulesx = c.modules.map(m => resolve_kinds(m,c)) + println("Before Resolve Kinds") + println(c.serialize()) + val x = Circuit(c.info,modulesx,c.main) + println("After Resolve Kinds") + println(x.serialize()) + x } // =============================================== + // ============== INFER TYPES ================== + + // ------------------ Utils ------------------------- + + val width_name_hash = Map[String,Int]() + def set_type (s:Stmt,t:Type) : Stmt = { + s match { + case s:DefWire => DefWire(s.info,s.name,t) + case s:DefRegister => DefRegister(s.info,s.name,t,s.clock,s.reset,s.init) + case s:DefMemory => DefMemory(s.info,s.name,t,s.depth,s.write_latency,s.read_latency,s.readers,s.writers,s.readwriters) + case s:DefNode => s + case s:DefPoison => DefPoison(s.info,s.name,t) + } + } + def remove_unknowns_w (w:Width):Width = { + w match { + case w:UnknownWidth => VarWidth(firrtl_gensym("w",width_name_hash)) + case w => w + } + } + def remove_unknowns (t:Type): Type = mapr(remove_unknowns_w _,t) + def mapr (f: Width => Width, t:Type) : Type = { + def apply_t (t:Type) : Type = { + wMap(f,tMap(apply_t _,t)) + } + apply_t(t) + } + + + + // ------------------ Pass ------------------------- + + def inferTypes (c:Circuit) : Circuit = { + val module_types = HashMap[String,Type]() + def infer_types (m:Module) : Module = { + val types = HashMap[String,Type]() + def infer_types_e (e:Expression) : Expression = { + eMap(infer_types_e _,e) match { + case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value)) + case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender) + case e:WSubField => WSubField(e.exp,e.name,field_type(tpe(e.exp),e.name),e.gender) + case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(tpe(e.exp)),e.gender) + case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(tpe(e.exp)),e.gender) + case e:DoPrim => set_primop_type(e) + case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval)) + case e:UIntValue => e + case e:SIntValue => e + } + } + def infer_types_s (s:Stmt) : Stmt = { + s match { + case s:DefRegister => { + val t = remove_unknowns(get_type(s)) + types += (s.name -> t) + eMap(infer_types_e _,set_type(s,t)) + } + case s:DefWire => { + val sx = eMap(infer_types_e _,s) + val t = remove_unknowns(get_type(sx)) + types += (s.name -> t) + set_type(sx,t) + } + case s:DefPoison => { + val sx = eMap(infer_types_e _,s) + val t = remove_unknowns(get_type(sx)) + types += (s.name -> t) + set_type(sx,t) + } + case s:DefNode => { + val sx = eMap(infer_types_e _,s) + val t = remove_unknowns(get_type(sx)) + types += (s.name -> t) + set_type(sx,t) + } + case s:DefMemory => { + val t = remove_unknowns(get_type(s)) + types += (s.name -> t) + val dt = remove_unknowns(s.data_type) + set_type(s,dt) + } + case s:WDefInstance => { + types += (s.name -> module_types(s.module)) + WDefInstance(s.info,s.name,s.module,module_types(s.module)) + } + case s => eMap(infer_types_e _,sMap(infer_types_s,s)) + } + } + + m.ports.foreach(p => types += (p.name -> p.tpe)) + m match { + case m:InModule => InModule(m.info,m.name,m.ports,infer_types_s(m.body)) + case m:ExModule => m + } + } + + + // MAIN + val modulesx = c.modules.map { + m => { + val portsx = m.ports.map(p => Port(p.info,p.name,p.direction,remove_unknowns(p.tpe))) + m match { + case m:InModule => InModule(m.info,m.name,portsx,m.body) + case m:ExModule => ExModule(m.info,m.name,portsx) + } + } + } + + modulesx.foreach(m => module_types += (m.name -> module_type(m))) + println("Before Infer Types") + println(c.serialize()) + val x = Circuit(c.info,modulesx.map(m => infer_types(m)) , c.main ) + println("After Infer Types") + println(x.serialize()) + x + } /** INFER TYPES * diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index 4732e756..56d1053c 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -9,38 +9,38 @@ import DebugUtils._ object PrimOps extends LazyLogging { private val mapPrimOp2String = Map[PrimOp, String]( - AddOp -> "add", - SubOp -> "sub", - MulOp -> "mul", - DivOp -> "div", - RemOp -> "rem", - LessOp -> "lt", - LessEqOp -> "leq", - GreaterOp -> "gt", - GreaterEqOp -> "geq", - EqualOp -> "eq", - NEqualOp -> "neq", - PadOp -> "pad", - AsUIntOp -> "asUInt", - AsSIntOp -> "asSInt", - AsClockOp -> "asClock", - ShiftLeftOp -> "shl", - ShiftRightOp -> "shr", - DynShiftLeftOp -> "dshl", - DynShiftRightOp -> "dshr", - ConvertOp -> "cvt", - NegOp -> "neg", - BitNotOp -> "not", - BitAndOp -> "and", - BitOrOp -> "or", - BitXorOp -> "xor", - BitAndReduceOp -> "andr", - BitOrReduceOp -> "orr", - BitXorReduceOp -> "xorr", - ConcatOp -> "cat", - BitsSelectOp -> "bits", - HeadOp -> "head", - TailOp -> "tail" + ADD_OP -> "add", + SUB_OP -> "sub", + MUL_OP -> "mul", + DIV_OP -> "div", + REM_OP -> "rem", + LESS_OP -> "lt", + LESS_EQ_OP -> "leq", + GREATER_OP -> "gt", + GREATER_EQ_OP -> "geq", + EQUAL_OP -> "eq", + NEQUAL_OP -> "neq", + PAD_OP -> "pad", + AS_UINT_OP -> "asUInt", + AS_SINT_OP -> "asSInt", + AS_CLOCK_OP -> "asClock", + SHIFT_LEFT_OP -> "shl", + SHIFT_RIGHT_OP -> "shr", + DYN_SHIFT_LEFT_OP -> "dshl", + DYN_SHIFT_RIGHT_OP -> "dshr", + NEG_OP -> "neg", + CONVERT_OP -> "cvt", + NOT_OP -> "not", + AND_OP -> "and", + OR_OP -> "or", + XOR_OP -> "xor", + AND_REDUCE_OP -> "andr", + OR_REDUCE_OP -> "orr", + XOR_REDUCE_OP -> "xorr", + CONCAT_OP -> "cat", + BITS_SELECT_OP -> "bits", + HEAD_OP -> "head", + TAIL_OP -> "tail" ) private val mapString2PrimOp = mapPrimOp2String.map(_.swap) def fromString(op: String): PrimOp = mapString2PrimOp(op) @@ -50,67 +50,303 @@ object PrimOps extends LazyLogging { } // Borrowed from Stanza implementation - def lowerAndTypePrimOp(e: DoPrim): DoPrim = { - def uAnd(op1: Expression, op2: Expression): Type = { - (op1.getType, op2.getType) match { - case (t1: UIntType, t2: UIntType) => UIntType(UnknownWidth) - case (t1: SIntType, t2) => SIntType(UnknownWidth) - case (t1, t2: SIntType) => SIntType(UnknownWidth) - case _ => UnknownType() - } - } - def ofType(op: Expression): Type = { - op.getType match { - case t: UIntType => UIntType(UnknownWidth) - case t: SIntType => SIntType(UnknownWidth) - case _ => UnknownType() - } - } + def set_primop_type (e:DoPrim) : DoPrim = { + //println-all(["Inferencing primop type: " e]) + def PLUS (w1:Width,w2:Width) : Width = PlusWidth(w1,w2) + def MAX (w1:Width,w2:Width) : Width = MaxWidth(Seq(w1,w2)) + def MINUS (w1:Width,w2:Width) : Width = MinusWidth(w1,w2) + def POW (w1:Width) : Width = ExpWidth(w1) + def MIN (w1:Width,w2:Width) : Width = MinWidth(Seq(w1,w2)) + val o = e.op + val a = e.args + val c = e.consts + def t1 () = tpe(a(0)) + def t2 () = tpe(a(1)) + def t3 () = tpe(a(2)) + def w1 () = widthBANG(tpe(a(0))) + def w2 () = widthBANG(tpe(a(1))) + def w3 () = widthBANG(tpe(a(2))) + def c1 () = IntWidth(c(0)) + def c2 () = IntWidth(c(1)) + o match { + case ADD_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => UIntType(PLUS(MAX(w1(),w2()),ONE)) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) + case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case SUB_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) + case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),ONE)) + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case MUL_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => UIntType(PLUS(w1(),w2())) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),w2())) + case (t1:SIntType, t2:UIntType) => SIntType(PLUS(w1(),w2())) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),w2())) + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case DIV_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => UIntType(w1()) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),ONE)) + case (t1:SIntType, t2:UIntType) => SIntType(w1()) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),ONE)) + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case REM_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => UIntType(MIN(w1(),w2())) + case (t1:UIntType, t2:SIntType) => UIntType(MIN(w1(),w2())) + case (t1:SIntType, t2:UIntType) => SIntType(MIN(w1(),PLUS(w2(),ONE))) + case (t1:SIntType, t2:SIntType) => SIntType(MIN(w1(),w2())) + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case LESS_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => BoolType() + case (t1:SIntType, t2:UIntType) => BoolType() + case (t1:UIntType, t2:SIntType) => BoolType() + case (t1:SIntType, t2:SIntType) => BoolType() + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case LESS_EQ_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => BoolType() + case (t1:SIntType, t2:UIntType) => BoolType() + case (t1:UIntType, t2:SIntType) => BoolType() + case (t1:SIntType, t2:SIntType) => BoolType() + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case GREATER_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => BoolType() + case (t1:SIntType, t2:UIntType) => BoolType() + case (t1:UIntType, t2:SIntType) => BoolType() + case (t1:SIntType, t2:SIntType) => BoolType() + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case GREATER_EQ_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => BoolType() + case (t1:SIntType, t2:UIntType) => BoolType() + case (t1:UIntType, t2:SIntType) => BoolType() + case (t1:SIntType, t2:SIntType) => BoolType() + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case EQUAL_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => BoolType() + case (t1:SIntType, t2:UIntType) => BoolType() + case (t1:UIntType, t2:SIntType) => BoolType() + case (t1:SIntType, t2:SIntType) => BoolType() + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case NEQUAL_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => BoolType() + case (t1:SIntType, t2:UIntType) => BoolType() + case (t1:UIntType, t2:SIntType) => BoolType() + case (t1:SIntType, t2:SIntType) => BoolType() + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case PAD_OP => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(MAX(w1(),c1())) + case (t1:SIntType) => SIntType(MAX(w1(),c1())) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case AS_UINT_OP => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(w1()) + case (t1:SIntType) => UIntType(w1()) + case (t1:ClockType) => UIntType(ONE) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case AS_SINT_OP => { + val t = (t1()) match { + case (t1:UIntType) => SIntType(w1()) + case (t1:SIntType) => SIntType(w1()) + case (t1:ClockType) => SIntType(ONE) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case AS_CLOCK_OP => { + val t = (t1()) match { + case (t1:UIntType) => ClockType() + case (t1:SIntType) => ClockType() + case (t1:ClockType) => ClockType() + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case SHIFT_LEFT_OP => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(PLUS(w1(),c1())) + case (t1:SIntType) => SIntType(PLUS(w1(),c1())) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case SHIFT_RIGHT_OP => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(MINUS(w1(),c1())) + case (t1:SIntType) => SIntType(MINUS(w1(),c1())) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case DYN_SHIFT_LEFT_OP => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(PLUS(w1(),POW(w2()))) + case (t1:SIntType) => SIntType(PLUS(w1(),POW(w2()))) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case DYN_SHIFT_RIGHT_OP => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(w1()) + case (t1:SIntType) => SIntType(w1()) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case CONVERT_OP => { + val t = (t1()) match { + case (t1:UIntType) => SIntType(PLUS(w1(),ONE)) + case (t1:SIntType) => SIntType(w1()) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case NEG_OP => { + val t = (t1()) match { + case (t1:UIntType) => SIntType(PLUS(w1(),ONE)) + case (t1:SIntType) => SIntType(PLUS(w1(),ONE)) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case NOT_OP => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(w1()) + case (t1:SIntType) => UIntType(w1()) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case AND_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:SIntType) => UIntType(MAX(w1(),w2())) + case (t1,t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case OR_OP => { + val t = (t1(),t2()) match { + case (t1:SIntType, t2:SIntType) => UIntType(MAX(w1(),w2())) + case (t1,t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case XOR_OP => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => UIntType(MAX(w1(),w2())) + case (t1:SIntType, t2:UIntType) => UIntType(MAX(w1(),w2())) + case (t1:UIntType, t2:SIntType) => UIntType(MAX(w1(),w2())) + case (t1:SIntType, t2:SIntType) => UIntType(MAX(w1(),w2())) + case (t1,t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case AND_REDUCE_OP => { + val t = (t1()) match { + case (t1:UIntType) => BoolType() + case (t1:SIntType) => BoolType() + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case OR_REDUCE_OP => { + val t = (t1()) match { + case (t1:UIntType) => BoolType() + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case XOR_REDUCE_OP => { + val t = (t1()) match { + case (t1:SIntType) => BoolType() + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case CONCAT_OP => { + val t = (t1(),t2()) match { + case (t1:SIntType, t2:SIntType) => UIntType(PLUS(w1(),w2())) + case (t1, t2) => UnknownType() + } + DoPrim(o,a,c,t) + } + case BITS_SELECT_OP => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(PLUS(MINUS(c1(),c2()),ONE)) + case (t1:SIntType) => UIntType(PLUS(MINUS(c1(),c2()),ONE)) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case HEAD_OP => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(c1()) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } + case TAIL_OP => { + val t = (t1()) match { + case (t1:SIntType) => UIntType(MINUS(w1(),c1())) + case (t1) => UnknownType() + } + DoPrim(o,a,c,t) + } - logger.debug(s"lowerAndTypePrimOp on ${e.op.getClass.getSimpleName}") - // TODO fix this - val tpe = UIntType(UnknownWidth) - //val tpe = e.op match { - // case Add => uAnd(e.args(0), e.args(1)) - // case Sub => SIntType(UnknownWidth) - // case Addw => uAnd(e.args(0), e.args(1)) - // case Subw => uAnd(e.args(0), e.args(1)) - // case Mul => uAnd(e.args(0), e.args(1)) - // case Div => uAnd(e.args(0), e.args(1)) - // case Mod => ofType(e.args(0)) - // case Quo => uAnd(e.args(0), e.args(1)) - // case Rem => ofType(e.args(1)) - // case Lt => UIntType(UnknownWidth) - // case Leq => UIntType(UnknownWidth) - // case Gt => UIntType(UnknownWidth) - // case Geq => UIntType(UnknownWidth) - // case Eq => UIntType(UnknownWidth) - // case Neq => UIntType(UnknownWidth) - // case Eqv => UIntType(UnknownWidth) - // case Neqv => UIntType(UnknownWidth) - // case Mux => ofType(e.args(1)) - // case Pad => ofType(e.args(0)) - // case AsUInt => UIntType(UnknownWidth) - // case AsSInt => SIntType(UnknownWidth) - // case Shl => ofType(e.args(0)) - // case Shr => ofType(e.args(0)) - // case Dshl => ofType(e.args(0)) - // case Dshr => ofType(e.args(0)) - // case Cvt => SIntType(UnknownWidth) - // case Neg => SIntType(UnknownWidth) - // case Not => ofType(e.args(0)) - // case And => ofType(e.args(0)) - // case Or => ofType(e.args(0)) - // case Xor => ofType(e.args(0)) - // case Andr => UIntType(UnknownWidth) - // case Orr => UIntType(UnknownWidth) - // case Xorr => UIntType(UnknownWidth) - // case Cat => UIntType(UnknownWidth) - // case Bit => UIntType(UnknownWidth) - // case Bits => UIntType(UnknownWidth) - // case _ => ??? - //} - DoPrim(e.op, e.args, e.consts, tpe) - } + } + } } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 647fb9c2..f029d410 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -18,434 +18,560 @@ import PrimOps._ object Utils { - // Is there a more elegant way to do this? - private type FlagMap = Map[String, Boolean] - private val FlagMap = Map[String, Boolean]().withDefaultValue(false) - - val lnOf2 = scala.math.log(2) // natural log of 2 - def ceil_log2(x: BigInt): BigInt = (x-1).bitLength - - def create_mask (dt:Type) : Type = { - dt match { - case t:VectorType => VectorType(create_mask(t.tpe),t.size) - case t:BundleType => { - val fieldss = t.fields.map { f => Field(f.name,f.flip,create_mask(f.tpe)) } - BundleType(fieldss) - } - case t:UIntType => BoolType() - case t:SIntType => BoolType() - } - } - - def error(str:String) = throw new FIRRTLException(str) - def debug(node: AST)(implicit flags: FlagMap): String = { - if (!flags.isEmpty) { - var str = "" - if (flags("types")) { - val tpe = node.getType - tpe match { - case t:UnknownType => str += s"@<t:${tpe.wipeWidth.serialize}>" - } + // Is there a more elegant way to do this? + private type FlagMap = Map[String, Boolean] + private val FlagMap = Map[String, Boolean]().withDefaultValue(false) + + val lnOf2 = scala.math.log(2) // natural log of 2 + def ceil_log2(x: BigInt): BigInt = (x-1).bitLength + val gen_names = Map[String,Int]() + val delin = "_" + def firrtl_gensym (s:String):String = { + firrtl_gensym(s,Map[String,Int]()) + } + def firrtl_gensym (sym_hash:Map[String,Int]):String = { + firrtl_gensym("gen",sym_hash) + } + def firrtl_gensym (s:String,sym_hash:Map[String,Int]):String = { + if (sym_hash contains s) { + val num = sym_hash(s) + 1 + sym_hash + (s -> num) + (s + delin + num) + } else { + sym_hash + (s -> 0) + (s + delin + 0) } - str - } - else { - "" - } - } - - implicit class BigIntUtils(bi: BigInt){ - def serialize(implicit flags: FlagMap = FlagMap): String = - "\"h" + bi.toString(16) + "\"" - } - - implicit class ASTUtils(ast: AST) { - def getType(): Type = - ast match { - case e: Expression => e.getType - case s: Stmt => s.getType - //case f: Field => f.getType - case t: Type => t.getType - case p: Port => p.getType - case _ => UnknownType() + } + + def create_mask (dt:Type) : Type = { + dt match { + case t:VectorType => VectorType(create_mask(t.tpe),t.size) + case t:BundleType => { + val fieldss = t.fields.map { f => Field(f.name,f.flip,create_mask(f.tpe)) } + BundleType(fieldss) + } + case t:UIntType => BoolType() + case t:SIntType => BoolType() } - } - - implicit class PrimOpUtils(op: PrimOp) { - def serialize(implicit flags: FlagMap = FlagMap): String = op.getString - } - - -// ACCESSORS ========= - def gender (e:Expression) : Gender = { - e match { - case e:WRef => gender(e) - case e:WSubField => gender(e) - case e:WSubIndex => gender(e) - case e:WSubAccess => gender(e) - case e:PrimOp => MALE - case e:UIntValue => MALE - case e:SIntValue => MALE - case e:Mux => MALE - case e:ValidIf => MALE - case _ => error("Shouldn't be here") - }} - def get_gender (s:Stmt) : Gender = - s match { - case s:DefWire => BIGENDER - case s:DefRegister => BIGENDER - case s:WDefInstance => MALE - case s:DefNode => MALE - case s:DefInstance => MALE - case s:DefPoison => UNKNOWNGENDER - case s:DefMemory => MALE - case s:Begin => UNKNOWNGENDER - case s:Connect => UNKNOWNGENDER - case s:BulkConnect => UNKNOWNGENDER - case s:Stop => UNKNOWNGENDER - case s:Print => UNKNOWNGENDER - case s:Empty => UNKNOWNGENDER - case s:IsInvalid => UNKNOWNGENDER - } - def get_gender (p:Port) : Gender = - if (p.dir == Input) MALE else FEMALE - def kind (e:Expression) : Kind = - e match { - case e:WRef => e.kind - case e:WSubField => kind(e.exp) - case e:WSubIndex => kind(e.exp) - case e => ExpKind() - } - def tpe (e:Expression) : Type = - e match { - case e:WRef => e.tpe - case e:WSubField => e.tpe - case e:WSubIndex => e.tpe - case e:UIntValue => UIntType(e.width) - case e:SIntValue => SIntType(e.width) - case e:WVoid => UnknownType() - case e:WInvalid => UnknownType() - } - def get_type (s:Stmt) : Type = { - s match { - case s:DefWire => s.tpe - case s:DefPoison => s.tpe - case s:DefRegister => s.tpe - case s:DefNode => tpe(s.value) - case s:DefMemory => { - val depth = s.depth - val addr = Field("addr",Default,UIntType(IntWidth(ceil_log2(depth)))) - val en = Field("en",Default,BoolType()) - val clk = Field("clk",Default,ClockType()) - val def_data = Field("data",Default,s.dataType) - val rev_data = Field("data",Reverse,s.dataType) - val mask = Field("mask",Default,create_mask(s.dataType)) - val wmode = Field("wmode",Default,UIntType(IntWidth(1))) - val rdata = Field("rdata",Reverse,s.dataType) - val read_type = BundleType(Seq(rev_data,addr,en,clk)) - val write_type = BundleType(Seq(def_data,mask,addr,en,clk)) - val readwrite_type = BundleType(Seq(wmode,rdata,def_data,mask,addr,en,clk)) - - val mem_fields = Vector() - s.readers.foreach {x => mem_fields :+ Field(x,Reverse,read_type)} - s.writers.foreach {x => mem_fields :+ Field(x,Reverse,write_type)} - s.readwriters.foreach {x => mem_fields :+ Field(x,Reverse,readwrite_type)} - BundleType(mem_fields) + } + +//============== TYPES ================ + def mux_type_and_widths (e1:Expression,e2:Expression) : Type = mux_type_and_widths(tpe(e1),tpe(e2)) + def mux_type_and_widths (t1:Type,t2:Type) : Type = { + def wmax (w1:Width,w2:Width) : Width = { + (w1,w2) match { + case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width.max(w2.width)) + case (w1,w2) => MaxWidth(Seq(w1,w2)) + } + } + if (equals(t1,t2)) { + (t1,t2) match { + case (t1:UIntType,t2:UIntType) => UIntType(wmax(t1.width,t2.width)) + case (t1:SIntType,t2:SIntType) => SIntType(wmax(t1.width,t2.width)) + case (t1:VectorType,t2:VectorType) => VectorType(mux_type_and_widths(t1.tpe,t2.tpe),t1.size) + case (t1:BundleType,t2:BundleType) => BundleType((t1.fields zip t2.fields).map{case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe,f2.tpe))}) + } + } else UnknownType() + } + def module_type (m:Module) : Type = { + BundleType(m.ports.map(p => p.toField)) + } + def sub_type (v:Type) : Type = { + v match { + case v:VectorType => v.tpe + case v => UnknownType() } - case s:DefInstance => UnknownType() - case _ => UnknownType() - }} - - def sMap(f:Stmt => Stmt, stmt: Stmt): Stmt = - stmt match { - case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) - case b: Begin => Begin(b.stmts.map(f)) - case s: Stmt => s + } + def field_type (v:Type,s:String) : Type = { + v match { + case v:BundleType => { + val ft = v.fields.find(p => p.name == s) + ft match { + case ft:Some[Field] => ft.get.tpe + case ft => UnknownType() + } + } + case v => UnknownType() + } + } + +//===================================== + def widthBANG (t:Type) : Width = { + t match { + case t:UIntType => t.width + case t:SIntType => t.width + case t:ClockType => IntWidth(1) + case t => error("No width!") + } + } +// ================================= + def error(str:String) = throw new FIRRTLException(str) + def debug(node: AST)(implicit flags: FlagMap): String = { + if (!flags.isEmpty) { + var str = "" + if (flags("types")) { + val tpe = node.getType + tpe match { + case t:UnknownType => str += s"@<t:${tpe.wipeWidth.serialize}>" + } + } + str } - def eMap(f:Expression => Expression, stmt:Stmt) : Stmt = - stmt match { - case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) - case n: DefNode => DefNode(n.info, n.name, f(n.value)) - case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) - case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) - case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) - case i: IsInvalid => IsInvalid(i.info, f(i.exp)) - case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) - case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) - case s: Stmt => s + else { + "" } - def eMap(f: Expression => Expression, exp:Expression): Expression = - exp match { - case s: SubField => SubField(f(s.exp), s.name, s.tpe) - case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) - case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) - case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) - case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) - case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) - case s: WSubField => WSubField(f(s.exp), s.name, s.tpe, s.gender) - case s: WSubIndex => WSubIndex(f(s.exp), s.value, s.tpe, s.gender) - case s: WSubAccess => WSubAccess(f(s.exp), f(s.index), s.tpe, s.gender) - case e: Expression => e + } + + implicit class BigIntUtils(bi: BigInt){ + def serialize(implicit flags: FlagMap = FlagMap): String = + "\"h" + bi.toString(16) + "\"" + } + + implicit class ASTUtils(ast: AST) { + def getType(): Type = + ast match { + case e: Expression => e.getType + case s: Stmt => s.getType + //case f: Field => f.getType + case t: Type => t.getType + case p: Port => p.getType + case _ => UnknownType() + } + } + + implicit class PrimOpUtils(op: PrimOp) { + def serialize(implicit flags: FlagMap = FlagMap): String = op.getString + } + + +// ACCESSORS ========= + def gender (e:Expression) : Gender = { + e match { + case e:WRef => gender(e) + case e:WSubField => gender(e) + case e:WSubIndex => gender(e) + case e:WSubAccess => gender(e) + case e:PrimOp => MALE + case e:UIntValue => MALE + case e:SIntValue => MALE + case e:Mux => MALE + case e:ValidIf => MALE + case _ => error("Shouldn't be here") + }} + def get_gender (s:Stmt) : Gender = + s match { + case s:DefWire => BIGENDER + case s:DefRegister => BIGENDER + case s:WDefInstance => MALE + case s:DefNode => MALE + case s:DefInstance => MALE + case s:DefPoison => UNKNOWNGENDER + case s:DefMemory => MALE + case s:Begin => UNKNOWNGENDER + case s:Connect => UNKNOWNGENDER + case s:BulkConnect => UNKNOWNGENDER + case s:Stop => UNKNOWNGENDER + case s:Print => UNKNOWNGENDER + case s:Empty => UNKNOWNGENDER + case s:IsInvalid => UNKNOWNGENDER } - //private trait StmtMagnet { - // def map(stmt: Stmt): Stmt - //} - //private object StmtMagnet { - // implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet { - // override def map(stmt: Stmt): Stmt = - // stmt match { - // case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) - // case b: Begin => Begin(b.stmts.map(f)) - // case s: Stmt => s - // } - // } - // implicit def forExp(f: Expression => Expression) = new StmtMagnet { - // override def map(stmt: Stmt): Stmt = - // stmt match { - // case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) - // case n: DefNode => DefNode(n.info, n.name, f(n.value)) - // case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) - // case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) - // case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) - // case i: IsInvalid => IsInvalid(i.info, f(i.exp)) - // case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) - // case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) - // case s: Stmt => s - // } - // } - //} - implicit class ExpUtils(exp: Expression) { - def serialize(implicit flags: FlagMap = FlagMap): String = { - val ret = exp match { - case v: UIntValue => s"UInt${v.width.serialize}(${v.value.serialize})" - case v: SIntValue => s"SInt${v.width.serialize}(${v.value.serialize})" - case r: Ref => r.name - case s: SubField => s"${s.exp.serialize}.${s.name}" - case s: SubIndex => s"${s.exp.serialize}[${s.value}]" - case s: SubAccess => s"${s.exp.serialize}[${s.index.serialize}]" - case m: Mux => s"mux(${m.cond.serialize}, ${m.tval.serialize}, ${m.fval.serialize})" - case v: ValidIf => s"validif(${v.cond.serialize}, ${v.value.serialize})" - case p: DoPrim => - s"${p.op.serialize}(" + (p.args.map(_.serialize) ++ p.consts.map(_.toString)).mkString(", ") + ")" - case r: WRef => r.name - case s: WSubField => s"${s.exp.serialize}.${s.name}" - case s: WSubIndex => s"${s.exp.serialize}[${s.value}]" - case s: WSubAccess => s"${s.exp.serialize}[${s.index.serialize}]" - } - ret + debug(exp) - } - } - - // def map(f: Expression => Expression): Expression = - // exp match { - // case s: SubField => SubField(f(s.exp), s.name, s.tpe) - // case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) - // case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) - // case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) - // case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) - // case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) - // case s: WSubField => SubField(f(s.exp), s.name, s.tpe, s.gender) - // case s: WSubIndex => SubIndex(f(s.exp), s.value, s.tpe, s.gender) - // case s: WSubAccess => SubAccess(f(s.exp), f(s.index), s.tpe, s.gender) - // case e: Expression => e - // } - //} - - implicit class StmtUtils(stmt: Stmt) { - def serialize(implicit flags: FlagMap = FlagMap): String = - { - var ret = stmt match { - case w: DefWire => s"wire ${w.name} : ${w.tpe.serialize}" - case r: DefRegister => - val str = new StringBuilder(s"reg ${r.name} : ${r.tpe.serialize}, ${r.clock.serialize} with : ") - withIndent { - str ++= newline + s"reset => (${r.reset.serialize}, ${r.init.serialize})" - } - str - case i: DefInstance => s"inst ${i.name} of ${i.module}" - case i: WDefInstance => s"inst ${i.name} of ${i.module}" - case m: DefMemory => { - val str = new StringBuilder(s"mem ${m.name} : " + newline) - withIndent { - str ++= s"data-type => ${m.dataType.serialize}" + newline + - s"depth => ${m.depth}" + newline + - s"read-latency => ${m.readLatency}" + newline + - s"write-latency => ${m.writeLatency}" + newline + - (if (m.readers.nonEmpty) m.readers.map(r => s"reader => ${r}").mkString(newline) + newline - else "") + - (if (m.writers.nonEmpty) m.writers.map(w => s"writer => ${w}").mkString(newline) + newline - else "") + - (if (m.readwriters.nonEmpty) m.readwriters.map(rw => s"readwriter => ${rw}").mkString(newline) + newline - else "") + - s"read-under-write => undefined" - } - str.result - } - case n: DefNode => s"node ${n.name} = ${n.value.serialize}" - case c: Connect => s"${c.loc.serialize} <= ${c.exp.serialize}" - case b: BulkConnect => s"${b.loc.serialize} <- ${b.exp.serialize}" - case w: Conditionally => { - var str = new StringBuilder(s"when ${w.pred.serialize} : ") - withIndent { str ++= w.conseq.serialize } - w.alt match { - case s:Empty => str.result - case s => { - str ++= newline + "else :" - withIndent { str ++= w.alt.serialize } - str.result - } - } - } - case b: Begin => { - val s = new StringBuilder - b.stmts.foreach { s ++= newline ++ _.serialize } - s.result + debug(b) - } - case i: IsInvalid => s"${i.exp.serialize} is invalid" - case s: Stop => s"stop(${s.clk.serialize}, ${s.en.serialize}, ${s.ret})" - case p: Print => s"printf(${p.clk.serialize}, ${p.en.serialize}, ${p.string}" + - (if (p.args.nonEmpty) p.args.map(_.serialize).mkString(", ", ", ", "") else "") + ")" - case s:Empty => "skip" - } - ret + debug(stmt) - } - - // Using implicit types to allow overloading of function type to map, see StmtMagnet above - //def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt) - - def getType(): Type = + def get_gender (p:Port) : Gender = + if (p.direction == Input) MALE else FEMALE + def kind (e:Expression) : Kind = + e match { + case e:WRef => e.kind + case e:WSubField => kind(e.exp) + case e:WSubIndex => kind(e.exp) + case e => ExpKind() + } + def tpe (e:Expression) : Type = + e match { + case e:WRef => e.tpe + case e:WSubField => e.tpe + case e:WSubIndex => e.tpe + case e:WSubAccess => e.tpe + case e:DoPrim => e.tpe + case e:Mux => e.tpe + case e:ValidIf => e.tpe + case e:UIntValue => UIntType(e.width) + case e:SIntValue => SIntType(e.width) + case e:WVoid => UnknownType() + case e:WInvalid => UnknownType() + } + def get_type (s:Stmt) : Type = { + s match { + case s:DefWire => s.tpe + case s:DefPoison => s.tpe + case s:DefRegister => s.tpe + case s:DefNode => tpe(s.value) + case s:DefMemory => { + val depth = s.depth + val addr = Field("addr",Default,UIntType(IntWidth(ceil_log2(depth)))) + val en = Field("en",Default,BoolType()) + val clk = Field("clk",Default,ClockType()) + val def_data = Field("data",Default,s.data_type) + val rev_data = Field("data",Reverse,s.data_type) + val mask = Field("mask",Default,create_mask(s.data_type)) + val wmode = Field("wmode",Default,UIntType(IntWidth(1))) + val rdata = Field("rdata",Reverse,s.data_type) + val read_type = BundleType(Seq(rev_data,addr,en,clk)) + val write_type = BundleType(Seq(def_data,mask,addr,en,clk)) + val readwrite_type = BundleType(Seq(wmode,rdata,def_data,mask,addr,en,clk)) + + val mem_fields = Vector() + s.readers.foreach {x => mem_fields :+ Field(x,Reverse,read_type)} + s.writers.foreach {x => mem_fields :+ Field(x,Reverse,write_type)} + s.readwriters.foreach {x => mem_fields :+ Field(x,Reverse,readwrite_type)} + BundleType(mem_fields) + } + case s:DefInstance => UnknownType() + case _ => UnknownType() + }} + +// =============== MAPPERS =================== + def sMap(f:Stmt => Stmt, stmt: Stmt): Stmt = stmt match { - case s: DefWire => s.tpe - case s: DefRegister => s.tpe - case s: DefMemory => s.dataType - case _ => UnknownType() + case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) + case b: Begin => Begin(b.stmts.map(f)) + case s: Stmt => s } - } - - implicit class WidthUtils(w: Width) { - def serialize(implicit flags: FlagMap = FlagMap): String = { - val s = w match { - case UnknownWidth => "" //"?" - case w: IntWidth => s"<${w.width.toString}>" - } - s + debug(w) - } - } - - implicit class FlipUtils(f: Flip) { - def serialize(implicit flags: FlagMap = FlagMap): String = { - val s = f match { - case Reverse => "flip " - case Default => "" - } - s + debug(f) - } - def flip(): Flip = { - f match { - case Reverse => Default - case Default => Reverse + def eMap(f:Expression => Expression, stmt:Stmt) : Stmt = + stmt match { + case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) + case n: DefNode => DefNode(n.info, n.name, f(n.value)) + case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) + case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) + case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) + case i: IsInvalid => IsInvalid(i.info, f(i.exp)) + case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) + case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) + case s: Stmt => s } - } - - def toDirection(): Direction = { - f match { - case Default => Output - case Reverse => Input + def eMap(f: Expression => Expression, exp:Expression): Expression = + exp match { + case s: SubField => SubField(f(s.exp), s.name, s.tpe) + case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) + case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) + case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) + case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) + case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) + case s: WSubField => WSubField(f(s.exp), s.name, s.tpe, s.gender) + case s: WSubIndex => WSubIndex(f(s.exp), s.value, s.tpe, s.gender) + case s: WSubAccess => WSubAccess(f(s.exp), f(s.index), s.tpe, s.gender) + case e: Expression => e } - } - } - - implicit class FieldUtils(field: Field) { - def serialize(implicit flags: FlagMap = FlagMap): String = - s"${field.flip.serialize}${field.name} : ${field.tpe.serialize}" + debug(field) - def flip(): Field = Field(field.name, field.flip.flip, field.tpe) - - def getType(): Type = field.tpe - def toPort(): Port = Port(NoInfo, field.name, field.flip.toDirection, field.tpe) - } - - implicit class TypeUtils(t: Type) { - def serialize(implicit flags: FlagMap = FlagMap): String = { - val commas = ", " // for mkString in BundleType - val s = t match { - case c:ClockType => "Clock" - //case UnknownType => "UnknownType" - case u:UnknownType => "?" - case t: UIntType => s"UInt${t.width.serialize}" - case t: SIntType => s"SInt${t.width.serialize}" - case t: BundleType => s"{ ${t.fields.map(_.serialize).mkString(commas)}}" - case t: VectorType => s"${t.tpe.serialize}[${t.size}]" - } - s + debug(t) - } - - def getType(): Type = + def tMap (f: Type => Type, t:Type):Type = { t match { - case v: VectorType => v.tpe - case tpe: Type => UnknownType() + case t:BundleType => BundleType(t.fields.map(p => Field(p.name, p.flip, f(p.tpe)))) + case t:VectorType => VectorType(f(t.tpe), t.size) + case t => t } - - def wipeWidth(): Type = - t match { - case t: UIntType => UIntType(UnknownWidth) - case t: SIntType => SIntType(UnknownWidth) - case _ => t + } + def tMap (f: Type => Type, c:Expression) : Expression = { + c match { + case c:DoPrim => DoPrim(c.op,c.args,c.consts,f(c.tpe)) + case c:Mux => Mux(c.cond,c.tval,c.fval,f(c.tpe)) + case c:ValidIf => ValidIf(c.cond,c.value,f(c.tpe)) + case c:WRef => WRef(c.name,f(c.tpe),c.kind,c.gender) + case c:WSubField => WSubField(c.exp,c.name,f(c.tpe),c.gender) + case c:WSubIndex => WSubIndex(c.exp,c.value,f(c.tpe),c.gender) + case c:WSubAccess => WSubAccess(c.exp,c.index,f(c.tpe),c.gender) + case c => c } - } - - implicit class DirectionUtils(d: Direction) { - def serialize(implicit flags: FlagMap = FlagMap): String = { - val s = d match { - case Input => "input" - case Output => "output" - } - s + debug(d) - } - def toFlip(): Flip = { - d match { - case Input => Reverse - case Output => Default + } + def tMap (f: Type => Type, c:Stmt) : Stmt = { + c match { + case c:DefPoison => DefPoison(c.info,c.name,f(c.tpe)) + case c:DefWire => DefWire(c.info,c.name,f(c.tpe)) + case c:DefRegister => DefRegister(c.info,c.name,f(c.tpe),c.clock,c.reset,c.init) + case c:DefMemory => DefMemory(c.info,c.name, f(c.data_type), c.depth, c.write_latency, c.read_latency, c.readers, c.writers, c.readwriters) + case c => c } - } - } - - implicit class PortUtils(p: Port) { - def serialize(implicit flags: FlagMap = FlagMap): String = - s"${p.dir.serialize} ${p.name} : ${p.tpe.serialize}" + debug(p) - def getType(): Type = p.tpe - def toField(): Field = Field(p.name, p.dir.toFlip, p.tpe) - } - - implicit class ModuleUtils(m: Module) { - def serialize(implicit flags: FlagMap = FlagMap): String = { - m match { - case m:InModule => { - var s = new StringBuilder(s"module ${m.name} : ") - withIndent { - s ++= m.ports.map(newline ++ _.serialize).mkString - s ++= m.body.serialize - } - s ++= debug(m) - s.toString - } + } + def wMap (f: Width => Width, c:Expression) : Expression = { + c match { + case c:UIntValue => UIntValue(c.value,f(c.width)) + case c:SIntValue => SIntValue(c.value,f(c.width)) + case c => c + } + } + def wMap (f: Width => Width, c:Type) : Type = { + c match { + case c:UIntType => UIntType(f(c.width)) + case c:SIntType => SIntType(f(c.width)) + case c => c + } + } + def wMap (f: Width => Width, w:Width) : Width = { + w match { + case w:MaxWidth => MaxWidth(w.args.map(f)) + case w:MinWidth => MinWidth(w.args.map(f)) + case w:PlusWidth => PlusWidth(f(w.arg1),f(w.arg2)) + case w:MinusWidth => MinusWidth(f(w.arg1),f(w.arg2)) + case w:ExpWidth => ExpWidth(f(w.arg1)) + case w => w } - } - } - - implicit class CircuitUtils(c: Circuit) { - def serialize(implicit flags: FlagMap = FlagMap): String = { - var s = new StringBuilder(s"circuit ${c.main} : ") - withIndent { s ++= newline ++ c.modules.map(_.serialize).mkString(newline + newline) } - s ++= newline ++ newline - s ++= debug(c) - s.toString - } - } - - private var indentLevel = 0 - private def newline = "\n" + (" " * indentLevel) - private def indent(): Unit = indentLevel += 1 - private def unindent() { require(indentLevel > 0); indentLevel -= 1 } - private def withIndent(f: => Unit) { indent(); f; unindent() } + } + val ONE = IntWidth(1) + //private trait StmtMagnet { + // def map(stmt: Stmt): Stmt + //} + //private object StmtMagnet { + // implicit def forStmt(f: Stmt => Stmt) = new StmtMagnet { + // override def map(stmt: Stmt): Stmt = + // stmt match { + // case w: Conditionally => Conditionally(w.info, w.pred, f(w.conseq), f(w.alt)) + // case b: Begin => Begin(b.stmts.map(f)) + // case s: Stmt => s + // } + // } + // implicit def forExp(f: Expression => Expression) = new StmtMagnet { + // override def map(stmt: Stmt): Stmt = + // stmt match { + // case r: DefRegister => DefRegister(r.info, r.name, r.tpe, f(r.clock), f(r.reset), f(r.init)) + // case n: DefNode => DefNode(n.info, n.name, f(n.value)) + // case c: Connect => Connect(c.info, f(c.loc), f(c.exp)) + // case b: BulkConnect => BulkConnect(b.info, f(b.loc), f(b.exp)) + // case w: Conditionally => Conditionally(w.info, f(w.pred), w.conseq, w.alt) + // case i: IsInvalid => IsInvalid(i.info, f(i.exp)) + // case s: Stop => Stop(s.info, s.ret, f(s.clk), f(s.en)) + // case p: Print => Print(p.info, p.string, p.args.map(f), f(p.clk), f(p.en)) + // case s: Stmt => s + // } + // } + //} + implicit class ExpUtils(exp: Expression) { + def serialize(implicit flags: FlagMap = FlagMap): String = { + val ret = exp match { + case v: UIntValue => s"UInt${v.width.serialize}(${v.value.serialize})" + case v: SIntValue => s"SInt${v.width.serialize}(${v.value.serialize})" + case r: Ref => r.name + case s: SubField => s"${s.exp.serialize}.${s.name}" + case s: SubIndex => s"${s.exp.serialize}[${s.value}]" + case s: SubAccess => s"${s.exp.serialize}[${s.index.serialize}]" + case m: Mux => s"mux(${m.cond.serialize}, ${m.tval.serialize}, ${m.fval.serialize})" + case v: ValidIf => s"validif(${v.cond.serialize}, ${v.value.serialize})" + case p: DoPrim => + s"${p.op.serialize}(" + (p.args.map(_.serialize) ++ p.consts.map(_.toString)).mkString(", ") + ")" + case r: WRef => r.name + case s: WSubField => s"w${s.exp.serialize}.${s.name}" + case s: WSubIndex => s"w${s.exp.serialize}[${s.value}]" + case s: WSubAccess => s"w${s.exp.serialize}[${s.index.serialize}]" + } + ret + debug(exp) + } + } + + // def map(f: Expression => Expression): Expression = + // exp match { + // case s: SubField => SubField(f(s.exp), s.name, s.tpe) + // case s: SubIndex => SubIndex(f(s.exp), s.value, s.tpe) + // case s: SubAccess => SubAccess(f(s.exp), f(s.index), s.tpe) + // case m: Mux => Mux(f(m.cond), f(m.tval), f(m.fval), m.tpe) + // case v: ValidIf => ValidIf(f(v.cond), f(v.value), v.tpe) + // case p: DoPrim => DoPrim(p.op, p.args.map(f), p.consts, p.tpe) + // case s: WSubField => SubField(f(s.exp), s.name, s.tpe, s.gender) + // case s: WSubIndex => SubIndex(f(s.exp), s.value, s.tpe, s.gender) + // case s: WSubAccess => SubAccess(f(s.exp), f(s.index), s.tpe, s.gender) + // case e: Expression => e + // } + //} + + implicit class StmtUtils(stmt: Stmt) { + def serialize(implicit flags: FlagMap = FlagMap): String = + { + var ret = stmt match { + case w: DefWire => s"wire ${w.name} : ${w.tpe.serialize}" + case r: DefRegister => + val str = new StringBuilder(s"reg ${r.name} : ${r.tpe.serialize}, ${r.clock.serialize} with : ") + withIndent { + str ++= newline + s"reset => (${r.reset.serialize}, ${r.init.serialize})" + } + str + case i: DefInstance => s"inst ${i.name} of ${i.module}" + case i: WDefInstance => s"inst ${i.name} of ${i.module}" + case m: DefMemory => { + val str = new StringBuilder(s"mem ${m.name} : " + newline) + withIndent { + str ++= s"data-type => ${m.data_type}" + newline + + s"depth => ${m.depth}" + newline + + s"read-latency => ${m.read_latency}" + newline + + s"write-latency => ${m.write_latency}" + newline + + (if (m.readers.nonEmpty) m.readers.map(r => s"reader => ${r}").mkString(newline) + newline + else "") + + (if (m.writers.nonEmpty) m.writers.map(w => s"writer => ${w}").mkString(newline) + newline + else "") + + (if (m.readwriters.nonEmpty) m.readwriters.map(rw => s"readwriter => ${rw}").mkString(newline) + newline + else "") + + s"read-under-write => undefined" + } + str.result + } + case n: DefNode => s"node ${n.name} = ${n.value.serialize}" + case c: Connect => s"${c.loc.serialize} <= ${c.exp.serialize}" + case b: BulkConnect => s"${b.loc.serialize} <- ${b.exp.serialize}" + case w: Conditionally => { + var str = new StringBuilder(s"when ${w.pred.serialize} : ") + withIndent { str ++= w.conseq.serialize } + w.alt match { + case s:Empty => str.result + case s => { + str ++= newline + "else :" + withIndent { str ++= w.alt.serialize } + str.result + } + } + } + case b: Begin => { + val s = new StringBuilder + b.stmts.foreach { s ++= newline ++ _.serialize } + s.result + debug(b) + } + case i: IsInvalid => s"${i.exp.serialize} is invalid" + case s: Stop => s"stop(${s.clk.serialize}, ${s.en.serialize}, ${s.ret})" + case p: Print => s"printf(${p.clk.serialize}, ${p.en.serialize}, ${p.string}" + + (if (p.args.nonEmpty) p.args.map(_.serialize).mkString(", ", ", ", "") else "") + ")" + case s:Empty => "skip" + } + ret + debug(stmt) + } + + // Using implicit types to allow overloading of function type to map, see StmtMagnet above + //def map[T](f: T => T)(implicit magnet: (T => T) => StmtMagnet): Stmt = magnet(f).map(stmt) + + def getType(): Type = + stmt match { + case s: DefWire => s.tpe + case s: DefRegister => s.tpe + case s: DefMemory => s.data_type + case _ => UnknownType() + } + } + + implicit class WidthUtils(w: Width) { + def serialize(implicit flags: FlagMap = FlagMap): String = { + val s = w match { + case w:UnknownWidth => "" //"?" + case w: IntWidth => s"<${w.width.toString}>" + } + s + debug(w) + } + } + + implicit class FlipUtils(f: Flip) { + def serialize(implicit flags: FlagMap = FlagMap): String = { + val s = f match { + case Reverse => "flip " + case Default => "" + } + s + debug(f) + } + def flip(): Flip = { + f match { + case Reverse => Default + case Default => Reverse + } + } + + def toDirection(): Direction = { + f match { + case Default => Output + case Reverse => Input + } + } + } + + implicit class FieldUtils(field: Field) { + def serialize(implicit flags: FlagMap = FlagMap): String = + s"${field.flip.serialize}${field.name} : ${field.tpe.serialize}" + debug(field) + def flip(): Field = Field(field.name, field.flip.flip, field.tpe) + + def getType(): Type = field.tpe + def toPort(): Port = Port(NoInfo, field.name, field.flip.toDirection, field.tpe) + } + + implicit class TypeUtils(t: Type) { + def serialize(implicit flags: FlagMap = FlagMap): String = { + val commas = ", " // for mkString in BundleType + val s = t match { + case c:ClockType => "Clock" + //case UnknownType => "UnknownType" + case u:UnknownType => "?" + case t: UIntType => s"UInt${t.width.serialize}" + case t: SIntType => s"SInt${t.width.serialize}" + case t: BundleType => s"{ ${t.fields.map(_.serialize).mkString(commas)}}" + case t: VectorType => s"${t.tpe.serialize}[${t.size}]" + } + s + debug(t) + } + + def getType(): Type = + t match { + case v: VectorType => v.tpe + case tpe: Type => UnknownType() + } + + def wipeWidth(): Type = + t match { + case t: UIntType => UIntType(UnknownWidth()) + case t: SIntType => SIntType(UnknownWidth()) + case _ => t + } + } + + implicit class DirectionUtils(d: Direction) { + def serialize(implicit flags: FlagMap = FlagMap): String = { + val s = d match { + case Input => "input" + case Output => "output" + } + s + debug(d) + } + def toFlip(): Flip = { + d match { + case Input => Reverse + case Output => Default + } + } + } + + implicit class PortUtils(p: Port) { + def serialize(implicit flags: FlagMap = FlagMap): String = + s"${p.direction.serialize} ${p.name} : ${p.tpe.serialize}" + debug(p) + def getType(): Type = p.tpe + def toField(): Field = Field(p.name, p.direction.toFlip, p.tpe) + } + + implicit class ModuleUtils(m: Module) { + def serialize(implicit flags: FlagMap = FlagMap): String = { + m match { + case m:InModule => { + var s = new StringBuilder(s"module ${m.name} : ") + withIndent { + s ++= m.ports.map(newline ++ _.serialize).mkString + s ++= m.body.serialize + } + s ++= debug(m) + s.toString + } + } + } + } + + implicit class CircuitUtils(c: Circuit) { + def serialize(implicit flags: FlagMap = FlagMap): String = { + var s = new StringBuilder(s"circuit ${c.main} : ") + withIndent { s ++= newline ++ c.modules.map(_.serialize).mkString(newline + newline) } + s ++= newline ++ newline + s ++= debug(c) + s.toString + } + } + + private var indentLevel = 0 + private def newline = "\n" + (" " * indentLevel) + private def indent(): Unit = indentLevel += 1 + private def unindent() { require(indentLevel > 0); indentLevel -= 1 } + private def withIndent(f: => Unit) { indent(); f; unindent() } } diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index 96d77485..42de6348 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -61,9 +61,9 @@ class Visitor(val fullFilename: String) extends FIRRTLBaseVisitor[AST] private def visitType[AST](ctx: FIRRTLParser.TypeContext): Type = { ctx.getChild(0).getText match { case "UInt" => if (ctx.getChildCount > 1) UIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) - else UIntType( UnknownWidth ) + else UIntType( UnknownWidth() ) case "SInt" => if (ctx.getChildCount > 1) SIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) - else SIntType( UnknownWidth ) + else SIntType( UnknownWidth() ) case "Clock" => ClockType() case "{" => BundleType(ctx.field.map(visitField)) case _ => new VectorType( visitType(ctx.`type`), string2BigInt(ctx.IntLit.getText) ) @@ -164,14 +164,14 @@ class Visitor(val fullFilename: String) extends FIRRTLBaseVisitor[AST] val (width, value) = if (ctx.getChildCount > 4) (IntWidth(string2BigInt(ctx.IntLit(0).getText)), string2BigInt(ctx.IntLit(1).getText)) - else (UnknownWidth, string2BigInt(ctx.IntLit(0).getText)) + else (UnknownWidth(), string2BigInt(ctx.IntLit(0).getText)) UIntValue(value, width) } case "SInt" => { val (width, value) = if (ctx.getChildCount > 4) (IntWidth(string2BigInt(ctx.IntLit(0).getText)), string2BigInt(ctx.IntLit(1).getText)) - else (UnknownWidth, string2BigInt(ctx.IntLit(0).getText)) + else (UnknownWidth(), string2BigInt(ctx.IntLit(0).getText)) SIntValue(value, width) } case "validif(" => ValidIf(visitExp(ctx.exp(0)), visitExp(ctx.exp(1)), UnknownType()) diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 27ec5516..46c66c82 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -31,6 +31,12 @@ case class WInvalid() extends Expression case class WDefInstance(info:Info,name:String,module:String,tpe:Type) extends Stmt +case class VarWidth(name:String) extends Width +case class PlusWidth(arg1:Width,arg2:Width) extends Width +case class MinusWidth(arg1:Width,arg2:Width) extends Width +case class MaxWidth(args:Seq[Width]) extends Width +case class MinWidth(args:Seq[Width]) extends Width +case class ExpWidth(arg1:Width) extends Width //case class IntWidth(width: BigInt) extends Width //case object UnknownWidth extends Width |
