diff options
| author | Adam Izraelevitz | 2016-09-07 13:14:22 -0700 |
|---|---|---|
| committer | GitHub | 2016-09-07 13:14:22 -0700 |
| commit | 8647a25fec8c5e18d766ff3e3602d3345cd8549c (patch) | |
| tree | 429f7acf1f95b0c1e3e9b9b1f2d528c49761356b /src/main/scala/firrtl/passes | |
| parent | 0c6db9ef0669e3fb92fcc0bda2085f934d065f0b (diff) | |
| parent | b1b977407d12878fb5d8ea92950888002beb258b (diff) | |
Merge pull request #271 from ucb-bar/cleanup_utils
Clean up Utils
Diffstat (limited to 'src/main/scala/firrtl/passes')
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckChirrtl.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 71 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ConstProp.scala | 10 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ExpandWhens.scala | 4 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Inline.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/LowerTypes.scala | 82 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/PadWidths.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 72 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/SplitExpressions.scala | 14 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Uniquify.scala | 11 |
11 files changed, 148 insertions, 133 deletions
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index 60a49bac..e0e7c57a 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -105,7 +105,7 @@ object CheckChirrtl extends Pass with LazyLogging { e } def checkChirrtlS(s: Statement): Statement = { - sinfo = s.getInfo + sinfo = get_info(s) def checkName(name: String): String = { if (names.contains(name)) errors.append(new NotUniqueException(name)) else names(name) = true @@ -138,7 +138,7 @@ object CheckChirrtl extends Pass with LazyLogging { for (p <- m.ports) { sinfo = p.info names(p.name) = true - val tpe = p.getType + val tpe = p.tpe tpe map (checkChirrtlT) tpe map (checkChirrtlW) } diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 9ee20c0a..6e49ce93 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -241,7 +241,7 @@ object CheckHighForm extends Pass with LazyLogging { else names(name) = true name } - sinfo = s.getInfo + sinfo = get_info(s) s map (checkName) s map (checkHighFormT) @@ -276,7 +276,7 @@ object CheckHighForm extends Pass with LazyLogging { for (p <- m.ports) { // FIXME should we set sinfo here? names(p.name) = true - val tpe = p.getType + val tpe = p.tpe tpe map (checkHighFormT) tpe map (checkHighFormW) } @@ -336,27 +336,36 @@ object CheckTypes extends Pass with LazyLogging { def all_same_type (ls:Seq[Expression]) : Unit = { var error = false for (x <- ls) { - if (wt(tpe(ls.head)) != wt(tpe(x))) error = true + if (wt(ls.head.tpe) != wt(x.tpe)) error = true } if (error) errors.append(new OpNotAllSameType(info,e.op.serialize)) } def all_ground (ls:Seq[Expression]) : Unit = { var error = false for (x <- ls ) { - if (!(tpe(x).typeof[UIntType] || tpe(x).typeof[SIntType])) error = true + x.tpe match { + case _: UIntType | _: SIntType => + case _ => error = true + } } if (error) errors.append(new OpNotGround(info,e.op.serialize)) } def all_uint (ls:Seq[Expression]) : Unit = { var error = false for (x <- ls ) { - if (!(tpe(x).typeof[UIntType])) error = true + x.tpe match { + case _: UIntType => + case _ => error = true + } } if (error) errors.append(new OpNotAllUInt(info,e.op.serialize)) } def is_uint (x:Expression) : Unit = { var error = false - if (!(tpe(x).typeof[UIntType])) error = true + x.tpe match { + case _: UIntType => + case _ => error = true + } if (error) errors.append(new OpNotUInt(info,e.op.serialize,x.serialize)) } @@ -417,7 +426,7 @@ object CheckTypes extends Pass with LazyLogging { (e map (check_types_e(info))) match { case (e:WRef) => e case (e:WSubField) => { - (tpe(e.exp)) match { + (e.exp.tpe) match { case (t:BundleType) => { val ft = t.fields.find(p => p.name == e.name) if (ft == None) errors.append(new SubfieldNotInBundle(info,e.name)) @@ -426,7 +435,7 @@ object CheckTypes extends Pass with LazyLogging { } } case (e:WSubIndex) => { - (tpe(e.exp)) match { + (e.exp.tpe) match { case (t:VectorType) => { if (e.value >= t.size) errors.append(new IndexTooLarge(info,e.value)) } @@ -434,24 +443,30 @@ object CheckTypes extends Pass with LazyLogging { } } case (e:WSubAccess) => { - (tpe(e.exp)) match { + (e.exp.tpe) match { case (t:VectorType) => false case (t) => errors.append(new IndexOnNonVector(info)) } - (tpe(e.index)) match { + (e.index.tpe) match { case (t:UIntType) => false case (t) => errors.append(new AccessIndexNotUInt(info)) } } case (e:DoPrim) => check_types_primop(e,errors,info) case (e:Mux) => { - if (wt(tpe(e.tval)) != wt(tpe(e.fval))) errors.append(new MuxSameType(info)) - if (!passive(tpe(e))) errors.append(new MuxPassiveTypes(info)) - if (!(tpe(e.cond).typeof[UIntType])) errors.append(new MuxCondUInt(info)) + if (wt(e.tval.tpe) != wt(e.fval.tpe)) errors.append(new MuxSameType(info)) + if (!passive(e.tpe)) errors.append(new MuxPassiveTypes(info)) + e.cond.tpe match { + case _: UIntType => + case _ => errors.append(new MuxCondUInt(info)) + } } case (e:ValidIf) => { - if (!passive(tpe(e))) errors.append(new ValidIfPassiveTypes(info)) - if (!(tpe(e.cond).typeof[UIntType])) errors.append(new ValidIfCondUInt(info)) + if (!passive(e.tpe)) errors.append(new ValidIfPassiveTypes(info)) + e.cond.tpe match { + case _: UIntType => + case _ => errors.append(new ValidIfCondUInt(info)) + } } case (_:UIntLiteral | _:SIntLiteral) => false } @@ -484,22 +499,22 @@ object CheckTypes extends Pass with LazyLogging { def check_types_s (s:Statement) : Statement = { s map (check_types_e(get_info(s))) match { - case (s:Connect) => if (wt(tpe(s.loc)) != wt(tpe(s.expr))) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) - case (s:DefRegister) => if (wt(s.tpe) != wt(tpe(s.init))) errors.append(new InvalidRegInit(s.info)) - case (s:PartialConnect) => if (!bulk_equals(tpe(s.loc),tpe(s.expr),Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) + case (s:Connect) => if (wt(s.loc.tpe) != wt(s.expr.tpe)) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) + case (s:DefRegister) => if (wt(s.tpe) != wt(s.init.tpe)) errors.append(new InvalidRegInit(s.info)) + case (s:PartialConnect) => if (!bulk_equals(s.loc.tpe,s.expr.tpe,Default,Default) ) errors.append(new InvalidConnect(s.info, s.loc.serialize, s.expr.serialize)) case (s:Stop) => { - if (wt(tpe(s.clk)) != wt(ClockType) ) errors.append(new ReqClk(s.info)) - if (wt(tpe(s.en)) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) + if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info)) + if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) } case (s:Print)=> { for (x <- s.args ) { - if (wt(tpe(x)) != wt(ut()) && wt(tpe(x)) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info)) + if (wt(x.tpe) != wt(ut()) && wt(x.tpe) != wt(st()) ) errors.append(new PrintfArgNotGround(s.info)) } - if (wt(tpe(s.clk)) != wt(ClockType) ) errors.append(new ReqClk(s.info)) - if (wt(tpe(s.en)) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) + if (wt(s.clk.tpe) != wt(ClockType) ) errors.append(new ReqClk(s.info)) + if (wt(s.en.tpe) != wt(ut()) ) errors.append(new EnNotUInt(s.info)) } - case (s:Conditionally) => if (wt(tpe(s.pred)) != wt(ut()) ) errors.append(new PredNotUInt(s.info)) - case (s:DefNode) => if (!passive(tpe(s.value)) ) errors.append(new NodePassiveType(s.info)) + case (s:Conditionally) => if (wt(s.pred.tpe) != wt(ut()) ) errors.append(new PredNotUInt(s.info)) + case (s:DefNode) => if (!passive(s.value.tpe) ) errors.append(new NodePassiveType(s.info)) case (s) => false } s map (check_types_s) @@ -571,7 +586,7 @@ object CheckGenders extends Pass { fQ } - val has_flipQ = flipQ(tpe(e)) + val has_flipQ = flipQ(e.tpe) //println(e) //println(gender) //println(desired) @@ -597,7 +612,7 @@ object CheckGenders extends Pass { (e) match { case (e:WRef) => genders(e.name) case (e:WSubField) => - val f = tpe(e.exp).as[BundleType].get.fields.find(f => f.name == e.name).get + val f = e.exp.tpe.asInstanceOf[BundleType].fields.find(f => f.name == e.name).get times(get_gender(e.exp,genders),f.flip) case (e:WSubIndex) => get_gender(e.exp,genders) case (e:WSubAccess) => get_gender(e.exp,genders) @@ -735,7 +750,7 @@ object CheckWidths extends Pass { } def check_width_s (s:Statement) : Statement = { s map (check_width_s) map (check_width_e(get_info(s))) - def tm (t:Type) : Type = mapr(check_width_w(info(s)) _,t) + def tm (t:Type) : Type = mapr(check_width_w(get_info(s)) _,t) s map (tm) } diff --git a/src/main/scala/firrtl/passes/ConstProp.scala b/src/main/scala/firrtl/passes/ConstProp.scala index 57782a3c..2e8b53f3 100644 --- a/src/main/scala/firrtl/passes/ConstProp.scala +++ b/src/main/scala/firrtl/passes/ConstProp.scala @@ -129,7 +129,7 @@ object ConstProp extends Pass { private def foldComparison(e: DoPrim) = { def foldIfZeroedArg(x: Expression): Expression = { - def isUInt(e: Expression): Boolean = tpe(e) match { + def isUInt(e: Expression): Boolean = e.tpe match { case UIntType(_) => true case _ => false } @@ -163,7 +163,7 @@ object ConstProp extends Pass { def range(e: Expression): Range = e match { case UIntLiteral(value, _) => Range(value, value) case SIntLiteral(value, _) => Range(value, value) - case _ => tpe(e) match { + case _ => e.tpe match { case SIntType(IntWidth(width)) => Range( min = BigInt(0) - BigInt(2).pow(width.toInt - 1), max = BigInt(2).pow(width.toInt - 1) - BigInt(1) @@ -226,7 +226,7 @@ object ConstProp extends Pass { case Pad => e.args(0) match { case UIntLiteral(v, _) => UIntLiteral(v, IntWidth(e.consts(0))) case SIntLiteral(v, _) => SIntLiteral(v, IntWidth(e.consts(0))) - case _ if long_BANG(tpe(e.args(0))) == e.consts(0) => e.args(0) + case _ if long_BANG(e.args(0).tpe) == e.consts(0) => e.args(0) case _ => e } case Bits => e.args(0) match { @@ -234,9 +234,9 @@ object ConstProp extends Pass { val hi = e.consts(0).toInt val lo = e.consts(1).toInt require(hi >= lo) - UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(tpe(e))) + UIntLiteral((lit.value >> lo) & ((BigInt(1) << (hi - lo + 1)) - 1), widthBANG(e.tpe)) } - case x if long_BANG(tpe(e)) == long_BANG(tpe(x)) => tpe(x) match { + case x if long_BANG(e.tpe) == long_BANG(x.tpe) => x.tpe match { case t: UIntType => x case _ => asUInt(x, e.tpe) } diff --git a/src/main/scala/firrtl/passes/ExpandWhens.scala b/src/main/scala/firrtl/passes/ExpandWhens.scala index 921693c7..3d26298a 100644 --- a/src/main/scala/firrtl/passes/ExpandWhens.scala +++ b/src/main/scala/firrtl/passes/ExpandWhens.scala @@ -131,8 +131,8 @@ object ExpandWhens extends Pass { val falseValue = altNetlist.getOrElse(lvalue, defaultValue) (trueValue, falseValue) match { case (WInvalid(), WInvalid()) => WInvalid() - case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, tpe(fv)) - case (tv, WInvalid()) => ValidIf(s.pred, tv, tpe(tv)) + case (WInvalid(), fv) => ValidIf(NOT(s.pred), fv, fv.tpe) + case (tv, WInvalid()) => ValidIf(s.pred, tv, tv.tpe) case (tv, fv) => Mux(s.pred, tv, fv, mux_type_and_widths(tv, fv)) } case None => diff --git a/src/main/scala/firrtl/passes/Inline.scala b/src/main/scala/firrtl/passes/Inline.scala index 7793c85c..a8fda1bf 100644 --- a/src/main/scala/firrtl/passes/Inline.scala +++ b/src/main/scala/firrtl/passes/Inline.scala @@ -5,7 +5,6 @@ package passes import scala.collection.mutable import firrtl.Mappers.{ExpMap,StmtMap} -import firrtl.Utils.WithAs import firrtl.ir._ import firrtl.passes.{PassException,PassExceptions} import Annotations.{Loose, Unstable, Annotation, TransID, Named, ModuleName, ComponentName, CircuitName, AnnotationMap} diff --git a/src/main/scala/firrtl/passes/LowerTypes.scala b/src/main/scala/firrtl/passes/LowerTypes.scala index 585598a8..a4c584ed 100644 --- a/src/main/scala/firrtl/passes/LowerTypes.scala +++ b/src/main/scala/firrtl/passes/LowerTypes.scala @@ -105,15 +105,15 @@ object LowerTypes extends Pass { require(tail.isEmpty) // there can't be a tail for these val memType = memDataTypeMap(mem.name) - if (memType.isGround) { - Seq(e) - } else { - val exps = create_exps(mem.name, memType) - exps map { e => - val loMemName = loweredName(e) - val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) - mergeRef(loMem, mergeRef(port, field)) - } + memType match { + case _: GroundType => Seq(e) + case _ => + val exps = create_exps(mem.name, memType) + exps map { e => + val loMemName = loweredName(e) + val loMem = WRef(loMemName, UnknownType, kind(mem), UNKNOWNGENDER) + mergeRef(loMem, mergeRef(port, field)) + } } // Fields that need not be replicated for each // eg. mem.reader.data[0].a @@ -138,7 +138,7 @@ object LowerTypes extends Pass { case k: InstanceKind => val (root, tail) = splitRef(e) val name = loweredName(tail) - WSubField(root, name, tpe(e), gender(e)) + WSubField(root, name, e.tpe, gender(e)) case k: MemKind => val exps = lowerTypesMemExp(e) if (exps.length > 1) @@ -146,7 +146,7 @@ object LowerTypes extends Pass { " to be expanded!") exps(0) case k => - WRef(loweredName(e), tpe(e), kind(e), gender(e)) + WRef(loweredName(e), e.tpe, kind(e), gender(e)) } case e: Mux => e map (lowerTypesExp) case e: ValidIf => e map (lowerTypesExp) @@ -158,26 +158,26 @@ object LowerTypes extends Pass { s map lowerTypesStmt match { case s: DefWire => sinfo = s.info - if (s.tpe.isGround) { - s - } else { - val exps = create_exps(s.name, s.tpe) - val stmts = exps map (e => DefWire(s.info, loweredName(e), tpe(e))) - Block(stmts) + s.tpe match { + case _: GroundType => s + case _ => + val exps = create_exps(s.name, s.tpe) + val stmts = exps map (e => DefWire(s.info, loweredName(e), e.tpe)) + Block(stmts) } case s: DefRegister => sinfo = s.info - if (s.tpe.isGround) { - s map lowerTypesExp - } else { - val es = create_exps(s.name, s.tpe) - val inits = create_exps(s.init) map (lowerTypesExp) - val clock = lowerTypesExp(s.clock) - val reset = lowerTypesExp(s.reset) - val stmts = es zip inits map { case (e, i) => - DefRegister(s.info, loweredName(e), tpe(e), clock, reset, i) - } - Block(stmts) + s.tpe match { + case _: GroundType => s map lowerTypesExp + case _ => + val es = create_exps(s.name, s.tpe) + val inits = create_exps(s.init) map (lowerTypesExp) + val clock = lowerTypesExp(s.clock) + val reset = lowerTypesExp(s.reset) + val stmts = es zip inits map { case (e, i) => + DefRegister(s.info, loweredName(e), e.tpe, clock, reset, i) + } + Block(stmts) } // Could instead just save the type of each Module as it gets processed case s: WDefInstance => @@ -188,7 +188,7 @@ object LowerTypes extends Pass { val exps = create_exps(WRef(f.name, f.tpe, ExpKind(), times(f.flip, MALE))) exps map ( e => // Flip because inst genders are reversed from Module type - Field(loweredName(e), toFlip(gender(e)).flip, tpe(e)) + Field(loweredName(e), swap(to_flip(gender(e))), e.tpe) ) } WDefInstance(s.info, s.name, s.module, BundleType(fieldsx)) @@ -197,16 +197,16 @@ object LowerTypes extends Pass { case s: DefMemory => sinfo = s.info memDataTypeMap += (s.name -> s.dataType) - if (s.dataType.isGround) { - s - } else { - val exps = create_exps(s.name, s.dataType) - val stmts = exps map { e => - DefMemory(s.info, loweredName(e), tpe(e), s.depth, - s.writeLatency, s.readLatency, s.readers, s.writers, - s.readwriters) - } - Block(stmts) + s.dataType match { + case _: GroundType => s + case _ => + val exps = create_exps(s.name, s.dataType) + val stmts = exps map { e => + DefMemory(s.info, loweredName(e), e.tpe, s.depth, + s.writeLatency, s.readLatency, s.readers, s.writers, + s.readwriters) + } + Block(stmts) } // wire foo : { a , b } // node x = foo @@ -217,7 +217,7 @@ object LowerTypes extends Pass { // node y = x_a case s: DefNode => sinfo = s.info - val names = create_exps(s.name, tpe(s.value)) map (lowerTypesExp) + val names = create_exps(s.name, s.value.tpe) map (lowerTypesExp) val exps = create_exps(s.value) map (lowerTypesExp) val stmts = names zip exps map { case (n, e) => DefNode(s.info, loweredName(n), e) @@ -249,7 +249,7 @@ object LowerTypes extends Pass { // Lower Ports val portsx = m.ports flatMap { p => val exps = create_exps(WRef(p.name, p.tpe, PortKind(), to_gender(p.direction))) - exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), tpe(e)) ) + exps map ( e => Port(p.info, loweredName(e), to_dir(gender(e)), e.tpe) ) } m match { case m: ExtModule => m.copy(ports = portsx) diff --git a/src/main/scala/firrtl/passes/PadWidths.scala b/src/main/scala/firrtl/passes/PadWidths.scala index 0cabc293..f2117761 100644 --- a/src/main/scala/firrtl/passes/PadWidths.scala +++ b/src/main/scala/firrtl/passes/PadWidths.scala @@ -2,7 +2,7 @@ package firrtl package passes import firrtl.Mappers.{ExpMap, StmtMap} -import firrtl.Utils.{tpe, long_BANG} +import firrtl.Utils.long_BANG import firrtl.PrimOps._ import firrtl.ir._ @@ -10,10 +10,10 @@ import firrtl.ir._ object PadWidths extends Pass { def name = "Pad Widths" private def width(t: Type): Int = long_BANG(t).toInt - private def width(e: Expression): Int = width(tpe(e)) + private def width(e: Expression): Int = width(e.tpe) // Returns an expression with the correct integer width private def fixup(i: Int)(e: Expression) = { - def tx = tpe(e) match { + def tx = e.tpe match { case t: UIntType => UIntType(IntWidth(i)) case t: SIntType => SIntType(IntWidth(i)) // default case should never be reached diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 7b4f9aa2..6b6dc811 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -103,7 +103,7 @@ object ResolveKinds extends Pass { def resolve (body:Statement) = { def resolve_expr (e:Expression):Expression = { e match { - case e:WRef => WRef(e.name,tpe(e),kinds(e.name),e.gender) + case e:WRef => WRef(e.name,e.tpe,kinds(e.name),e.gender) case e => e map (resolve_expr) } } @@ -170,11 +170,11 @@ object InferTypes extends Pass { val types = LinkedHashMap[String,Type]() def infer_types_e (e:Expression) : Expression = { e map (infer_types_e) match { - case e:ValidIf => ValidIf(e.cond,e.value,tpe(e.value)) + case e:ValidIf => ValidIf(e.cond,e.value,e.value.tpe) case e:WRef => WRef(e.name, types(e.name),e.kind,e.gender) - case e:WSubField => WSubField(e.exp,e.name,field_type(tpe(e.exp),e.name),e.gender) - case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(tpe(e.exp)),e.gender) - case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(tpe(e.exp)),e.gender) + case e:WSubField => WSubField(e.exp,e.name,field_type(e.exp.tpe,e.name),e.gender) + case e:WSubIndex => WSubIndex(e.exp,e.value,sub_type(e.exp.tpe),e.gender) + case e:WSubAccess => WSubAccess(e.exp,e.index,sub_type(e.exp.tpe),e.gender) case e:DoPrim => set_primop_type(e) case e:Mux => Mux(e.cond,e.tval,e.fval,mux_type_and_widths(e.tval,e.fval)) case e:UIntLiteral => e @@ -246,7 +246,7 @@ object ResolveGenders extends Pass { case e:WRef => WRef(e.name,e.tpe,e.kind,g) case e:WSubField => { val expx = - field_flip(tpe(e.exp),e.name) match { + field_flip(e.exp.tpe,e.name) match { case Default => resolve_e(g)(e.exp) case Flip => resolve_e(swap(g))(e.exp) } @@ -474,7 +474,7 @@ object InferWidths extends Pass { case (t:SIntType) => t.width case ClockType => IntWidth(1) case (t) => error("No width!"); IntWidth(-1) } } - def width_BANG (e:Expression) : Width = width_BANG(tpe(e)) + def width_BANG (e:Expression) : Width = width_BANG(e.tpe) def reduce_var_widths(c: Circuit, h: LinkedHashMap[String,Width]): Circuit = { def evaluate(w: Width): Width = { @@ -549,40 +549,40 @@ object InferWidths extends Pass { def get_constraints_e (e:Expression) : Expression = { (e map (get_constraints_e)) match { case (e:Mux) => { - constrain(width_BANG(e.cond),ONE) - constrain(ONE,width_BANG(e.cond)) + constrain(width_BANG(e.cond),IntWidth(1)) + constrain(IntWidth(1),width_BANG(e.cond)) e } case (e) => e }} def get_constraints (s:Statement) : Statement = { (s map (get_constraints_e)) match { case (s:Connect) => { - val n = get_size(tpe(s.loc)) + val n = get_size(s.loc.tpe) val ce_loc = create_exps(s.loc) val ce_exp = create_exps(s.expr) for (i <- 0 until n) { val locx = ce_loc(i) val expx = ce_exp(i) - get_flip(tpe(s.loc),i,Default) match { + get_flip(s.loc.tpe,i,Default) match { case Default => constrain(width_BANG(locx),width_BANG(expx)) case Flip => constrain(width_BANG(expx),width_BANG(locx)) }} s } case (s:PartialConnect) => { - val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) + val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) for (x <- ls) { val locx = create_exps(s.loc)(x._1) val expx = create_exps(s.expr)(x._2) - get_flip(tpe(s.loc),x._1,Default) match { + get_flip(s.loc.tpe,x._1,Default) match { case Default => constrain(width_BANG(locx),width_BANG(expx)) case Flip => constrain(width_BANG(expx),width_BANG(locx)) }} s } case (s:DefRegister) => { - constrain(width_BANG(s.reset),ONE) - constrain(ONE,width_BANG(s.reset)) - get_constraints_t(s.tpe,tpe(s.init),Default) + constrain(width_BANG(s.reset),IntWidth(1)) + constrain(IntWidth(1),width_BANG(s.reset)) + get_constraints_t(s.tpe,s.init.tpe,Default) s } case (s:Conditionally) => { - v += WGeq(width_BANG(s.pred),ONE) - v += WGeq(ONE,width_BANG(s.pred)) + v += WGeq(width_BANG(s.pred),IntWidth(1)) + v += WGeq(IntWidth(1),width_BANG(s.pred)) s map (get_constraints) } case (s) => s map (get_constraints) }} @@ -661,7 +661,7 @@ object ExpandConnects extends Pass { e map (set_gender) match { case (e:WRef) => WRef(e.name,e.tpe,e.kind,genders(e.name)) case (e:WSubField) => { - val f = get_field(tpe(e.exp),e.name) + val f = get_field(e.exp.tpe,e.name) val genderx = times(gender(e.exp),f.flip) WSubField(e.exp,e.name,e.tpe,genderx) } @@ -677,7 +677,7 @@ object ExpandConnects extends Pass { case (s:DefMemory) => { genders(s.name) = MALE; s } case (s:DefNode) => { genders(s.name) = MALE; s } case (s:IsInvalid) => { - val n = get_size(tpe(s.expr)) + val n = get_size(s.expr.tpe) val invalids = ArrayBuffer[Statement]() val exps = create_exps(s.expr) for (i <- 0 until n) { @@ -696,14 +696,14 @@ object ExpandConnects extends Pass { } else Block(invalids) } case (s:Connect) => { - val n = get_size(tpe(s.loc)) + val n = get_size(s.loc.tpe) val connects = ArrayBuffer[Statement]() val locs = create_exps(s.loc) val exps = create_exps(s.expr) for (i <- 0 until n) { val locx = locs(i) val expx = exps(i) - val sx = get_flip(tpe(s.loc),i,Default) match { + val sx = get_flip(s.loc.tpe,i,Default) match { case Default => Connect(s.info,locx,expx) case Flip => Connect(s.info,expx,locx) } @@ -712,14 +712,14 @@ object ExpandConnects extends Pass { Block(connects) } case (s:PartialConnect) => { - val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) + val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) val connects = ArrayBuffer[Statement]() val locs = create_exps(s.loc) val exps = create_exps(s.expr) ls.foreach { x => { val locx = locs(x._1) val expx = exps(x._2) - val sx = get_flip(tpe(s.loc),x._1,Default) match { + val sx = get_flip(s.loc.tpe,x._1,Default) match { case Default => Connect(s.info,locx,expx) case Flip => Connect(s.info,expx,locx) } @@ -755,7 +755,7 @@ object Legalize extends Pass { def legalizeShiftRight (e: DoPrim): Expression = e.op match { case Shr => { val amount = e.consts(0).toInt - val width = long_BANG(tpe(e.args(0))) + val width = long_BANG(e.args(0).tpe) lazy val msb = width - 1 if (amount >= width) { e.tpe match { @@ -771,9 +771,9 @@ object Legalize extends Pass { case _ => e } def legalizeConnect(c: Connect): Statement = { - val t = tpe(c.loc) + val t = c.loc.tpe val w = long_BANG(t) - if (w >= long_BANG(tpe(c.expr))) c + if (w >= long_BANG(c.expr.tpe)) c else { val newType = t match { case _: UIntType => UIntType(IntWidth(w)) @@ -811,8 +811,8 @@ object VerilogWrap extends Pass { if (e.op == Tail) { (a0()) match { case (e0:DoPrim) => { - if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),tpe(e)) - else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),tpe(e)) + if (e0.op == Add) DoPrim(Addw,e0.args,Seq(),e.tpe) + else if (e0.op == Sub) DoPrim(Subw,e0.args,Seq(),e.tpe) else e } case (e0) => e @@ -913,12 +913,12 @@ object CInferTypes extends Pass { def infer_types_e (e:Expression) : Expression = { e map infer_types_e match { case (e:Reference) => Reference(e.name, types.getOrElse(e.name,UnknownType)) - case (e:SubField) => SubField(e.expr,e.name,field_type(tpe(e.expr),e.name)) - case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(tpe(e.expr))) - case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(tpe(e.expr))) + case (e:SubField) => SubField(e.expr,e.name,field_type(e.expr.tpe,e.name)) + case (e:SubIndex) => SubIndex(e.expr,e.value,sub_type(e.expr.tpe)) + case (e:SubAccess) => SubAccess(e.expr,e.index,sub_type(e.expr.tpe)) case (e:DoPrim) => set_primop_type(e) case (e:Mux) => Mux(e.cond,e.tval,e.fval,mux_type(e.tval,e.tval)) - case (e:ValidIf) => ValidIf(e.cond,e.value,tpe(e.value)) + case (e:ValidIf) => ValidIf(e.cond,e.value,e.value.tpe) case (_:UIntLiteral | _:SIntLiteral) => e } } @@ -1067,8 +1067,8 @@ object RemoveCHIRRTL extends Pass { val e2s = create_exps(e.fval) (e1s,e2s).zipped map ((e1,e2) => Mux(e.cond,e1,e2,mux_type(e1,e2))) case (e:ValidIf) => - create_exps(e.value) map (e1 => ValidIf(e.cond,e1,tpe(e1))) - case (e) => (tpe(e)) match { + create_exps(e.value) map (e1 => ValidIf(e.cond,e1,e1.tpe)) + case (e) => (e.tpe) match { case (_:GroundType) => Seq(e) case (t:BundleType) => (t.fields foldLeft Seq[Expression]())((exps, f) => exps ++ create_exps(SubField(e,f.name,f.tpe))) @@ -1276,7 +1276,7 @@ object RemoveCHIRRTL extends Pass { case Some(en) => stmts += Connect(s.info,en,one) } if (has_write_mport) { - val ls = get_valid_points(tpe(s.loc),tpe(s.expr),Default,Default) + val ls = get_valid_points(s.loc.tpe,s.expr.tpe,Default,Default) val locs = create_exps(get_mask(s.loc)) for (x <- ls ) { val locx = locs(x._1) diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala index a3ce49f7..880d6b1c 100644 --- a/src/main/scala/firrtl/passes/RemoveAccesses.scala +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -76,7 +76,7 @@ object RemoveAccesses extends Pass { def onStmt(s: Statement): Statement = { def create_temp(e: Expression): (Statement, Expression) = { val n = namespace.newTemp - (DefWire(info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e))) + (DefWire(get_info(s), n, e.tpe), WRef(n, e.tpe, kind(e), gender(e))) } /** Replaces a subaccess in a given male expression @@ -94,9 +94,9 @@ object RemoveAccesses extends Pass { stmts += wire rs.zipWithIndex foreach { case (x, i) if i < temps.size => - stmts += Connect(info(s),getTemp(i),x.base) + stmts += Connect(get_info(s),getTemp(i),x.base) case (x, i) => - stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt) + stmts += Conditionally(get_info(s),x.guard,Connect(get_info(s),getTemp(i),x.base),EmptyStmt) } temp } diff --git a/src/main/scala/firrtl/passes/SplitExpressions.scala b/src/main/scala/firrtl/passes/SplitExpressions.scala index 1c9674e1..3b6021ed 100644 --- a/src/main/scala/firrtl/passes/SplitExpressions.scala +++ b/src/main/scala/firrtl/passes/SplitExpressions.scala @@ -2,7 +2,7 @@ package firrtl package passes import firrtl.Mappers.{ExpMap, StmtMap} -import firrtl.Utils.{tpe, kind, gender, info} +import firrtl.Utils.{kind, gender, get_info} import firrtl.ir._ import scala.collection.mutable @@ -20,18 +20,18 @@ object SplitExpressions extends Pass { def split(e: Expression): Expression = e match { case e: DoPrim => { val name = namespace.newTemp - v += DefNode(info(s), name, e) - WRef(name, tpe(e), kind(e), gender(e)) + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), gender(e)) } case e: Mux => { val name = namespace.newTemp - v += DefNode(info(s), name, e) - WRef(name, tpe(e), kind(e), gender(e)) + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), gender(e)) } case e: ValidIf => { val name = namespace.newTemp - v += DefNode(info(s), name, e) - WRef(name, tpe(e), kind(e), gender(e)) + v += DefNode(get_info(s), name, e) + WRef(name, e.tpe, kind(e), gender(e)) } case e => e } diff --git a/src/main/scala/firrtl/passes/Uniquify.scala b/src/main/scala/firrtl/passes/Uniquify.scala index b1a20fdd..d034719a 100644 --- a/src/main/scala/firrtl/passes/Uniquify.scala +++ b/src/main/scala/firrtl/passes/Uniquify.scala @@ -109,8 +109,9 @@ object Uniquify extends Pass { val newName = findValidPrefix(f.name, Seq(""), namespace) namespace += newName Field(newName, f.flip, f.tpe) - } map { f => - if (f.tpe.isAggregate) { + } map { f => f.tpe match { + case _: GroundType => f + case _ => val tpe = recUniquifyNames(f.tpe, collection.mutable.HashSet()) val elts = enumerateNames(tpe) // Need leading _ for findValidPrefix, it doesn't add _ for checks @@ -123,8 +124,6 @@ object Uniquify extends Pass { } namespace ++= (elts map (e => LowerTypes.loweredName(prefix +: e))) Field(prefix, f.flip, tpe) - } else { - f } } BundleType(newFields) @@ -349,7 +348,9 @@ object Uniquify extends Pass { def uniquifyPorts(m: DefModule): DefModule = { def uniquifyPorts(ports: Seq[Port]): Seq[Port] = { - val portsType = BundleType(ports map (_.toField)) + val portsType = BundleType(ports map { + case Port(_, name, dir, tpe) => Field(name, to_flip(dir), tpe) + }) val uniquePortsType = uniquifyNames(portsType, collection.mutable.HashSet()) val localMap = createNameMapping(portsType, uniquePortsType) portNameMap += (m.name -> localMap) |
