aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorDonggyu2017-03-09 17:29:45 -0800
committerAdam Izraelevitz2017-03-09 17:29:45 -0800
commite571ef88f7f69b2374fa9ba86e219523645213c6 (patch)
treec47192c502e5d7a39143f8244a9d6423ef16abd5 /src
parent664d5b33094b7158bb6f8a583a89d83ac69be83e (diff)
make sure infer-rw works for exclusive when statements (#481)
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/passes/memlib/InferReadWrite.scala3
-rw-r--r--src/test/scala/firrtlTests/InferReadWriteSpec.scala31
2 files changed, 34 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
index b941503f..9bd6a4ab 100644
--- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
+++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala
@@ -39,8 +39,11 @@ object InferReadWritePass extends Pass {
def getProductTerms(connects: Connects)(e: Expression): Seq[Expression] = e match {
// No ConstProp yet...
+ // TODO: do const prop before
case Mux(cond, tval, fval, _) if weq(tval, one) && weq(fval, zero) =>
getProductTerms(connects)(cond)
+ case Mux(cond, tval, fval, _) if weq(fval, zero) =>
+ getProductTerms(connects)(cond) ++ getProductTerms(connects)(tval)
// Visit each term of AND operation
case DoPrim(op, args, consts, tpe) if op == And =>
e +: (args flatMap getProductTerms(connects))
diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
index 91dc911c..73fdbe91 100644
--- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala
+++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala
@@ -82,6 +82,37 @@ circuit sram6t :
parse(res.getEmittedCircuit.value)
}
+ "Infer ReadWrite Ports" should "infer readwrite ports from exclusive when statements" in {
+ val input = """
+circuit sram6t :
+ module sram6t :
+ input clock : Clock
+ input reset : UInt<1>
+ output io : { flip addr : UInt<11>, flip ren : UInt<1>, flip wen : UInt<1>, flip dataIn : UInt<32>, dataOut : UInt<32>}
+
+ io is invalid
+ smem mem : UInt<32> [2048]
+ when io.wen :
+ write mport _T_14 = mem[io.addr], clock
+ _T_14 <= io.dataIn
+ node _T_16 = eq(io.wen, UInt<1>("h0"))
+ when _T_16 :
+ wire _T_18 : UInt
+ _T_18 is invalid
+ when io.ren :
+ _T_18 <= io.addr
+ node _T_20 = or(_T_18, UInt<11>("h0"))
+ node _T_21 = bits(_T_20, 10, 0)
+ read mport _T_22 = mem[_T_21], clock
+ io.dataOut <= _T_22
+""".stripMargin
+
+ val annotationMap = AnnotationMap(Seq(memlib.InferReadWriteAnnotation("sram6t")))
+ val res = compileAndEmit(CircuitState(parse(input), ChirrtlForm, Some(annotationMap)))
+ // Check correctness of firrtl
+ parse(res.getEmittedCircuit.value)
+ }
+
"Infer ReadWrite Ports" should "not infer readwrite ports for the difference clocks" in {
val input = """
circuit sram6t :