aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJack Koenig2018-06-11 13:27:56 -0700
committerGitHub2018-06-11 13:27:56 -0700
commit535d8025412a64471d8cc9c315505a8e2cbddbe0 (patch)
tree4c52dd0a665192b55763fec2d3b47db23b561bad /src
parent9bd639acf58ad3a6c13b858d65845a95ddac1610 (diff)
Add utilities for UInt and SInt literals (#815)
Also minor cleanup to literal construction in Visitor
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/firrtl/Utils.scala5
-rw-r--r--src/main/scala/firrtl/Visitor.scala45
-rw-r--r--src/main/scala/firrtl/ir/IR.scala8
-rw-r--r--src/main/scala/firrtl/passes/RemoveAccesses.scala2
-rw-r--r--src/test/scala/firrtlTests/UnitTests.scala2
-rw-r--r--src/test/scala/firrtlTests/WidthSpec.scala24
6 files changed, 53 insertions, 33 deletions
diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala
index 30236231..dfb635c2 100644
--- a/src/main/scala/firrtl/Utils.scala
+++ b/src/main/scala/firrtl/Utils.scala
@@ -237,9 +237,8 @@ object Utils extends LazyLogging {
def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a
def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1
val BoolType = UIntType(IntWidth(1))
- val one = UIntLiteral(BigInt(1), IntWidth(1))
- val zero = UIntLiteral(BigInt(0), IntWidth(1))
- def uint(i: BigInt): UIntLiteral = UIntLiteral(i, IntWidth(1 max i.bitLength))
+ val one = UIntLiteral(1)
+ val zero = UIntLiteral(0)
def create_exps(n: String, t: Type): Seq[Expression] =
create_exps(WRef(n, t, ExpKind, UNKNOWNGENDER))
diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala
index c45d7f56..64249c11 100644
--- a/src/main/scala/firrtl/Visitor.scala
+++ b/src/main/scala/firrtl/Visitor.scala
@@ -284,9 +284,6 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] {
}
}
- // TODO
- // - Add mux
- // - Add validif
private def visitExp[FirrtlNode](ctx: FIRRTLParser.ExpContext): Expression = {
val ctx_exp = ctx.exp.asScala
if (ctx.getChildCount == 1)
@@ -294,32 +291,24 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] {
else
ctx.getChild(0).getText match {
case "UInt" =>
- // This could be better
- val (width, value) =
- if (ctx.getChildCount > 4)
- (IntWidth(string2BigInt(ctx.intLit(0).getText)), string2BigInt(ctx.intLit(1).getText))
- else {
- val bigint = string2BigInt(ctx.intLit(0).getText)
- (IntWidth(BigInt(scala.math.max(bigint.bitLength, 1))), bigint)
- }
- UIntLiteral(value, width)
+ if (ctx.getChildCount > 4) {
+ val width = IntWidth(string2BigInt(ctx.intLit(0).getText))
+ val value = string2BigInt(ctx.intLit(1).getText)
+ UIntLiteral(value, width)
+ } else {
+ val value = string2BigInt(ctx.intLit(0).getText)
+ UIntLiteral(value)
+ }
case "SInt" =>
- val (width, value) =
- if (ctx.getChildCount > 4) {
- val width = string2BigInt(ctx.intLit(0).getText)
- val value = string2BigInt(ctx.intLit(1).getText)
- (IntWidth(width), value)
- } else {
- val str = ctx.intLit(0).getText
- val value = string2BigInt(str)
- // To calculate bitwidth of negative number,
- // 1) negate number and subtract one to get the maximum positive value.
- // 2) get bitwidth of max positive number
- // 3) add one to account for the signed representation
- val width = if (value < 0) (value.abs - BigInt(1)).bitLength + 1 else value.bitLength + 1
- (IntWidth(BigInt(width)), value)
- }
- SIntLiteral(value, width)
+ if (ctx.getChildCount > 4) {
+ val width = string2BigInt(ctx.intLit(0).getText)
+ val value = string2BigInt(ctx.intLit(1).getText)
+ SIntLiteral(value, IntWidth(width))
+ } else {
+ val str = ctx.intLit(0).getText
+ val value = string2BigInt(str)
+ SIntLiteral(value)
+ }
case "validif(" => ValidIf(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), UnknownType)
case "mux(" => Mux(visitExp(ctx_exp(0)), visitExp(ctx_exp(1)), visitExp(ctx_exp(2)), UnknownType)
case _ =>
diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala
index 3887f17d..fc741b28 100644
--- a/src/main/scala/firrtl/ir/IR.scala
+++ b/src/main/scala/firrtl/ir/IR.scala
@@ -161,6 +161,10 @@ case class UIntLiteral(value: BigInt, width: Width) extends Literal {
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = UIntLiteral(value, f(width))
}
+object UIntLiteral {
+ def minWidth(value: BigInt): Width = IntWidth(math.max(value.bitLength, 1))
+ def apply(value: BigInt): UIntLiteral = new UIntLiteral(value, minWidth(value))
+}
case class SIntLiteral(value: BigInt, width: Width) extends Literal {
def tpe = SIntType(width)
def serialize = s"""SInt${width.serialize}("h""" + value.toString(16)+ """")"""
@@ -168,6 +172,10 @@ case class SIntLiteral(value: BigInt, width: Width) extends Literal {
def mapType(f: Type => Type): Expression = this
def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width))
}
+object SIntLiteral {
+ def minWidth(value: BigInt): Width = IntWidth(value.bitLength + 1)
+ def apply(value: BigInt): SIntLiteral = new SIntLiteral(value, minWidth(value))
+}
case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Literal {
def tpe = FixedType(width, point)
def serialize = {
diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala
index 30aae284..961f5aba 100644
--- a/src/main/scala/firrtl/passes/RemoveAccesses.scala
+++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala
@@ -54,7 +54,7 @@ object RemoveAccesses extends Pass {
ls.zipWithIndex map {case (l, i) =>
val c = (i / stride) % wrap
val basex = l.base
- val guardx = AND(l.guard,EQV(uint(c),e.index))
+ val guardx = AND(l.guard,EQV(UIntLiteral(c),e.index))
Location(basex,guardx)
}
}
diff --git a/src/test/scala/firrtlTests/UnitTests.scala b/src/test/scala/firrtlTests/UnitTests.scala
index a38a8def..62ed561e 100644
--- a/src/test/scala/firrtlTests/UnitTests.scala
+++ b/src/test/scala/firrtlTests/UnitTests.scala
@@ -405,7 +405,7 @@ class UnitTests extends FirrtlFlatSpec {
val result = execute(input, passes)
- def u(value: Int) = UIntLiteral(BigInt(value), IntWidth(scala.math.max(BigInt(value).bitLength, 1)))
+ def u(value: Int) = UIntLiteral(BigInt(value))
val ut16 = UIntType(IntWidth(BigInt(16)))
val ut2 = UIntType(IntWidth(BigInt(2)))
diff --git a/src/test/scala/firrtlTests/WidthSpec.scala b/src/test/scala/firrtlTests/WidthSpec.scala
index d1d02ee2..9ca965f6 100644
--- a/src/test/scala/firrtlTests/WidthSpec.scala
+++ b/src/test/scala/firrtlTests/WidthSpec.scala
@@ -22,6 +22,30 @@ class WidthSpec extends FirrtlFlatSpec {
}
}
+ case class LiteralWidthCheck(lit: BigInt, uIntWidth: Option[BigInt], sIntWidth: BigInt)
+ val litChecks = Seq(
+ LiteralWidthCheck(-4, None, 3),
+ LiteralWidthCheck(-3, None, 3),
+ LiteralWidthCheck(-2, None, 2),
+ LiteralWidthCheck(-1, None, 1),
+ LiteralWidthCheck(0, Some(1), 1), // TODO https://github.com/freechipsproject/firrtl/pull/530
+ LiteralWidthCheck(1, Some(1), 2),
+ LiteralWidthCheck(2, Some(2), 3),
+ LiteralWidthCheck(3, Some(2), 3),
+ LiteralWidthCheck(4, Some(3), 4)
+ )
+ for (LiteralWidthCheck(lit, uwo, sw) <- litChecks) {
+ import firrtl.ir.{UIntLiteral, SIntLiteral, IntWidth}
+ s"$lit" should s"have signed width $sw" in {
+ SIntLiteral(lit).width should equal (IntWidth(sw))
+ }
+ uwo.foreach { uw =>
+ it should s"have unsigned width $uw" in {
+ UIntLiteral(lit).width should equal (IntWidth(uw))
+ }
+ }
+ }
+
"Dshl by 20 bits" should "result in an error" in {
val passes = Seq(
ToWorkingIR,