aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-09-22 13:08:32 -0700
committerJack Koenig2016-09-22 13:08:32 -0700
commit24ffde94c89ef67eed4df30ba49ccf48ca46c9a9 (patch)
treee00151cd32f8ce326778662e2a1755e76a21990c /src
parent8b12dcbb76896a19f95dc4da19b3b8c74c1ddda3 (diff)
Fixed width inference for add, sub (#312)
Fixes #308 Fixes #193
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/PrimOps.scala8
-rw-r--r--src/test/scala/firrtlTests/WidthSpec.scala87
2 files changed, 91 insertions, 4 deletions
diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala
index 8d677104..cf2cb74e 100644
--- a/src/main/scala/firrtl/PrimOps.scala
+++ b/src/main/scala/firrtl/PrimOps.scala
@@ -142,15 +142,15 @@ object PrimOps extends LazyLogging {
e copy (tpe = (e.op match {
case Add => (t1, t2) match {
case (_: UIntType, _: UIntType) => UIntType(PLUS(MAX(w1, w2), IntWidth(1)))
- case (_: UIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
- case (_: SIntType, _: UIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
+ case (_: UIntType, _: SIntType) => SIntType(PLUS(MAX(w1, MINUS(w2, IntWidth(1))), IntWidth(2)))
+ case (_: SIntType, _: UIntType) => SIntType(PLUS(MAX(w2, MINUS(w1, IntWidth(1))), IntWidth(2)))
case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
case _ => UnknownType
}
case Sub => (t1, t2) match {
case (_: UIntType, _: UIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
- case (_: UIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
- case (_: SIntType, _: UIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
+ case (_: UIntType, _: SIntType) => SIntType(MAX(PLUS(w2, IntWidth(1)), PLUS(w1, IntWidth(2))))
+ case (_: SIntType, _: UIntType) => SIntType(MAX(PLUS(w1, IntWidth(1)), PLUS(w2, IntWidth(2))))
case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1)))
case _ => UnknownType
}
diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala
new file mode 100644
index 00000000..4c9ad687
--- /dev/null
+++ b/src/test/scala/firrtlTests/WidthSpec.scala
@@ -0,0 +1,87 @@
+/*
+Copyright (c) 2014 - 2016 The Regents of the University of
+California (Regents). All Rights Reserved. Redistribution and use in
+source and binary forms, with or without modification, are permitted
+provided that the following conditions are met:
+ * Redistributions of source code must retain the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer.
+ * Redistributions in binary form must reproduce the above
+ copyright notice, this list of conditions and the following
+ two paragraphs of disclaimer in the documentation and/or other materials
+ provided with the distribution.
+ * Neither the name of the Regents nor the names of its contributors
+ may be used to endorse or promote products derived from this
+ software without specific prior written permission.
+IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT,
+SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS,
+ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF
+REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT
+LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF
+ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION
+TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR
+MODIFICATIONS.
+*/
+
+package firrtlTests
+
+import java.io._
+import org.scalatest._
+import org.scalatest.prop._
+import firrtl._
+import firrtl.ir.Circuit
+import firrtl.passes._
+import firrtl.Parser.IgnoreInfo
+
+class WidthSpec extends FirrtlFlatSpec {
+ def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo)
+ private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = {
+ val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) {
+ (c: Circuit, p: Pass) => p.run(c)
+ }
+ val lines = c.serialize.split("\n") map normalized
+
+ expected foreach { e =>
+ lines should contain(e)
+ }
+ }
+
+ "Add of UInt<2> and SInt<2>" should "return SInt<4>" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ InferWidths)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input x: UInt<2>
+ | input y: SInt<2>
+ | output z: SInt
+ | z <= add(x, y)""".stripMargin
+ val check = Seq( "output z : SInt<4>")
+ executeTest(input, check, passes)
+ }
+ "SInt<2> - UInt<3>" should "return SInt<5>" in {
+ val passes = Seq(
+ ToWorkingIR,
+ CheckHighForm,
+ ResolveKinds,
+ InferTypes,
+ CheckTypes,
+ InferWidths)
+ val input =
+ """circuit Unit :
+ | module Unit :
+ | input x: UInt<3>
+ | input y: SInt<2>
+ | output z: SInt
+ | z <= sub(y, x)""".stripMargin
+ val check = Seq( "output z : SInt<5>")
+ executeTest(input, check, passes)
+ }
+}