diff options
| author | Adam Izraelevitz | 2018-02-22 17:25:55 -0800 |
|---|---|---|
| committer | GitHub | 2018-02-22 17:25:55 -0800 |
| commit | 46b78943a726e4c9bf85ffb25a2ccf926b10dda7 (patch) | |
| tree | 39f9363400fdd39e2e55f3dc8c5221461941edec /src | |
| parent | 65bbf155003a86cd836f7ff4a2def6af91794780 (diff) | |
Add tests for #702. Adds Utility functions. Allows clock muxing in FIRRTL, but not Emitter. (#717)
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Emitter.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 123 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 8 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveCHIRRTL.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala | 6 | ||||
| -rw-r--r-- | src/main/scala/firrtl/transforms/ConstantPropagation.scala | 49 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/ChirrtlMemSpec.scala | 146 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/UnitTests.scala | 4 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/VerilogEmitterTests.scala | 19 |
9 files changed, 248 insertions, 117 deletions
diff --git a/src/main/scala/firrtl/Emitter.scala b/src/main/scala/firrtl/Emitter.scala index 5753fc17..cf356dcb 100644 --- a/src/main/scala/firrtl/Emitter.scala +++ b/src/main/scala/firrtl/Emitter.scala @@ -218,7 +218,10 @@ class VerilogEmitter extends SeqTransform with Emitter { } x match { case (e: DoPrim) => emit(op_stream(e), top + 1) - case (e: Mux) => emit(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1) + case (e: Mux) => { + if(e.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly") + emit(Seq(e.cond," ? ",cast(e.tval)," : ",cast(e.fval)),top + 1) + } case (e: ValidIf) => emit(Seq(cast(e.value)),top + 1) case (e: WRef) => w write e.serialize case (e: WSubField) => w write LowerTypes.loweredName(e) @@ -319,7 +322,7 @@ class VerilogEmitter extends SeqTransform with Emitter { } case AsUInt => Seq("$unsigned(", a0, ")") case AsSInt => Seq("$signed(", a0, ")") - case AsClock => Seq("$unsigned(", a0, ")") + case AsClock => Seq(a0) case Dshlw => Seq(cast(a0), " << ", a1) case Dshl => Seq(cast(a0), " << ", a1) case Dshr => doprim.tpe match { @@ -433,6 +436,7 @@ class VerilogEmitter extends SeqTransform with Emitter { } expr match { case m: Mux if canFlatten(m) => + if(m.tpe == ClockType) throw EmitterException("Cannot emit clock muxes directly") val ifStatement = Seq(tabs, "if (", m.cond, ") begin") val trueCase = addUpdate(m.tval, tabs + tab) val elseStatement = Seq(tabs, "end else begin") diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index e5a7e6be..0c684c5d 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -226,6 +226,10 @@ object Utils extends LazyLogging { case SIntLiteral(value, _) => "_" + value } } + /** Maps node name to value */ + type NodeMap = mutable.HashMap[String, Expression] + + def isTemp(str: String): Boolean = str.head == '_' /** Indent the results of [[ir.FirrtlNode.serialize]] */ def indent(str: String) = str replaceAllLiterally ("\n", "\n ") @@ -306,88 +310,45 @@ object Utils extends LazyLogging { case _ => false } -//============== TYPES ================ -//<<<<<<< HEAD -// def mux_type (e1:Expression,e2:Expression) : Type = mux_type(tpe(e1),tpe(e2)) -// def mux_type (t1:Type,t2:Type) : Type = { -// if (wt(t1) == wt(t2)) { -// (t1,t2) match { -// case (t1:UIntType,t2:UIntType) => UIntType(UnknownWidth) -// case (t1:SIntType,t2:SIntType) => SIntType(UnknownWidth) -// case (t1:FixedType,t2:FixedType) => FixedType(UnknownWidth, UnknownWidth) -// case (t1:VectorType,t2:VectorType) => VectorType(mux_type(t1.tpe,t2.tpe),t1.size) -// case (t1:BundleType,t2:BundleType) => -// BundleType((t1.fields,t2.fields).zipped.map((f1,f2) => { -// Field(f1.name,f1.flip,mux_type(f1.tpe,f2.tpe)) -// })) -// } -// } else UnknownType -// } -// def mux_type_and_widths (e1:Expression,e2:Expression) : Type = mux_type_and_widths(tpe(e1),tpe(e2)) -// def PLUS (w1:Width,w2:Width) : Width = (w1, w2) match { -// case (IntWidth(i), IntWidth(j)) => IntWidth(i + j) -// case _ => PlusWidth(w1,w2) -// } -// def MAX (w1:Width,w2:Width) : Width = (w1, w2) match { -// case (IntWidth(i), IntWidth(j)) => IntWidth(max(i,j)) -// case _ => MaxWidth(Seq(w1,w2)) -// } -// def MINUS (w1:Width,w2:Width) : Width = (w1, w2) match { -// case (IntWidth(i), IntWidth(j)) => IntWidth(i - j) -// case _ => MinusWidth(w1,w2) -// } -// def POW (w1:Width) : Width = w1 match { -// case IntWidth(i) => IntWidth(pow_minus_one(BigInt(2), i)) -// case _ => ExpWidth(w1) -// } -// def MIN (w1:Width,w2:Width) : Width = (w1, w2) match { -// case (IntWidth(i), IntWidth(j)) => IntWidth(min(i,j)) -// case _ => MinWidth(Seq(w1,w2)) -// } -// def mux_type_and_widths (t1:Type,t2:Type) : Type = { -// def wmax (w1:Width,w2:Width) : Width = { -// (w1,w2) match { -// case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width.max(w2.width)) -// case (w1,w2) => MaxWidth(Seq(w1,w2)) -// } -// } -// val wt1 = new WrappedType(t1) -// val wt2 = new WrappedType(t2) -// if (wt1 == wt2) { -// (t1,t2) match { -// case (t1:UIntType,t2:UIntType) => UIntType(wmax(t1.width,t2.width)) -// case (t1:SIntType,t2:SIntType) => SIntType(wmax(t1.width,t2.width)) -// case (FixedType(w1, p1), FixedType(w2, p2)) => -// FixedType(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))), MAX(p1, p2)) -// case (t1:VectorType,t2:VectorType) => VectorType(mux_type_and_widths(t1.tpe,t2.tpe),t1.size) -// case (t1:BundleType,t2:BundleType) => BundleType((t1.fields zip t2.fields).map{case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe,f2.tpe))}) -// } -// } else UnknownType -// } -// def module_type (m:DefModule) : Type = { -// BundleType(m.ports.map(p => p.toField)) -// } -// def sub_type (v:Type) : Type = { -// v match { -// case v:VectorType => v.tpe -// case v => UnknownType -// } -// } -// def field_type (v:Type,s:String) : Type = { -// v match { -// case v:BundleType => { -// val ft = v.fields.find(p => p.name == s) -// ft match { -// case ft:Some[Field] => ft.get.tpe -// case ft => UnknownType -// } -// } -// case v => UnknownType -// } -// } -//======= + /** Returns children Expressions of e */ + def getKids(e: Expression): Seq[Expression] = { + val kids = mutable.ArrayBuffer[Expression]() + def addKids(e: Expression): Expression = { + kids += e + e + } + e map addKids + kids + } + + /** Walks two expression trees and returns a sequence of tuples of where they differ */ + def diff(e1: Expression, e2: Expression): Seq[(Expression, Expression)] = { + if(weq(e1, e2)) Nil + else { + val (e1Kids, e2Kids) = (getKids(e1), getKids(e2)) + + if(e1Kids == Nil || e2Kids == Nil || e1Kids.size != e2Kids.size) Seq((e1, e2)) + else { + e1Kids.zip(e2Kids).flatMap { case (e1k, e2k) => diff(e1k, e2k) } + } + } + } + + /** Returns an inlined expression (replacing node references with values), + * stopping on a stopping condition or until the reference is not a node + */ + def inline(nodeMap: NodeMap, stop: String => Boolean = {x: String => false})(e: Expression): Expression = { + def onExp(e: Expression): Expression = e map onExp match { + case Reference(name, _) if nodeMap.contains(name) && !stop(name) => onExp(nodeMap(name)) + case WRef(name, _, _, _) if nodeMap.contains(name) && !stop(name) => onExp(nodeMap(name)) + case other => other + } + onExp(e) + } + def mux_type(e1: Expression, e2: Expression): Type = mux_type(e1.tpe, e2.tpe) def mux_type(t1: Type, t2: Type): Type = (t1, t2) match { + case (ClockType, ClockType) => ClockType case (t1: UIntType, t2: UIntType) => UIntType(UnknownWidth) case (t1: SIntType, t2: SIntType) => SIntType(UnknownWidth) case (t1: FixedType, t2: FixedType) => FixedType(UnknownWidth, UnknownWidth) @@ -405,6 +366,7 @@ object Utils extends LazyLogging { case (w1x, w2x) => MaxWidth(Seq(w1x, w2x)) } (t1, t2) match { + case (ClockType, ClockType) => ClockType case (t1x: UIntType, t2x: UIntType) => UIntType(wmax(t1x.width, t2x.width)) case (t1x: SIntType, t2x: SIntType) => SIntType(wmax(t1x.width, t2x.width)) case (FixedType(w1, p1), FixedType(w2, p2)) => @@ -432,7 +394,6 @@ object Utils extends LazyLogging { } case vx => UnknownType } -//>>>>>>> e54fb610c6bf0a7fe5c9c0f0e0b3acbb3728cfd0 // ================================= def error(str: String, cause: Throwable = null) = throw new FIRRTLException(str, cause) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 6934fca2..5b198064 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -259,8 +259,8 @@ object CheckTypes extends Pass { s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.") class NodePassiveType(info: Info, mname: String) extends PassException( s"$info: [module $mname] Node must be a passive type.") - class MuxSameType(info: Info, mname: String) extends PassException( - s"$info: [module $mname] Must mux between equivalent types.") + class MuxSameType(info: Info, mname: String, t1: String, t2: String) extends PassException( + s"$info: [module $mname] Must mux between equivalent types: $t1 != $t2.") class MuxPassiveTypes(info: Info, mname: String) extends PassException( s"$info: [module $mname] Must mux between passive types.") class MuxCondUInt(info: Info, mname: String) extends PassException( @@ -361,15 +361,13 @@ object CheckTypes extends Pass { case (e: DoPrim) => check_types_primop(info, mname, e) case (e: Mux) => if (wt(e.tval.tpe) != wt(e.fval.tpe)) - errors.append(new MuxSameType(info, mname)) + errors.append(new MuxSameType(info, mname, e.tval.tpe.serialize, e.fval.tpe.serialize)) if (!passive(e.tpe)) errors.append(new MuxPassiveTypes(info, mname)) e.cond.tpe match { case _: UIntType => case _ => errors.append(new MuxCondUInt(info, mname)) } - if ((e.tval.tpe == ClockType) || (e.fval.tpe == ClockType)) - errors.append(new MuxClock(info, mname)) case (e: ValidIf) => if (!passive(e.tpe)) errors.append(new ValidIfPassiveTypes(info, mname)) diff --git a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala index c841dc32..6b3508a6 100644 --- a/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala +++ b/src/main/scala/firrtl/passes/RemoveCHIRRTL.scala @@ -4,11 +4,11 @@ package firrtl.passes // Datastructures import scala.collection.mutable.ArrayBuffer - import firrtl._ import firrtl.ir._ import firrtl.Utils._ import firrtl.Mappers._ +import firrtl.PrimOps.AsClock case class MPort(name: String, clk: Expression) case class MPorts(readers: ArrayBuffer[MPort], writers: ArrayBuffer[MPort], readwriters: ArrayBuffer[MPort]) diff --git a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala index 79ecd9cd..0424b1dd 100644 --- a/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala +++ b/src/main/scala/firrtl/passes/memlib/ResolveMaskGranularity.scala @@ -77,7 +77,7 @@ object AnalysisUtils { /** Checks whether the two memories are equivalent in all respects except name */ - def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory, noDeDupeMems: Seq[String]) = + def eqMems(a: DefAnnotatedMemory, b: DefAnnotatedMemory, noDeDupeMems: Seq[String]): Boolean = a == b.copy(info = a.info, name = a.name, memRef = a.memRef) && !(noDeDupeMems.contains(a.name) || noDeDupeMems.contains(b.name)) } @@ -120,6 +120,6 @@ object ResolveMaskGranularity extends Pass { case sx => sx map updateStmts(connects) } - def annotateModMems(m: DefModule) = m map updateStmts(getConnects(m)) - def run(c: Circuit) = c copy (modules = c.modules map annotateModMems) + def annotateModMems(m: DefModule): DefModule = m map updateStmts(getConnects(m)) + def run(c: Circuit): Circuit = c copy (modules = c.modules map annotateModMems) } diff --git a/src/main/scala/firrtl/transforms/ConstantPropagation.scala b/src/main/scala/firrtl/transforms/ConstantPropagation.scala index 086f1cee..04ad2cb2 100644 --- a/src/main/scala/firrtl/transforms/ConstantPropagation.scala +++ b/src/main/scala/firrtl/transforms/ConstantPropagation.scala @@ -251,6 +251,24 @@ class ConstantPropagation extends Transform { // Is "a" a "better name" than "b"? private def betterName(a: String, b: String): Boolean = (a.head != '_') && (b.head == '_') + def optimize(e: Expression): Expression = constPropExpression(new NodeMap(), Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) + def optimize(e: Expression, nodeMap: NodeMap): Expression = constPropExpression(nodeMap, Map.empty[String, String], Map.empty[String, Map[String, Literal]])(e) + private def constPropExpression(nodeMap: NodeMap, instMap: Map[String, String], constSubOutputs: Map[String, Map[String, Literal]])(e: Expression): Expression = { + val old = e map constPropExpression(nodeMap, instMap, constSubOutputs) + val propagated = old match { + case p: DoPrim => constPropPrim(p) + case m: Mux => constPropMux(m) + case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => + constPropNodeRef(ref, nodeMap(rname)) + case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, MALE) => + val module = instMap(inst) + // Check constSubOutputs to see if the submodule is driving a constant + constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) + case x => x + } + propagated + } + /** Constant propagate a Module * * Two pass process @@ -279,7 +297,7 @@ class ConstantPropagation extends Transform { ): (Module, Map[String, Literal], Map[String, Map[String, Seq[Literal]]]) = { var nPropagated = 0L - val nodeMap = mutable.HashMap.empty[String, Expression] + val nodeMap = new NodeMap() // For cases where we are trying to constprop a bad name over a good one, we swap their names // during the second pass val swapMap = mutable.HashMap.empty[String, String] @@ -325,21 +343,6 @@ class ConstantPropagation extends Transform { case other => other map backPropStmt } - def constPropExpression(e: Expression): Expression = { - val old = e map constPropExpression - val propagated = old match { - case p: DoPrim => constPropPrim(p) - case m: Mux => constPropMux(m) - case ref @ WRef(rname, _,_, MALE) if nodeMap.contains(rname) => - constPropNodeRef(ref, nodeMap(rname)) - case ref @ WSubField(WRef(inst, _, InstanceKind, _), pname, _, MALE) => - val module = instMap(inst) - // Check constSubOutputs to see if the submodule is driving a constant - constSubOutputs.get(module).flatMap(_.get(pname)).getOrElse(ref) - case x => x - } - propagated - } // When propagating a reference, check if we want to keep the name that would be deleted def propagateRef(lname: String, value: Expression): Unit = { @@ -354,31 +357,31 @@ class ConstantPropagation extends Transform { } def constPropStmt(s: Statement): Statement = { - val stmtx = s map constPropStmt map constPropExpression + val stmtx = s map constPropStmt map constPropExpression(nodeMap, instMap, constSubOutputs) stmtx match { case x: DefNode if !dontTouches.contains(x.name) => propagateRef(x.name, x.value) case Connect(_, WRef(wname, wtpe, WireKind, _), expr: Literal) if !dontTouches.contains(wname) => - val exprx = constPropExpression(pad(expr, wtpe)) + val exprx = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(expr, wtpe)) propagateRef(wname, exprx) // Record constants driving outputs case Connect(_, WRef(pname, ptpe, PortKind, _), lit: Literal) if !dontTouches.contains(pname) => - val paddedLit = constPropExpression(pad(lit, ptpe)).asInstanceOf[Literal] + val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] constOutputs(pname) = paddedLit // Const prop registers that are fed only a constant or a mux between and constant and the // register itself // This requires that reset has been made explicit case Connect(_, lref @ WRef(lname, ltpe, RegKind, _), expr) if !dontTouches.contains(lname) => expr match { case lit: Literal => - nodeMap(lname) = constPropExpression(pad(lit, ltpe)) + nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ltpe)) case Mux(_, tval: WRef, fval: Literal, _) if weq(lref, tval) => - nodeMap(lname) = constPropExpression(pad(fval, ltpe)) + nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(fval, ltpe)) case Mux(_, tval: Literal, fval: WRef, _) if weq(lref, fval) => - nodeMap(lname) = constPropExpression(pad(tval, ltpe)) + nodeMap(lname) = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(tval, ltpe)) case _ => } // Mark instance inputs connected to a constant case Connect(_, lref @ WSubField(WRef(inst, _, InstanceKind, _), port, ptpe, _), lit: Literal) => - val paddedLit = constPropExpression(pad(lit, ptpe)).asInstanceOf[Literal] + val paddedLit = constPropExpression(nodeMap, instMap, constSubOutputs)(pad(lit, ptpe)).asInstanceOf[Literal] val module = instMap(inst) val portsMap = constSubInputs.getOrElseUpdate(module, mutable.HashMap.empty) portsMap(port) = paddedLit +: portsMap.getOrElse(port, List.empty) diff --git a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala index 6fac5047..d039cc96 100644 --- a/src/test/scala/firrtlTests/ChirrtlMemSpec.scala +++ b/src/test/scala/firrtlTests/ChirrtlMemSpec.scala @@ -8,6 +8,8 @@ import firrtl.passes._ import firrtl.transforms._ import firrtl.Mappers._ import annotations._ +import FirrtlCheckers._ +import firrtl.PrimOps.AsClock class ChirrtlMemSpec extends LowTransformSpec { object MemEnableCheckPass extends Pass { @@ -107,4 +109,148 @@ circuit foo : // Check correctness of firrtl parse(res.getEmittedCircuit.value) } + + ignore should "Memories should not have validif on port clocks when declared in a when" in { + val input = + """;buildInfoPackage: chisel3, version: 3.0-SNAPSHOT, scalaVersion: 2.11.11, sbtVersion: 0.13.16, builtAtString: 2017-10-06 20:55:20.367, builtAtMillis: 1507323320367 + |circuit Stack : + | module Stack : + | input clock : Clock + | input reset : UInt<1> + | output io : {flip push : UInt<1>, flip pop : UInt<1>, flip en : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>} + | + | clock is invalid + | reset is invalid + | io is invalid + | cmem stack_mem : UInt<32>[4] @[Stack.scala 15:22] + | reg sp : UInt<3>, clock with : (reset => (reset, UInt<3>("h00"))) @[Stack.scala 16:26] + | reg out : UInt<32>, clock with : (reset => (reset, UInt<32>("h00"))) @[Stack.scala 17:26] + | when io.en : @[Stack.scala 19:16] + | node _T_14 = lt(sp, UInt<3>("h04")) @[Stack.scala 20:25] + | node _T_15 = and(io.push, _T_14) @[Stack.scala 20:18] + | when _T_15 : @[Stack.scala 20:42] + | node _T_16 = bits(sp, 1, 0) + | infer mport _T_17 = stack_mem[_T_16], clock + | _T_17 <= io.dataIn @[Stack.scala 21:21] + | node _T_19 = add(sp, UInt<1>("h01")) @[Stack.scala 22:16] + | node _T_20 = tail(_T_19, 1) @[Stack.scala 22:16] + | sp <= _T_20 @[Stack.scala 22:10] + | skip @[Stack.scala 20:42] + | else : @[Stack.scala 23:39] + | node _T_22 = gt(sp, UInt<1>("h00")) @[Stack.scala 23:31] + | node _T_23 = and(io.pop, _T_22) @[Stack.scala 23:24] + | when _T_23 : @[Stack.scala 23:39] + | node _T_25 = sub(sp, UInt<1>("h01")) @[Stack.scala 24:16] + | node _T_26 = asUInt(_T_25) @[Stack.scala 24:16] + | node _T_27 = tail(_T_26, 1) @[Stack.scala 24:16] + | sp <= _T_27 @[Stack.scala 24:10] + | skip @[Stack.scala 23:39] + | node _T_29 = gt(sp, UInt<1>("h00")) @[Stack.scala 26:14] + | when _T_29 : @[Stack.scala 26:21] + | node _T_31 = sub(sp, UInt<1>("h01")) @[Stack.scala 27:27] + | node _T_32 = asUInt(_T_31) @[Stack.scala 27:27] + | node _T_33 = tail(_T_32, 1) @[Stack.scala 27:27] + | node _T_34 = bits(_T_33, 1, 0) + | infer mport _T_35 = stack_mem[_T_34], clock + | out <= _T_35 @[Stack.scala 27:11] + | skip @[Stack.scala 26:21] + | skip @[Stack.scala 19:16] + | io.dataOut <= out @[Stack.scala 31:14] + """.stripMargin + val annotationMap = AnnotationMap(Nil) + val res = (new LowFirrtlCompiler).compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), Nil).circuit + assert(res search { + case Connect(_, WSubField(WSubField(WRef("stack_mem", _, _, _), "_T_35",_, _), "clk", _, _), WRef("clock", _, _, _)) => true + case Connect(_, WSubField(WSubField(WRef("stack_mem", _, _, _), "_T_17",_, _), "clk", _, _), WRef("clock", _, _, _)) => true + }) + } + + ignore should "Mem non-local clock port assignment should be ok assign in only one side of when" in { + val input = + """circuit foo : + | module foo : + | input clock : Clock + | input en : UInt<1> + | input addr: UInt<2> + | output out: UInt<32> + | out is invalid + | cmem mem : UInt<32>[4] + | when en: + | read mport bar = mem[addr], clock + | out <= bar + |""".stripMargin + val annotationMap = AnnotationMap(Nil) + val res = (new LowFirrtlCompiler).compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), Nil).circuit + assert(res search { + case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar",_, _), "clk", _, _), WRef("clock", _, _, _)) => true + }) + } + + ignore should "Mem local clock port assignment should be ok" in { + val input = + """circuit foo : + | module foo : + | input clock : Clock + | input en : UInt<1> + | input addr: UInt<2> + | output out: UInt<32> + | out is invalid + | cmem mem : UInt<32>[4] + | when en: + | node local = clock + | read mport bar = mem[addr], local + | out <= bar + |""".stripMargin + val annotationMap = AnnotationMap(Nil) + val res = new LowFirrtlCompiler().compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), Nil).circuit + assert(res search { + case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar",_, _), "clk", _, _), WRef("clock", _, _, _)) => true + }) + } + + ignore should "Mem local nested clock port assignment should be ok" in { + val input = + """circuit foo : + | module foo : + | input clock : Clock + | input en : UInt<1> + | input addr: UInt<2> + | output out: UInt<32> + | out is invalid + | cmem mem : UInt<32>[4] + | when en: + | node local = clock + | read mport bar = mem[addr], asClock(local) + | out <= bar + |""".stripMargin + val annotationMap = AnnotationMap(Nil) + + val res = new LowFirrtlCompiler().compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), Nil).circuit + assert(res search { + case Connect(_, WSubField(WSubField(WRef("mem", _, _, _), "bar",_, _), "clk", _, _), DoPrim(AsClock, Seq(WRef("clock", _, _, _)), Nil, _)) => true + }) + } + + + ignore should "Mem non-local nested clock port assignment should be ok" in { + val input = + """circuit foo : + | module foo : + | input clock : Clock + | input en : UInt<1> + | input addr: UInt<2> + | output out: UInt<32> + | out is invalid + | cmem mem : UInt<32>[4] + | when en: + | read mport bar = mem[addr], asClock(clock) + | out <= bar + |""".stripMargin + val annotationMap = AnnotationMap(Nil) + + val res = (new HighFirrtlCompiler).compile(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)), Nil).circuit + assert(res search { + case Connect(_, SubField(SubField(Reference("mem", _), "bar", _), "clk", _), DoPrim(AsClock, Seq(Reference("clock", _)), _, _)) => true + }) + } } diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala index 44799829..018a35f6 100644 --- a/src/test/scala/firrtlTests/UnitTests.scala +++ b/src/test/scala/firrtlTests/UnitTests.scala @@ -313,7 +313,7 @@ class UnitTests extends FirrtlFlatSpec { } } - "Conditional conection of clocks" should "throw an exception" in { + "Conditional connection of clocks" should "throw an exception" in { val input = """circuit Unit : | module Unit : @@ -325,7 +325,7 @@ class UnitTests extends FirrtlFlatSpec { | when sel : | clock3 <= clock2 |""".stripMargin - intercept[PassExceptions] { // Both MuxClock and InvalidConnect are thrown + intercept[EmitterException] { compileToVerilog(input) } } diff --git a/src/test/scala/firrtlTests/VerilogEmitterTests.scala b/src/test/scala/firrtlTests/VerilogEmitterTests.scala index 6928718a..40b66917 100644 --- a/src/test/scala/firrtlTests/VerilogEmitterTests.scala +++ b/src/test/scala/firrtlTests/VerilogEmitterTests.scala @@ -130,4 +130,23 @@ class VerilogEmitterSpec extends FirrtlFlatSpec { """.stripMargin compiler.compile(CircuitState(parse(input), ChirrtlForm), new java.io.StringWriter) } + "AsClock" should "emit correctly" in { + val compiler = new VerilogCompiler + val input = + """circuit Test : + | module Test : + | input in : UInt<1> + | output out : Clock + | out <= asClock(in) + |""".stripMargin + val check = + """module Test( + | input in, + | output out + |); + | assign out = in; + |endmodule + |""".stripMargin.split("\n") map normalized + executeTest(input, check, compiler) + } } |
