diff options
| author | Donggyu Kim | 2016-08-30 16:29:18 -0700 |
|---|---|---|
| committer | Donggyu Kim | 2016-09-07 11:57:36 -0700 |
| commit | b1b977407d12878fb5d8ea92950888002beb258b (patch) | |
| tree | 429f7acf1f95b0c1e3e9b9b1f2d528c49761356b /src | |
| parent | 8bb62b613956cff472cc89b28013b3f4af254224 (diff) | |
clean up Utils.scala
remove unnecessary functions & change spaces
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 13 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 1109 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckChirrtl.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 33 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 74 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/SplitExpressions.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Uniquify.scala | 11 |
9 files changed, 458 insertions, 801 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index e6314b72..6c658257 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -68,9 +68,9 @@ class VerilogEmitter extends Emitter { var mname = "" def wref (n:String,t:Type) = WRef(n,t,ExpKind(),UNKNOWNGENDER) def remove_root (ex:Expression) : Expression = { - (ex.as[WSubField].get.exp) match { + (ex.asInstanceOf[WSubField].exp) match { case (e:WSubField) => remove_root(e) - case (e:WRef) => WRef(ex.as[WSubField].get.name,ex.tpe,InstanceKind(),UNKNOWNGENDER) + case (e:WRef) => WRef(ex.asInstanceOf[WSubField].name,ex.tpe,InstanceKind(),UNKNOWNGENDER) } } def not_empty (s:ArrayBuffer[_]) : Boolean = if (s.size == 0) false else true @@ -120,7 +120,7 @@ class VerilogEmitter extends Emitter { case (i:Long) => w.get.write(i.toString) case (t:VIndent) => w.get.write(" ") case (s:Seq[Any]) => { - s.foreach((x:Any) => emit2(x.as[Any].get, top + 1)) + s.foreach((x:Any) => emit2(x, top + 1)) if (top == 0) w.get.write("\n") } } @@ -142,7 +142,10 @@ class VerilogEmitter extends Emitter { } def op_stream (doprim:DoPrim) : Seq[Any] = { def cast_if (e:Expression) : Any = { - val signed = doprim.args.find(x => x.tpe.typeof[SIntType]) + val signed = doprim.args.find(x => x.tpe match { + case _: SIntType => true + case _ => false + }) if (signed == None) e else e.tpe match { case (t:SIntType) => Seq("$signed(",e,")") @@ -478,7 +481,7 @@ class VerilogEmitter extends Emitter { initialize(e) } case (s:IsInvalid) => { - val wref = netlist(s.expr).as[WRef].get + val wref = netlist(s.expr).asInstanceOf[WRef] declare("wire",wref.name,s.expr.tpe) invalidAssign(wref) } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 5854fb95..9404e5e2 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -36,16 +36,14 @@ MODIFICATIONS. package firrtl -import scala.collection.mutable.StringBuilder +import firrtl.ir._ +import firrtl.PrimOps._ +import firrtl.Mappers._ +import firrtl.WrappedExpression._ +import firrtl.WrappedType._ +import scala.collection.mutable.{StringBuilder, ArrayBuffer, LinkedHashMap, HashMap, HashSet} import java.io.PrintWriter import com.typesafe.scalalogging.LazyLogging -import WrappedExpression._ -import firrtl.WrappedType._ -import firrtl.Mappers._ -import firrtl.PrimOps._ -import firrtl.ir._ -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.LinkedHashMap //import scala.reflect.runtime.universe._ class FIRRTLException(str: String) extends Exception(str) @@ -82,91 +80,73 @@ object Utils extends LazyLogging { if (bi < BigInt(0)) "\"h" + bi.toString(16).substring(1) + "\"" else "\"h" + bi.toString(16) + "\"" - implicit class WithAs[T](x: T) { - import scala.reflect._ - def as[O: ClassTag]: Option[O] = x match { - case o: O => Some(o) - case _ => None } - def typeof[O: ClassTag]: Boolean = x match { - case o: O => true - case _ => false } - } - implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x) - def ceil_log2(x: BigInt): BigInt = (x-1).bitLength - def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt - def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b - def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a - def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 - val gen_names = Map[String,Int]() - val delin = "_" - val BoolType = UIntType(IntWidth(1)) - val one = UIntLiteral(BigInt(1),IntWidth(1)) - val zero = UIntLiteral(BigInt(0),IntWidth(1)) - def uint (i:Int) : UIntLiteral = { - val num_bits = req_num_bits(i) - val w = IntWidth(scala.math.max(1,num_bits - 1)) - UIntLiteral(BigInt(i),w) - } - def req_num_bits (i: Int) : Int = { - val ix = if (i < 0) ((-1 * i) - 1) else i - ceil_log2(ix + 1) + 1 - } - def AND (e1:WrappedExpression,e2:WrappedExpression) : Expression = { - if (e1 == e2) e1.e1 - else if ((e1 == we(zero)) | (e2 == we(zero))) zero - else if (e1 == we(one)) e2.e1 - else if (e2 == we(one)) e1.e1 - else DoPrim(And,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) - } - - def OR (e1:WrappedExpression,e2:WrappedExpression) : Expression = { - if (e1 == e2) e1.e1 - else if ((e1 == we(one)) | (e2 == we(one))) one - else if (e1 == we(zero)) e2.e1 - else if (e2 == we(zero)) e1.e1 - else DoPrim(Or,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) - } - def EQV (e1:Expression,e2:Expression) : Expression = { DoPrim(Eq,Seq(e1,e2),Seq(),e1.tpe) } - def NOT (e1:WrappedExpression) : Expression = { - if (e1 == we(one)) zero - else if (e1 == we(zero)) one - else DoPrim(Eq,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1))) - } + implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x) + def ceil_log2(x: BigInt): BigInt = (x-1).bitLength + def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt + def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b + def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a + def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 + val BoolType = UIntType(IntWidth(1)) + val one = UIntLiteral(BigInt(1),IntWidth(1)) + val zero = UIntLiteral(BigInt(0),IntWidth(1)) + def uint (i:Int) : UIntLiteral = { + val num_bits = req_num_bits(i) + val w = IntWidth(scala.math.max(1,num_bits - 1)) + UIntLiteral(BigInt(i),w) + } + def req_num_bits (i: Int) : Int = { + val ix = if (i < 0) ((-1 * i) - 1) else i + ceil_log2(ix + 1) + 1 + } + def EQV (e1:Expression,e2:Expression) : Expression = + DoPrim(Eq, Seq(e1, e2), Nil, e1.tpe) + // TODO: these should be fixed + def AND (e1:WrappedExpression,e2:WrappedExpression) : Expression = { + if (e1 == e2) e1.e1 + else if ((e1 == we(zero)) | (e2 == we(zero))) zero + else if (e1 == we(one)) e2.e1 + else if (e2 == we(one)) e1.e1 + else DoPrim(And,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) + } + def OR (e1:WrappedExpression,e2:WrappedExpression) : Expression = { + if (e1 == e2) e1.e1 + else if ((e1 == we(one)) | (e2 == we(one))) one + else if (e1 == we(zero)) e2.e1 + else if (e2 == we(zero)) e1.e1 + else DoPrim(Or,Seq(e1.e1,e2.e1),Seq(),UIntType(IntWidth(1))) + } + def NOT (e1:WrappedExpression) : Expression = { + if (e1 == we(one)) zero + else if (e1 == we(zero)) one + else DoPrim(Eq,Seq(e1.e1,zero),Seq(),UIntType(IntWidth(1))) + } - - //def MUX (p:Expression,e1:Expression,e2:Expression) : Expression = { - // Mux(p,e1,e2,mux_type(e1.tpe,e2.tpe)) - //} + def create_mask(dt: Type): Type = dt match { + case t: VectorType => VectorType(create_mask(t.tpe),t.size) + case t: BundleType => BundleType(t.fields.map (f => f.copy(tpe=create_mask(f.tpe)))) + case t: UIntType => BoolType + case t: SIntType => BoolType + } - def create_mask (dt:Type) : Type = { - dt match { - case t:VectorType => VectorType(create_mask(t.tpe),t.size) - case t:BundleType => { - val fieldss = t.fields.map { f => Field(f.name,f.flip,create_mask(f.tpe)) } - BundleType(fieldss) - } - case t:UIntType => BoolType - case t:SIntType => BoolType + def create_exps(n: String, t: Type): Seq[Expression] = + create_exps(WRef(n, t, ExpKind(), UNKNOWNGENDER)) + def create_exps(e: Expression): Seq[Expression] = e match { + case (e: Mux) => + val e1s = create_exps(e.tval) + val e2s = create_exps(e.fval) + e1s zip e2s map {case (e1, e2) => + Mux(e.cond, e1, e2, mux_type_and_widths(e1,e2)) } - } - def create_exps (n:String, t:Type) : Seq[Expression] = - create_exps(WRef(n,t,ExpKind(),UNKNOWNGENDER)) - def create_exps (e:Expression) : Seq[Expression] = e match { - case (e:Mux) => - val e1s = create_exps(e.tval) - val e2s = create_exps(e.fval) - (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type_and_widths(e1,e2))) - case (e:ValidIf) => create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe)) - case (e) => e.tpe match { - case (_:GroundType) => Seq(e) - case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => - exps ++ create_exps(WSubField(e,f.name,f.tpe,times(gender(e), f.flip)))) - case (t:VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => - exps ++ create_exps(WSubIndex(e,i,t.tpe,gender(e)))) - } - } - - def get_flip (t:Type, i:Int, f:Orientation) : Orientation = { + case (e: ValidIf) => create_exps(e.value) map (e1 => ValidIf(e.cond, e1, e1.tpe)) + case (e) => e.tpe match { + case (_: GroundType) => Seq(e) + case (t: BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => + exps ++ create_exps(WSubField(e, f.name, f.tpe,times(gender(e), f.flip)))) + case (t: VectorType) => ((0 until t.size) foldLeft Seq[Expression]())((exps, i) => + exps ++ create_exps(WSubIndex(e, i, t.tpe,gender(e)))) + } + } + def get_flip(t: Type, i: Int, f: Orientation): Orientation = { if (i >= get_size(t)) error("Shouldn't be here") t match { case (_: GroundType) => f @@ -196,364 +176,265 @@ object Utils extends LazyLogging { case (e: WSubField) => e.exp.tpe match {case b: BundleType => (b.fields takeWhile (_.name != e.name) foldLeft 0)( (point, f) => point + get_size(f.tpe)) - } - case (e: WSubIndex) => e.value * get_size(e.tpe) - case (e: WSubAccess) => get_point(e.exp) - } + } + case (e: WSubIndex) => e.value * get_size(e.tpe) + case (e: WSubAccess) => get_point(e.exp) + } - /** Returns true if t, or any subtype, contains a flipped field - * @param t [[firrtl.ir.Type]] - * @return if t contains [[firrtl.ir.Flip]] - */ - def hasFlip(t: Type): Boolean = t match { - case t: BundleType => - (t.fields exists (_.flip == Flip)) || - (t.fields exists (f => hasFlip(f.tpe))) - case t: VectorType => hasFlip(t.tpe) - case _ => false - } + /** Returns true if t, or any subtype, contains a flipped field + * @param t [[firrtl.ir.Type]] + * @return if t contains [[firrtl.ir.Flip]] + */ + def hasFlip(t: Type): Boolean = t match { + case t: BundleType => + (t.fields exists (_.flip == Flip)) || + (t.fields exists (f => hasFlip(f.tpe))) + case t: VectorType => hasFlip(t.tpe) + case _ => false + } //============== TYPES ================ - def mux_type (e1:Expression,e2:Expression) : Type = mux_type(e1.tpe,e2.tpe) - def mux_type (t1:Type,t2:Type) : Type = { - if (wt(t1) == wt(t2)) { - (t1,t2) match { - case (t1:UIntType,t2:UIntType) => UIntType(UnknownWidth) - case (t1:SIntType,t2:SIntType) => SIntType(UnknownWidth) - case (t1:VectorType,t2:VectorType) => VectorType(mux_type(t1.tpe,t2.tpe),t1.size) - case (t1:BundleType,t2:BundleType) => - BundleType((t1.fields,t2.fields).zipped.map((f1,f2) => { - Field(f1.name,f1.flip,mux_type(f1.tpe,f2.tpe)) - })) - } - } else UnknownType - } - def mux_type_and_widths (e1:Expression,e2:Expression) : Type = mux_type_and_widths(e1.tpe,e2.tpe) - def mux_type_and_widths (t1:Type,t2:Type) : Type = { - def wmax (w1:Width,w2:Width) : Width = { - (w1,w2) match { - case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width.max(w2.width)) - case (w1,w2) => MaxWidth(Seq(w1,w2)) - } - } - val wt1 = new WrappedType(t1) - val wt2 = new WrappedType(t2) - if (wt1 == wt2) { - (t1,t2) match { - case (t1:UIntType,t2:UIntType) => UIntType(wmax(t1.width,t2.width)) - case (t1:SIntType,t2:SIntType) => SIntType(wmax(t1.width,t2.width)) - case (t1:VectorType,t2:VectorType) => VectorType(mux_type_and_widths(t1.tpe,t2.tpe),t1.size) - case (t1:BundleType,t2:BundleType) => BundleType((t1.fields zip t2.fields).map{case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe,f2.tpe))}) - } - } else UnknownType - } - def module_type (m:DefModule) : Type = { - BundleType(m.ports.map(p => p.toField)) - } - def sub_type (v:Type) : Type = { - v match { - case v:VectorType => v.tpe - case v => UnknownType - } - } - def field_type (v:Type,s:String) : Type = { - v match { - case v:BundleType => { - val ft = v.fields.find(p => p.name == s) - ft match { - case ft:Some[Field] => ft.get.tpe - case ft => UnknownType - } - } - case v => UnknownType - } - } + def mux_type (e1:Expression, e2:Expression) : Type = mux_type(e1.tpe, e2.tpe) + def mux_type (t1:Type, t2:Type) : Type = (t1,t2) match { + case (t1:UIntType, t2:UIntType) => UIntType(UnknownWidth) + case (t1:SIntType, t2:SIntType) => SIntType(UnknownWidth) + case (t1:VectorType, t2:VectorType) => VectorType(mux_type(t1.tpe, t2.tpe), t1.size) + case (t1:BundleType, t2:BundleType) => BundleType((t1.fields zip t2.fields) map { + case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe)) + }) + case _ => UnknownType + } + def mux_type_and_widths (e1:Expression,e2:Expression) : Type = + mux_type_and_widths(e1.tpe, e2.tpe) + def mux_type_and_widths (t1:Type, t2:Type) : Type = { + def wmax (w1:Width, w2:Width) : Width = (w1,w2) match { + case (w1:IntWidth, w2:IntWidth) => IntWidth(w1.width max w2.width) + case (w1, w2) => MaxWidth(Seq(w1, w2)) + } + (t1, t2) match { + case (t1:UIntType, t2:UIntType) => UIntType(wmax(t1.width, t2.width)) + case (t1:SIntType, t2:SIntType) => SIntType(wmax(t1.width, t2.width)) + case (t1:VectorType, t2:VectorType) => VectorType( + mux_type_and_widths(t1.tpe, t2.tpe), t1.size) + case (t1:BundleType, t2:BundleType) => BundleType((t1.fields zip t2.fields) map { + case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe, f2.tpe)) + }) + case _ => UnknownType + } + } + + def module_type(m: DefModule): Type = BundleType(m.ports map { + case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe) + }) + def sub_type(v: Type): Type = v match { + case v: VectorType => v.tpe + case v => UnknownType + } + def field_type(v:Type, s: String) : Type = v match { + case v: BundleType => v.fields find (_.name == s) match { + case Some(f) => f.tpe + case None => UnknownType + } + case v => UnknownType + } ////===================================== - def widthBANG (t:Type) : Width = { - t match { - case g: GroundType => g.width - case t => error("No width!") - } - } - def long_BANG (t:Type) : Long = { - (t) match { - case g: GroundType => - g.width match { - case IntWidth(x) => x.toLong - case _ => throw new FIRRTLException(s"Expecting IntWidth, got: ${g.width}") - } - case (t:BundleType) => { - var w = 0 - for (f <- t.fields) { w = w + long_BANG(f.tpe).toInt } - w - } - case (t:VectorType) => t.size * long_BANG(t.tpe) - } - } -// ================================= - def error(str:String) = throw new FIRRTLException(str) + def widthBANG (t:Type) : Width = t match { + case g: GroundType => g.width + case t => error("No width!") + } + def long_BANG(t: Type): Long = t match { + case (g: GroundType) => g.width match { + case IntWidth(x) => x.toLong + case _ => error(s"Expecting IntWidth, got: ${g.width}") + } + case (t: BundleType) => (t.fields foldLeft 0)((w, f) => + w + long_BANG(f.tpe).toInt) + case (t: VectorType) => t.size * long_BANG(t.tpe) + } - implicit class FirrtlNodeUtils(node: FirrtlNode) { - def getType(): Type = - node match { - case e: Expression => e.getType - case s: Statement => s.getType - //case f: Field => f.getType - case t: Type => t.getType - case p: Port => p.getType - case _ => UnknownType - } - } +// ================================= + def error(str: String) = throw new FIRRTLException(str) //// =============== EXPANSION FUNCTIONS ================ - def get_size(t: Type) : Int = t match { - case (t: BundleType) => (t.fields foldLeft 0)( - (sum, f) => sum + get_size(f.tpe)) - case (t: VectorType) => t.size * get_size(t.tpe) - case (t) => 1 - } + def get_size(t: Type): Int = t match { + case (t: BundleType) => (t.fields foldLeft 0)( + (sum, f) => sum + get_size(f.tpe)) + case (t: VectorType) => t.size * get_size(t.tpe) + case (t) => 1 + } - def get_valid_points (t1: Type, t2: Type, flip1: Orientation, flip2: Orientation) : Seq[(Int,Int)] = { - //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) - (t1, t2) match { - case (t1: UIntType, t2: UIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil - case (t1: SIntType, t2: SIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil - case (t1: BundleType, t2: BundleType) => - def emptyMap = Map[String, (Type, Orientation, Int)]() - val t1_fields = ((t1.fields foldLeft (emptyMap, 0)){case ((map, ilen), f1) => - (map + (f1.name -> (f1.tpe, f1.flip, ilen)), ilen + get_size(f1.tpe))})._1 - ((t2.fields foldLeft (Seq[(Int, Int)](), 0)){case ((points, jlen), f2) => - t1_fields get f2.name match { - case None => (points, jlen + get_size(f2.tpe)) - case Some((f1_tpe, f1_flip, ilen))=> - val f1_times = times(flip1, f1_flip) - val f2_times = times(flip2, f2.flip) - val ls = get_valid_points(f1_tpe, f2.tpe, f1_times, f2_times) - (points ++ (ls map {case (x, y) => (x + ilen, y + jlen)}), jlen + get_size(f2.tpe)) - } - })._1 - case (t1: VectorType, t2: VectorType) => - val size = math.min(t1.size, t2.size) - (((0 until size) foldLeft (Seq[(Int, Int)](), 0, 0)){case ((points, ilen, jlen), _) => - val ls = get_valid_points(t1.tpe, t2.tpe, flip1, flip2) - (points ++ (ls map {case (x, y) => ((x + ilen), (y + jlen))}), - ilen + get_size(t1.tpe), jlen + get_size(t2.tpe)) - })._1 - case (ClockType, ClockType) => if (flip1 == flip2) Seq((0, 0)) else Nil - case _ => error("shouldn't be here") - } - } + def get_valid_points(t1: Type, t2: Type, flip1: Orientation, flip2: Orientation): Seq[(Int,Int)] = { + //;println_all(["Inside with t1:" t1 ",t2:" t2 ",f1:" flip1 ",f2:" flip2]) + (t1, t2) match { + case (t1: UIntType, t2: UIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil + case (t1: SIntType, t2: SIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil + case (t1: BundleType, t2: BundleType) => + def emptyMap = Map[String, (Type, Orientation, Int)]() + val t1_fields = ((t1.fields foldLeft (emptyMap, 0)){case ((map, ilen), f1) => + (map + (f1.name -> (f1.tpe, f1.flip, ilen)), ilen + get_size(f1.tpe))})._1 + ((t2.fields foldLeft (Seq[(Int, Int)](), 0)){case ((points, jlen), f2) => + t1_fields get f2.name match { + case None => (points, jlen + get_size(f2.tpe)) + case Some((f1_tpe, f1_flip, ilen))=> + val f1_times = times(flip1, f1_flip) + val f2_times = times(flip2, f2.flip) + val ls = get_valid_points(f1_tpe, f2.tpe, f1_times, f2_times) + (points ++ (ls map {case (x, y) => (x + ilen, y + jlen)}), jlen + get_size(f2.tpe)) + } + })._1 + case (t1: VectorType, t2: VectorType) => + val size = math.min(t1.size, t2.size) + (((0 until size) foldLeft (Seq[(Int, Int)](), 0, 0)){case ((points, ilen, jlen), _) => + val ls = get_valid_points(t1.tpe, t2.tpe, flip1, flip2) + (points ++ (ls map {case (x, y) => ((x + ilen), (y + jlen))}), + ilen + get_size(t1.tpe), jlen + get_size(t2.tpe)) + })._1 + case (ClockType, ClockType) => if (flip1 == flip2) Seq((0, 0)) else Nil + case _ => error("shouldn't be here") + } + } // =========== GENDER/FLIP UTILS ============ - def swap (g:Gender) : Gender = { - g match { - case UNKNOWNGENDER => UNKNOWNGENDER - case MALE => FEMALE - case FEMALE => MALE - case BIGENDER => BIGENDER - } - } - def swap (d:Direction) : Direction = { - d match { - case Output => Input - case Input => Output - } - } - def swap (f:Orientation) : Orientation = { - f match { - case Default => Flip - case Flip => Default - } - } - def to_dir (g:Gender) : Direction = { - g match { - case MALE => Input - case FEMALE => Output - } - } - def to_gender (d:Direction) : Gender = { - d match { - case Input => MALE - case Output => FEMALE - } - } - def toGender(f: Orientation): Gender = f match { - case Default => FEMALE - case Flip => MALE + def swap(g: Gender) : Gender = g match { + case UNKNOWNGENDER => UNKNOWNGENDER + case MALE => FEMALE + case FEMALE => MALE + case BIGENDER => BIGENDER + } + def swap(d: Direction) : Direction = d match { + case Output => Input + case Input => Output + } + def swap(f: Orientation) : Orientation = f match { + case Default => Flip + case Flip => Default + } + def to_dir(g: Gender): Direction = g match { + case MALE => Input + case FEMALE => Output } - def toFlip(g: Gender): Orientation = g match { + def to_gender(d: Direction): Gender = d match { + case Input => MALE + case Output => FEMALE + } + def to_flip(d: Direction): Orientation = d match { + case Input => Flip + case Output => Default + } + def to_flip(g: Gender): Orientation = g match { case MALE => Flip case FEMALE => Default } - def field_flip (v:Type,s:String) : Orientation = { - v match { - case v:BundleType => { - val ft = v.fields.find {p => p.name == s} - ft match { - case ft:Some[Field] => ft.get.flip - case ft => Default - } - } - case v => Default - } - } - def get_field (v:Type,s:String) : Field = { - v match { - case v:BundleType => { - val ft = v.fields.find {p => p.name == s} - ft match { - case ft:Some[Field] => ft.get - case ft => error("Shouldn't be here"); Field("blah",Default,UnknownType) - } - } - case v => error("Shouldn't be here"); Field("blah",Default,UnknownType) - } - } - def times (flip:Orientation, d:Direction) : Direction = times(flip, d) - def times (d:Direction,flip:Orientation) : Direction = { - flip match { - case Default => d - case Flip => swap(d) - } - } - def times (g: Gender, d: Direction): Direction = times(d, g) - def times (d: Direction, g: Gender): Direction = g match { - case FEMALE => d - case MALE => swap(d) // MALE == INPUT == REVERSE - } + def field_flip(v:Type, s:String) : Orientation = v match { + case (v:BundleType) => v.fields find (_.name == s) match { + case Some(ft) => ft.flip + case None => Default + } + case v => Default + } + def get_field(v:Type, s:String) : Field = v match { + case (v:BundleType) => v.fields find (_.name == s) match { + case Some(ft) => ft + case None => error("Shouldn't be here") + } + case v => error("Shouldn't be here") + } - def times (g:Gender,flip:Orientation) : Gender = times(flip, g) - def times (flip:Orientation, g:Gender) : Gender = { - flip match { - case Default => g - case Flip => swap(g) - } - } - def times (f1:Orientation, f2:Orientation) : Orientation = { - f2 match { - case Default => f1 - case Flip => swap(f1) - } - } + def times(flip: Orientation, d: Direction): Direction = times(flip, d) + def times(d: Direction,flip: Orientation): Direction = flip match { + case Default => d + case Flip => swap(d) + } + def times(g: Gender, d: Direction): Direction = times(d, g) + def times(d: Direction, g: Gender): Direction = g match { + case FEMALE => d + case MALE => swap(d) // MALE == INPUT == REVERSE + } + def times(g: Gender,flip: Orientation): Gender = times(flip, g) + def times(flip: Orientation, g: Gender): Gender = flip match { + case Default => g + case Flip => swap(g) + } + def times(f1: Orientation, f2: Orientation): Orientation = f2 match { + case Default => f1 + case Flip => swap(f1) + } // =========== ACCESSORS ========= - def info (s:Statement) : Info = { - s match { - case s:DefWire => s.info - case s:DefRegister => s.info - case s:DefInstance => s.info - case s:WDefInstance => s.info - case s:DefMemory => s.info - case s:DefNode => s.info - case s:Conditionally => s.info - case s:PartialConnect => s.info - case s:Connect => s.info - case s:IsInvalid => s.info - case s:Stop => s.info - case s:Print => s.info - case s:Block => NoInfo - case EmptyStmt => NoInfo - } - } - def gender (e:Expression) : Gender = { - e match { - case e:WRef => e.gender - case e:WSubField => e.gender - case e:WSubIndex => e.gender - case e:WSubAccess => e.gender - case e:DoPrim => MALE - case e:UIntLiteral => MALE - case e:SIntLiteral => MALE - case e:Mux => MALE - case e:ValidIf => MALE - case e:WInvalid => MALE - case e => println(e); error("Shouldn't be here") - }} - def get_gender (s:Statement) : Gender = - s match { - case s:DefWire => BIGENDER - case s:DefRegister => BIGENDER - case s:WDefInstance => MALE - case s:DefNode => MALE - case s:DefInstance => MALE - case s:DefMemory => MALE - case s:Block => UNKNOWNGENDER - case s:Connect => UNKNOWNGENDER - case s:PartialConnect => UNKNOWNGENDER - case s:Stop => UNKNOWNGENDER - case s:Print => UNKNOWNGENDER - case EmptyStmt => UNKNOWNGENDER - case s:IsInvalid => UNKNOWNGENDER - } - def get_gender (p:Port) : Gender = - if (p.direction == Input) MALE else FEMALE - def kind (e:Expression) : Kind = - e match { - case e:WRef => e.kind - case e:WSubField => kind(e.exp) - case e:WSubIndex => kind(e.exp) - case e => ExpKind() - } - def get_type (s:Statement) : Type = { - s match { - case s:DefWire => s.tpe - case s:DefRegister => s.tpe - case s:DefNode => s.value.tpe - case s:DefMemory => { - val depth = s.depth - val addr = Field("addr",Default,UIntType(IntWidth(scala.math.max(ceil_log2(depth), 1)))) - val en = Field("en",Default,BoolType) - val clk = Field("clk",Default,ClockType) - val def_data = Field("data",Default,s.dataType) - val rev_data = Field("data",Flip,s.dataType) - val mask = Field("mask",Default,create_mask(s.dataType)) - val wmode = Field("wmode",Default,UIntType(IntWidth(1))) - val rdata = Field("rdata",Flip,s.dataType) - val wdata = Field("wdata",Default,s.dataType) - val wmask = Field("wmask",Default,create_mask(s.dataType)) - val read_type = BundleType(Seq(rev_data,addr,en,clk)) - val write_type = BundleType(Seq(def_data,mask,addr,en,clk)) - val readwrite_type = BundleType(Seq(wmode,rdata,wdata,wmask,addr,en,clk)) - - val mem_fields = ArrayBuffer[Field]() - s.readers.foreach {x => mem_fields += Field(x,Flip,read_type)} - s.writers.foreach {x => mem_fields += Field(x,Flip,write_type)} - s.readwriters.foreach {x => mem_fields += Field(x,Flip,readwrite_type)} - BundleType(mem_fields) - } - case s:DefInstance => UnknownType - case s:WDefInstance => s.tpe - case _ => UnknownType - }} - def get_name (s:Statement) : String = { - s match { - case s:DefWire => s.name - case s:DefRegister => s.name - case s:DefNode => s.name - case s:DefMemory => s.name - case s:DefInstance => s.name - case s:WDefInstance => s.name - case _ => error("Shouldn't be here"); "blah" - }} - def get_info (s:Statement) : Info = { - s match { - case s:DefWire => s.info - case s:DefRegister => s.info - case s:DefInstance => s.info - case s:WDefInstance => s.info - case s:DefMemory => s.info - case s:DefNode => s.info - case s:Conditionally => s.info - case s:PartialConnect => s.info - case s:Connect => s.info - case s:IsInvalid => s.info - case s:Stop => s.info - case s:Print => s.info - case _ => NoInfo - }} + def kind(e: Expression): Kind = e match { + case e: WRef => e.kind + case e: WSubField => kind(e.exp) + case e: WSubIndex => kind(e.exp) + case e: WSubAccess => kind(e.exp) + case e => ExpKind() + } + def gender (e: Expression): Gender = e match { + case e: WRef => e.gender + case e: WSubField => e.gender + case e: WSubIndex => e.gender + case e: WSubAccess => e.gender + case e: DoPrim => MALE + case e: UIntLiteral => MALE + case e: SIntLiteral => MALE + case e: Mux => MALE + case e: ValidIf => MALE + case e: WInvalid => MALE + case e => println(e); error("Shouldn't be here") + } + def get_gender(s:Statement): Gender = s match { + case s: DefWire => BIGENDER + case s: DefRegister => BIGENDER + case s: WDefInstance => MALE + case s: DefNode => MALE + case s: DefInstance => MALE + case s: DefMemory => MALE + case s: Block => UNKNOWNGENDER + case s: Connect => UNKNOWNGENDER + case s: PartialConnect => UNKNOWNGENDER + case s: Stop => UNKNOWNGENDER + case s: Print => UNKNOWNGENDER + case s: IsInvalid => UNKNOWNGENDER + case EmptyStmt => UNKNOWNGENDER + } + def get_gender(p: Port): Gender = if (p.direction == Input) MALE else FEMALE + def get_type(s: Statement): Type = s match { + case s: DefWire => s.tpe + case s: DefRegister => s.tpe + case s: DefNode => s.value.tpe + case s: DefMemory => + val depth = s.depth + val addr = Field("addr", Default, UIntType(IntWidth(scala.math.max(ceil_log2(depth), 1)))) + val en = Field("en", Default, BoolType) + val clk = Field("clk", Default, ClockType) + val def_data = Field("data", Default, s.dataType) + val rev_data = Field("data", Flip, s.dataType) + val mask = Field("mask", Default, create_mask(s.dataType)) + val wmode = Field("wmode", Default, UIntType(IntWidth(1))) + val rdata = Field("rdata", Flip, s.dataType) + val wdata = Field("wdata", Default, s.dataType) + val wmask = Field("wmask", Default, create_mask(s.dataType)) + val read_type = BundleType(Seq(rev_data, addr, en, clk)) + val write_type = BundleType(Seq(def_data, mask, addr, en, clk)) + val readwrite_type = BundleType(Seq(wmode, rdata, wdata, wmask, addr, en, clk)) + BundleType( + (s.readers map (Field(_, Flip, read_type))) ++ + (s.writers map (Field(_, Flip, write_type))) ++ + (s.readwriters map (Field(_, Flip, readwrite_type))) + ) + case s: WDefInstance => s.tpe + case _ => UnknownType + } + def get_name(s: Statement): String = s match { + case s: HasName => s.name + case _ => error("Shouldn't be here") + } + def get_info(s: Statement): Info = s match { + case s: HasInfo => s.info + case _ => NoInfo + } /** Splits an Expression into root Ref and tail * @@ -623,11 +504,12 @@ object Utils extends LazyLogging { case None => getRootDecl(root.name)(m.body) match { case Some(decl) => decl - case None => throw new DeclarationNotFoundException(s"[module ${m.name}] Reference ${expr.serialize} not declared!") + case None => throw new DeclarationNotFoundException( + s"[module ${m.name}] Reference ${expr.serialize} not declared!") } } rootDecl - case e => throw new FIRRTLException(s"getDeclaration does not support Expressions of type ${e.getClass}") + case e => error(s"getDeclaration does not support Expressions of type ${e.getClass}") } } @@ -692,322 +574,79 @@ object Utils extends LazyLogging { // to-stmt(body(m)) // map(to-port,ports(m)) // sym-hash - implicit class StmtUtils(stmt: Statement) { - def getType(): Type = - stmt match { - case s: DefWire => s.tpe - case s: DefRegister => s.tpe - case s: DefMemory => s.dataType - case _ => UnknownType - } + val v_keywords = Set( + "alias", "always", "always_comb", "always_ff", "always_latch", + "and", "assert", "assign", "assume", "attribute", "automatic", - def getInfo: Info = - stmt match { - case s: DefWire => s.info - case s: DefRegister => s.info - case s: DefInstance => s.info - case s: DefMemory => s.info - case s: DefNode => s.info - case s: Conditionally => s.info - case s: PartialConnect => s.info - case s: Connect => s.info - case s: IsInvalid => s.info - case s: Stop => s.info - case s: Print => s.info - case _ => NoInfo - } - } + "before", "begin", "bind", "bins", "binsof", "bit", "break", + "buf", "bufif0", "bufif1", "byte", - implicit class FlipUtils(f: Orientation) { - def flip(): Orientation = { - f match { - case Flip => Default - case Default => Flip - } - } - - def toDirection(): Direction = { - f match { - case Default => Output - case Flip => Input - } - } - } + "case", "casex", "casez", "cell", "chandle", "class", "clocking", + "cmos", "config", "const", "constraint", "context", "continue", + "cover", "covergroup", "coverpoint", "cross", - implicit class FieldUtils(field: Field) { - def flip(): Field = Field(field.name, field.flip.flip, field.tpe) + "deassign", "default", "defparam", "design", "disable", "dist", "do", - def getType(): Type = field.tpe - def toPort(info: Info = NoInfo): Port = - Port(info, field.name, field.flip.toDirection, field.tpe) - } + "edge", "else", "end", "endattribute", "endcase", "endclass", + "endclocking", "endconfig", "endfunction", "endgenerate", + "endgroup", "endinterface", "endmodule", "endpackage", + "endprimitive", "endprogram", "endproperty", "endspecify", + "endsequence", "endtable", "endtask", + "enum", "event", "expect", "export", "extends", "extern", - implicit class TypeUtils(t: Type) { - def isGround: Boolean = t match { - case (_: UIntType | _: SIntType | ClockType) => true - case (_: BundleType | _: VectorType) => false - } - def isAggregate: Boolean = !t.isGround + "final", "first_match", "for", "force", "foreach", "forever", + "fork", "forkjoin", "function", - def getType(): Type = - t match { - case v: VectorType => v.tpe - case tpe: Type => UnknownType - } + "generate", "genvar", - def wipeWidth(): Type = - t match { - case t: UIntType => UIntType(UnknownWidth) - case t: SIntType => SIntType(UnknownWidth) - case _ => t - } - } + "highz0", "highz1", - implicit class DirectionUtils(d: Direction) { - def toFlip(): Orientation = { - d match { - case Input => Flip - case Output => Default - } - } - } - - implicit class PortUtils(p: Port) { - def getType(): Type = p.tpe - def toField(): Field = Field(p.name, p.direction.toFlip, p.tpe) - } + "if", "iff", "ifnone", "ignore_bins", "illegal_bins", "import", + "incdir", "include", "initial", "initvar", "inout", "input", + "inside", "instance", "int", "integer", "interconnect", + "interface", "intersect", + + "join", "join_any", "join_none", "large", "liblist", "library", + "local", "localparam", "logic", "longint", + + "macromodule", "matches", "medium", "modport", "module", + + "nand", "negedge", "new", "nmos", "nor", "noshowcancelled", + "not", "notif0", "notif1", "null", + + "or", "output", + + "package", "packed", "parameter", "pmos", "posedge", + "primitive", "priority", "program", "property", "protected", + "pull0", "pull1", "pulldown", "pullup", + "pulsestyle_onevent", "pulsestyle_ondetect", "pure", + + "rand", "randc", "randcase", "randsequence", "rcmos", + "real", "realtime", "ref", "reg", "release", "repeat", + "return", "rnmos", "rpmos", "rtran", "rtranif0", "rtranif1", + + "scalared", "sequence", "shortint", "shortreal", "showcancelled", + "signed", "small", "solve", "specify", "specparam", "static", + "strength", "string", "strong0", "strong1", "struct", "super", + "supply0", "supply1", + + "table", "tagged", "task", "this", "throughout", "time", "timeprecision", + "timeunit", "tran", "tranif0", "tranif1", "tri", "tri0", "tri1", "triand", + "trior", "trireg", "type","typedef", + + "union", "unique", "unsigned", "use", + + "var", "vectored", "virtual", "void", + + "wait", "wait_order", "wand", "weak0", "weak1", "while", + "wildcard", "wire", "with", "within", "wor", + "xnor", "xor", - val v_keywords = Map[String,Boolean]() + - ("alias" -> true) + - ("always" -> true) + - ("always_comb" -> true) + - ("always_ff" -> true) + - ("always_latch" -> true) + - ("and" -> true) + - ("assert" -> true) + - ("assign" -> true) + - ("assume" -> true) + - ("attribute" -> true) + - ("automatic" -> true) + - ("before" -> true) + - ("begin" -> true) + - ("bind" -> true) + - ("bins" -> true) + - ("binsof" -> true) + - ("bit" -> true) + - ("break" -> true) + - ("buf" -> true) + - ("bufif0" -> true) + - ("bufif1" -> true) + - ("byte" -> true) + - ("case" -> true) + - ("casex" -> true) + - ("casez" -> true) + - ("cell" -> true) + - ("chandle" -> true) + - ("class" -> true) + - ("clocking" -> true) + - ("cmos" -> true) + - ("config" -> true) + - ("const" -> true) + - ("constraint" -> true) + - ("context" -> true) + - ("continue" -> true) + - ("cover" -> true) + - ("covergroup" -> true) + - ("coverpoint" -> true) + - ("cross" -> true) + - ("deassign" -> true) + - ("default" -> true) + - ("defparam" -> true) + - ("design" -> true) + - ("disable" -> true) + - ("dist" -> true) + - ("do" -> true) + - ("edge" -> true) + - ("else" -> true) + - ("end" -> true) + - ("endattribute" -> true) + - ("endcase" -> true) + - ("endclass" -> true) + - ("endclocking" -> true) + - ("endconfig" -> true) + - ("endfunction" -> true) + - ("endgenerate" -> true) + - ("endgroup" -> true) + - ("endinterface" -> true) + - ("endmodule" -> true) + - ("endpackage" -> true) + - ("endprimitive" -> true) + - ("endprogram" -> true) + - ("endproperty" -> true) + - ("endspecify" -> true) + - ("endsequence" -> true) + - ("endtable" -> true) + - ("endtask" -> true) + - ("enum" -> true) + - ("event" -> true) + - ("expect" -> true) + - ("export" -> true) + - ("extends" -> true) + - ("extern" -> true) + - ("final" -> true) + - ("first_match" -> true) + - ("for" -> true) + - ("force" -> true) + - ("foreach" -> true) + - ("forever" -> true) + - ("fork" -> true) + - ("forkjoin" -> true) + - ("function" -> true) + - ("generate" -> true) + - ("genvar" -> true) + - ("highz0" -> true) + - ("highz1" -> true) + - ("if" -> true) + - ("iff" -> true) + - ("ifnone" -> true) + - ("ignore_bins" -> true) + - ("illegal_bins" -> true) + - ("import" -> true) + - ("incdir" -> true) + - ("include" -> true) + - ("initial" -> true) + - ("initvar" -> true) + - ("inout" -> true) + - ("input" -> true) + - ("inside" -> true) + - ("instance" -> true) + - ("int" -> true) + - ("integer" -> true) + - ("interconnect" -> true) + - ("interface" -> true) + - ("intersect" -> true) + - ("join" -> true) + - ("join_any" -> true) + - ("join_none" -> true) + - ("large" -> true) + - ("liblist" -> true) + - ("library" -> true) + - ("local" -> true) + - ("localparam" -> true) + - ("logic" -> true) + - ("longint" -> true) + - ("macromodule" -> true) + - ("matches" -> true) + - ("medium" -> true) + - ("modport" -> true) + - ("module" -> true) + - ("nand" -> true) + - ("negedge" -> true) + - ("new" -> true) + - ("nmos" -> true) + - ("nor" -> true) + - ("noshowcancelled" -> true) + - ("not" -> true) + - ("notif0" -> true) + - ("notif1" -> true) + - ("null" -> true) + - ("or" -> true) + - ("output" -> true) + - ("package" -> true) + - ("packed" -> true) + - ("parameter" -> true) + - ("pmos" -> true) + - ("posedge" -> true) + - ("primitive" -> true) + - ("priority" -> true) + - ("program" -> true) + - ("property" -> true) + - ("protected" -> true) + - ("pull0" -> true) + - ("pull1" -> true) + - ("pulldown" -> true) + - ("pullup" -> true) + - ("pulsestyle_onevent" -> true) + - ("pulsestyle_ondetect" -> true) + - ("pure" -> true) + - ("rand" -> true) + - ("randc" -> true) + - ("randcase" -> true) + - ("randsequence" -> true) + - ("rcmos" -> true) + - ("real" -> true) + - ("realtime" -> true) + - ("ref" -> true) + - ("reg" -> true) + - ("release" -> true) + - ("repeat" -> true) + - ("return" -> true) + - ("rnmos" -> true) + - ("rpmos" -> true) + - ("rtran" -> true) + - ("rtranif0" -> true) + - ("rtranif1" -> true) + - ("scalared" -> true) + - ("sequence" -> true) + - ("shortint" -> true) + - ("shortreal" -> true) + - ("showcancelled" -> true) + - ("signed" -> true) + - ("small" -> true) + - ("solve" -> true) + - ("specify" -> true) + - ("specparam" -> true) + - ("static" -> true) + - ("strength" -> true) + - ("string" -> true) + - ("strong0" -> true) + - ("strong1" -> true) + - ("struct" -> true) + - ("super" -> true) + - ("supply0" -> true) + - ("supply1" -> true) + - ("table" -> true) + - ("tagged" -> true) + - ("task" -> true) + - ("this" -> true) + - ("throughout" -> true) + - ("time" -> true) + - ("timeprecision" -> true) + - ("timeunit" -> true) + - ("tran" -> true) + - ("tranif0" -> true) + - ("tranif1" -> true) + - ("tri" -> true) + - ("tri0" -> true) + - ("tri1" -> true) + - ("triand" -> true) + - ("trior" -> true) + - ("trireg" -> true) + - ("type" -> true) + - ("typedef" -> true) + - ("union" -> true) + - ("unique" -> true) + - ("unsigned" -> true) + - ("use" -> true) + - ("var" -> true) + - ("vectored" -> true) + - ("virtual" -> true) + - ("void" -> true) + - ("wait" -> true) + - ("wait_order" -> true) + - ("wand" -> true) + - ("weak0" -> true) + - ("weak1" -> true) + - ("while" -> true) + - ("wildcard" -> true) + - ("wire" -> true) + - ("with" -> true) + - ("within" -> true) + - ("wor" -> true) + - ("xnor" -> true) + - ("xor" -> true) + - ("SYNTHESIS" -> true) + - ("PRINTF_COND" -> true) + - ("VCS" -> true) + "SYNTHESIS", + "PRINTF_COND", + "VCS") } object MemoizedHash { diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index 60a49bac..e0e7c57a 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -105,7 +105,7 @@ object CheckChirrtl extends Pass with LazyLogging { e } def checkChirrtlS(s: Statement): Statement = { - sinfo = s.getInfo + sinfo = get_info(s) def checkName(name: String): String = { if (names.contains(name)) errors.append(new NotUniqueException(name)) else names(name) = true @@ -138,7 +138,7 @@ object CheckChirrtl extends Pass with LazyLogging { for (p <- m.ports) { sinfo = p.info names(p.name) = true - val tpe = p.getType + val tpe = p.tpe tpe map (checkChirrtlT) tpe map (checkChirrtlW) } diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index e61c55ea..6e49ce93 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -241,7 +241,7 @@ object CheckHighForm extends Pass with LazyLogging { else names(name) = true name } - sinfo = s.getInfo + sinfo = get_info(s) s map (checkName) s map (checkHighFormT) @@ -276,7 +276,7 @@ object CheckHighForm extends Pass with LazyLogging { for (p <- m.ports) { // FIXME should we set sinfo here? names(p.name) = true - val tpe = p.getType + val tpe = p.tpe tpe map (checkHighFormT) tpe map (checkHighFormW) } @@ -343,20 +343,29 @@ object CheckTypes extends Pass with LazyLogging { def all_ground (ls:Seq[Expression]) : Unit = { var error = false for (x <- ls ) { - if (!(x.tpe.typeof[UIntType] || x.tpe.typeof[SIntType])) error = true + x.tpe match { + case _: UIntType | _: SIntType => + case _ => error = true + } } if (error) errors.append(new OpNotGround(info,e.op.serialize)) } def all_uint (ls:Seq[Expression]) : Unit = { var error = false for (x <- ls ) { - if (!(x.tpe.typeof[UIntType])) error = true + x.tpe match { + case _: UIntType => + case _ => error = true + } } if (error) errors.append(new OpNotAllUInt(info,e.op.serialize)) } def is_uint (x:Expression) : Unit = { var error = false - if (!(x.tpe.typeof[UIntType])) error = true + x.tpe match { + case _: UIntType => + case _ => error = true + } if (error) errors.append(new OpNotUInt(info,e.op.serialize,x.serialize)) } @@ -447,11 +456,17 @@ object CheckTypes extends Pass with LazyLogging { case (e:Mux) => { if (wt(e.tval.tpe) != wt(e.fval.tpe)) errors.append(new MuxSameType(info)) if (!passive(e.tpe)) errors.append(new MuxPassiveTypes(info)) - if (!(e.cond.tpe.typeof[UIntType])) errors.append(new MuxCondUInt(info)) + e.cond.tpe match { + case _: UIntType => + case _ => errors.append(new MuxCondUInt(info)) + } } case (e:ValidIf) => { if (!passive(e.tpe)) errors.append(new ValidIfPassiveTypes(info)) - if (!(e.cond.tpe.typeof[UIntType])) errors.append(new ValidIfCondUInt(info)) + e.cond.tpe match { + case _: UIntType => + case _ => errors.append(new ValidIfCondUInt(info)) + } } case (_:UIntLiteral | _:SIntLiteral) => false } @@ -597,7 +612,7 @@ object CheckGenders extends Pass { (e) match { case (e:WRef) => genders(e.name) case (e:WSubField) => - val f = e.exp.tpe.as[BundleType].get.fields.find(f => f.name == e.name).get + val f = e.exp.tpe.asInstanceOf[BundleType].fields.find(f => f.name == e.name).get times(get_gender(e.exp,genders),f.flip) case (e:WSubIndex) => get_gender(e.exp,genders) case (e:WSubAccess) => get_gender(e.exp,genders) @@ -735,7 +750,7 @@ object CheckWidths extends Pass { } def check_width_s (s:Statement) : Statement = { s map (check_width_s) map (check_width_e(get_info(s))) - def tm (t:Type) : Type = mapr(check_width_w(info(s)) _,t) + def tm (t:Type) : Type = mapr(check_width_w(get_info(s)) _,t) s map (tm) } diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 7793c85c..a8fda1bf 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -5,7 +5,6 @@ package passes import scala.collection.mutable import firrtl.Mappers.{ExpMap,StmtMap} -import firrtl.Utils.WithAs import firrtl.ir._ import firrtl.passes.{PassException,PassExceptions} import Annotations.{Loose, Unstable, Annotation, TransID, Named, ModuleName, ComponentName, CircuitName, AnnotationMap} diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index e5661fae..a4c584ed 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -105,15 +105,15 @@ object LowerTypes extends Pass { require(tail.isEmpty) // there can't be a tail for these val memType = memDataTypeMap(mem.name) - if (memType.isGround) { - Seq(e) - } else { - val exps = create_exps(mem.name, memType) - exps map { e => - val loMemName = loweredName(e) - val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) - mergeRef(loMem, mergeRef(port, field)) - } + memType match { + case _: GroundType => Seq(e) + case _ => + val exps = create_exps(mem.name, memType) + exps map { e => + val loMemName = loweredName(e) + val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) + mergeRef(loMem, mergeRef(port, field)) + } } // Fields that need not be replicated for each // eg. mem.reader.data[0].a @@ -158,26 +158,26 @@ object LowerTypes extends Pass { s map lowerTypesStmt match { case s: DefWire => sinfo = s.info - if (s.tpe.isGround) { - s - } else { - val exps = create_exps(s.name, s.tpe) - val stmts = exps map (e => DefWire(s.info, loweredName(e), e.tpe)) - Block(stmts) + s.tpe match { + case _: GroundType => s + case _ => + val exps = create_exps(s.name, s.tpe) + val stmts = exps map (e => DefWire(s.info, loweredName(e), e.tpe)) + Block(stmts) } case s: DefRegister => sinfo = s.info - if (s.tpe.isGround) { - s map lowerTypesExp - } else { - val es = create_exps(s.name, s.tpe) - val inits = create_exps(s.init) map (lowerTypesExp) - val clock = lowerTypesExp(s.clock) - val reset = lowerTypesExp(s.reset) - val stmts = es zip inits map { case (e, i) => - DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i) - } - Block(stmts) + s.tpe match { + case _: GroundType => s map lowerTypesExp + case _ => + val es = create_exps(s.name, s.tpe) + val inits = create_exps(s.init) map (lowerTypesExp) + val clock = lowerTypesExp(s.clock) + val reset = lowerTypesExp(s.reset) + val stmts = es zip inits map { case (e, i) => + DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i) + } + Block(stmts) } // Could instead just save the type of each Module as it gets processed case s: WDefInstance => @@ -188,7 +188,7 @@ object LowerTypes extends Pass { val exps = create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) exps map ( e => // Flip because inst genders are reversed from Module type - Field(loweredName(e), toFlip(gender(e)).flip, e.tpe) + Field(loweredName(e), swap(to_flip(gender(e))), e.tpe) ) } WDefInstance(s.info, s.name, s.module, BundleType(fieldsx)) @@ -197,16 +197,16 @@ object LowerTypes extends Pass { case s: DefMemory => sinfo = s.info memDataTypeMap += (s.name -> s.dataType) - if (s.dataType.isGround) { - s - } else { - val exps = create_exps(s.name, s.dataType) - val stmts = exps map { e => - DefMemory(s.info, loweredName(e), e.tpe, s.depth, - s.writeLatency, s.readLatency, s.readers, s.writers, - s.readwriters) - } - Block(stmts) + s.dataType match { + case _: GroundType => s + case _ => + val exps = create_exps(s.name, s.dataType) + val stmts = exps map { e => + DefMemory(s.info, loweredName(e), e.tpe, s.depth, + s.writeLatency, s.readLatency, s.readers, s.writers, + s.readwriters) + } + Block(stmts) } // wire foo : { a , b } // node x = foo diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index a3ce49f7..880d6b1c 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -76,7 +76,7 @@ object RemoveAccesses extends Pass { def onStmt(s: Statement): Statement = { def create_temp(e: Expression): (Statement, Expression) = { val n = namespace.newTemp - (DefWire(info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e))) + (DefWire(get_info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e))) } /** Replaces a subaccess in a given male expression @@ -94,9 +94,9 @@ object RemoveAccesses extends Pass { stmts += wire rs.zipWithIndex foreach { case (x, i) if i < temps.size => - stmts += Connect(info(s),getTemp(i),x.base) + stmts += Connect(get_info(s),getTemp(i),x.base) case (x, i) => - stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt) + stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt) } temp } diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 12d9982b..3b6021ed 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -2,7 +2,7 @@ package firrtl package passes import firrtl.Mappers.{ExpMap, StmtMap} -import firrtl.Utils.{kind, gender, info} +import firrtl.Utils.{kind, gender, get_info} import firrtl.ir._ import scala.collection.mutable @@ -20,17 +20,17 @@ object SplitExpressions extends Pass { def split(e: Expression): Expression = e match { case e: DoPrim => { val name = namespace.newTemp - v += DefNode(info(s), name, e) + v += DefNode(get_info(s), name, e) WRef(name, e.tpe, kind(e), gender(e)) } case e: Mux => { val name = namespace.newTemp - v += DefNode(info(s), name, e) + v += DefNode(get_info(s), name, e) WRef(name, e.tpe, kind(e), gender(e)) } case e: ValidIf => { val name = namespace.newTemp - v += DefNode(info(s), name, e) + v += DefNode(get_info(s), name, e) WRef(name, e.tpe, kind(e), gender(e)) } case e => e diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index b1a20fdd..d034719a 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -109,8 +109,9 @@ object Uniquify extends Pass { val newName = findValidPrefix(f.name, Seq(""), namespace) namespace += newName Field(newName, f.flip, f.tpe) - } map { f => - if (f.tpe.isAggregate) { + } map { f => f.tpe match { + case _: GroundType => f + case _ => val tpe = recUniquifyNames(f.tpe, collection.mutable.HashSet()) val elts = enumerateNames(tpe) // Need leading _ for findValidPrefix, it doesn't add _ for checks @@ -123,8 +124,6 @@ object Uniquify extends Pass { } namespace ++= (elts map (e => LowerTypes.loweredName(prefix +: e))) Field(prefix, f.flip, tpe) - } else { - f } } BundleType(newFields) @@ -349,7 +348,9 @@ object Uniquify extends Pass { def uniquifyPorts(m: DefModule): DefModule = { def uniquifyPorts(ports: Seq[Port]): Seq[Port] = { - val portsType = BundleType(ports map (_.toField)) + val portsType = BundleType(ports map { + case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe) + }) val uniquePortsType = uniquifyNames(portsType, collection.mutable.HashSet()) val localMap = createNameMapping(portsType, uniquePortsType) portNameMap += (m.name -> localMap) |
