aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/passes
diff options
context:
space:
mode:
authorAdam Izraelevitz2016-10-17 15:10:12 -0700
committerGitHub2016-10-17 15:10:12 -0700
commit7d08b9a1486fef0459481f6e542464a29fbe1db5 (patch)
treee8b2289ac5cbecbd59d58cab8bd503287818ec5d /src/main/scala/firrtl/passes
parent2848d87721df110d0425114283cb5fa7e6c2ee03 (diff)
Add fixed point type (#322)
* WIP: Adding FixedType to Firrtl proper Got simple example running through width inference Checks should be ok Need to look into FixedLiteral more * Added simple test for fixed types * Added asFixedPoint to primops * Added tail case for FixedType * Added ConvertFixedToSInt.scala Added pass to MiddleToLowerFirrtl transform * Replace AsFixedType with AsSInt in fixed removal * Bugfix: constant from asFixed not deleted * Added unit test for bulk connect * Fixed partial connect bug #241 * Fixed missing case for FixedPoint in legalizeConnect * Add FixedMathSpec that demonstrates some problems with FixedPointMath * Fixed test and ConvertToSInt to pass. Negative binary points not easily supported, needs much more time to implement. * Refactored checking neg widths Make checking for negative binary points easier * Added tests for inferring many FixedType ops shl, shr, cat, bits, head, tail, setbp, shiftbp * Handle bpshl, bpshr, bpset in ConvertFixedToSInt Changed name from shiftbp -> bpshl, bpshr Change name from setbp -> bpset Added more tests * Added set binary point test that fails * Added simple test for zero binary point * gitignore fixes for antlr intermediate dir and intellij dir * removed unused imports retool the fixed point with zero binary point test * simplified example of inability to set binary point to zero * Temporary fix for zero-width binary point This fix allows for all widths to be zero, but since this is a feature I am working on next, I'm not going to bother with a more stringent check. * change version for dsp tools * Removed extra temporary file * Fixed merge bug * Fixed another merge bug * Removed commented out/unrelated files * Removed snake case
Diffstat (limited to 'src/main/scala/firrtl/passes')
-rw-r--r--src/main/scala/firrtl/passes/.ConvertFixedToSInt.scala.swobin0 -> 20480 bytes
-rw-r--r--src/main/scala/firrtl/passes/CheckChirrtl.scala13
-rw-r--r--src/main/scala/firrtl/passes/Checks.scala78
-rw-r--r--src/main/scala/firrtl/passes/ConvertFixedToSInt.scala119
-rw-r--r--src/main/scala/firrtl/passes/InferWidths.scala48
-rw-r--r--src/main/scala/firrtl/passes/Passes.scala2
6 files changed, 213 insertions, 47 deletions
diff --git a/src/main/scala/firrtl/passes/.ConvertFixedToSInt.scala.swo b/src/main/scala/firrtl/passes/.ConvertFixedToSInt.scala.swo
new file mode 100644
index 00000000..abd7c349
--- /dev/null
+++ b/src/main/scala/firrtl/passes/.ConvertFixedToSInt.scala.swo
Binary files differ
diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala
index 17e228eb..504702b5 100644
--- a/src/main/scala/firrtl/passes/CheckChirrtl.scala
+++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala
@@ -68,22 +68,21 @@ object CheckChirrtl extends Pass {
errors append new InvalidLOCException(info, mname)
case _ => // Do Nothing
}
-
def checkChirrtlW(info: Info, mname: String)(w: Width): Width = w match {
- case w: IntWidth if w.width <= 0 =>
- errors append new NegWidthException(info, mname)
+ case w: IntWidth if (w.width < BigInt(0)) =>
+ errors.append(new NegWidthException(info, mname))
w
case _ => w
}
- def checkChirrtlT(info: Info, mname: String)(t: Type): Type = {
+ def checkChirrtlT(info: Info, mname: String)(t: Type): Type =
t map checkChirrtlT(info, mname) match {
case t: VectorType if t.size < 0 =>
errors append new NegVecSizeException(info, mname)
- case _ => // Do nothing
+ t map checkChirrtlW(info, mname)
+ //case FixedType(width, point) => FixedType(checkChirrtlW(width), point)
+ case _ => t map checkChirrtlW(info, mname)
}
- t map checkChirrtlW(info, mname) map checkChirrtlT(info, mname)
- }
def validSubexp(info: Info, mname: String)(e: Expression): Expression = {
e match {
diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala
index e6cd1d04..8eebda2f 100644
--- a/src/main/scala/firrtl/passes/Checks.scala
+++ b/src/main/scala/firrtl/passes/Checks.scala
@@ -100,7 +100,7 @@ object CheckHighForm extends Pass {
correctNum(Option(2), 0)
case AsUInt | AsSInt | AsClock | Cvt | Neq | Not =>
correctNum(Option(1), 0)
- case Pad | Shl | Shr | Head | Tail =>
+ case AsFixedPoint | Pad | Shl | Shr | Head | Tail | BPShl | BPShr | BPSet =>
correctNum(Option(1), 1)
case Bits =>
correctNum(Option(1), 2)
@@ -132,21 +132,20 @@ object CheckHighForm extends Pass {
def checkHighFormW(info: Info, mname: String)(w: Width): Width = {
w match {
- case wx: IntWidth if wx.width <= 0 =>
+ case wx: IntWidth if wx.width < 0 =>
errors append new NegWidthException(info, mname)
case wx => // Do nothing
}
w
}
- def checkHighFormT(info: Info, mname: String)(t: Type): Type = {
- t match {
- case tx: VectorType if tx.size < 0 =>
+ def checkHighFormT(info: Info, mname: String)(t: Type): Type =
+ t map checkHighFormT(info, mname) match {
+ case tx: VectorType if tx.size < 0 =>
errors append new NegVecSizeException(info, mname)
- case _ => // Do nothing
+ t
+ case _ => t map checkHighFormW(info, mname)
}
- t map checkHighFormW(info, mname) map checkHighFormT(info, mname)
- }
def validSubexp(info: Info, mname: String)(e: Expression): Expression = {
e match {
@@ -265,6 +264,7 @@ object CheckTypes extends Pass {
s"$info: [module $mname] Primop $op requires all arguments to be UInt type.")
class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException(
s"$info: [module $mname] Primop $op requires all operands to have the same type.")
+ class OpNoMixFix(info:Info, mname: String, op: String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type.")
class OpNotAnalog(info: Info, mname: String, exp: String) extends PassException(
s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.")
class NodePassiveType(info: Info, mname: String) extends PassException(
@@ -304,12 +304,39 @@ object CheckTypes extends Pass {
if (ls exists (x => wt(ls.head.tpe) != wt(e.tpe)))
errors append new OpNotAllSameType(info, mname, e.op.serialize)
}
- def allInt(ls: Seq[Expression]) {
+ def allUSC(ls: Seq[Expression]) {
+ val error = ls.foldLeft(false)((error, x) => x.tpe match {
+ case (_: UIntType| _: SIntType| ClockType) => error
+ case _ => true
+ })
+ if (error) errors.append(new OpNotGround(info, mname, e.op.serialize))
+ }
+ def allUSF(ls: Seq[Expression]) {
+ val error = ls.foldLeft(false)((error, x) => x.tpe match {
+ case (_: UIntType| _: SIntType| _: FixedType) => error
+ case _ => true
+ })
+ if (error) errors.append(new OpNotGround(info, mname, e.op.serialize))
+ }
+ def allUS(ls: Seq[Expression]) {
if (ls exists (x => x.tpe match {
case _: UIntType | _: SIntType => false
case _ => true
})) errors append new OpNotGround(info, mname, e.op.serialize)
}
+ def allF(ls: Seq[Expression]) {
+ val error = ls.foldLeft(false)((error, x) => x.tpe match {
+ case _:FixedType => error
+ case _ => true
+ })
+ if (error) errors.append(new OpNotGround(info, mname, e.op.serialize))
+ }
+ def strictFix(ls: Seq[Expression]) =
+ ls.filter(!_.tpe.isInstanceOf[FixedType]).size match {
+ case 0 =>
+ case x if(x == ls.size) =>
+ case x => errors.append(new OpNoMixFix(info, mname, e.op.serialize))
+ }
def all_uint (ls: Seq[Expression]) {
if (ls exists (x => x.tpe match {
case _: UIntType => false
@@ -323,10 +350,14 @@ object CheckTypes extends Pass {
}) errors append new OpNotUInt(info, mname, e.op.serialize, x.serialize)
}
e.op match {
- case AsUInt | AsSInt | AsClock =>
- case Dshl => is_uint(e.args(1)); allInt(e.args)
- case Dshr => is_uint(e.args(1)); allInt(e.args)
- case _ => allInt(e.args)
+ case AsUInt | AsSInt | AsFixedPoint =>
+ case AsClock => allUSC(e.args)
+ case Dshl => is_uint(e.args(1)); allUSF(e.args)
+ case Dshr => is_uint(e.args(1)); allUSF(e.args)
+ case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => allUSF(e.args); strictFix(e.args)
+ case Pad | Shl | Shr | Cat | Bits | Head | Tail => allUSF(e.args)
+ case BPShl | BPShr | BPSet => allF(e.args)
+ case _ => allUS(e.args)
}
}
@@ -383,6 +414,7 @@ object CheckTypes extends Pass {
case (ClockType, ClockType) => flip1 == flip2
case (_: UIntType, _: UIntType) => flip1 == flip2
case (_: SIntType, _: SIntType) => flip1 == flip2
+ case (_: FixedType, _: FixedType) => flip1 == flip2
case (_: AnalogType, _: AnalogType) => false
case (t1: BundleType, t2: BundleType) =>
val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())(
@@ -583,7 +615,7 @@ object CheckWidths extends Pass {
def check_width_w(info: Info, mname: String)(w: Width): Width = {
w match {
- case w: IntWidth if w.width > 0 =>
+ case w: IntWidth if w.width >= 0 =>
case _: IntWidth =>
errors append new NegWidthException(info, mname)
case _ =>
@@ -592,6 +624,9 @@ object CheckWidths extends Pass {
w
}
+ def check_width_t(info: Info, mname: String)(t: Type): Type =
+ t map check_width_t(info, mname) map check_width_w(info, mname)
+
def check_width_e(info: Info, mname: String)(e: Expression): Expression = {
e match {
case e: UIntLiteral => e.width match {
@@ -614,26 +649,25 @@ object CheckWidths extends Pass {
errors append new WidthTooBig(info, mname)
case _ =>
}
- e map check_width_w(info, mname) map check_width_e(info, mname)
+ //e map check_width_t(info, mname) map check_width_e(info, mname)
+ e map check_width_e(info, mname)
}
+
def check_width_s(minfo: Info, mname: String)(s: Statement): Statement = {
val info = get_info(s) match { case NoInfo => minfo case x => x }
- s map check_width_e(info, mname) map check_width_s(info, mname) match {
- case Attach(infox, source, exprs) =>
+ s map check_width_e(info, mname) map check_width_s(info, mname) map check_width_t(info, mname) match {
+ case Attach(infox, source, exprs) =>
exprs foreach ( e =>
if (bitWidth(e.tpe) != bitWidth(source.tpe))
errors append new AttachWidthsNotEqual(infox, mname, e.serialize, source.serialize)
)
s
case _ => s
- }
+ }
}
- def check_width_p(minfo: Info, mname: String)(p: Port): Port = {
- p.tpe map check_width_w(p.info, mname)
- p
- }
+ def check_width_p(minfo: Info, mname: String)(p: Port): Port = p.copy(tpe = check_width_t(p.info, mname)(p.tpe))
def check_width_m(m: DefModule) {
m map check_width_p(m.info, m.name) map check_width_s(m.info, m.name)
diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
new file mode 100644
index 00000000..8cf0b890
--- /dev/null
+++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala
@@ -0,0 +1,119 @@
+package firrtl.passes
+
+import scala.collection.mutable
+import firrtl.PrimOps._
+import firrtl.ir._
+import firrtl._
+import firrtl.Mappers._
+import firrtl.Utils.{sub_type, module_type, field_type, BoolType, max, min, pow_minus_one}
+
+/** Replaces FixedType with SIntType, and correctly aligns all binary points
+ */
+object ConvertFixedToSInt extends Pass {
+ def name = "Convert Fixed Types to SInt Types"
+ def alignArg(e: Expression, point: BigInt): Expression = e.tpe match {
+ case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p)
+ if((point - p) > 0) {
+ DoPrim(Shl, Seq(e), Seq(point - p), UnknownType)
+ } else if (point - p < 0) {
+ DoPrim(Shr, Seq(e), Seq(p - point), UnknownType)
+ } else e
+ case FixedType(w, p) => error("Shouldn't be here")
+ case _ => e
+ }
+ def calcPoint(es: Seq[Expression]): BigInt =
+ es.map(_.tpe match {
+ case FixedType(IntWidth(w), IntWidth(p)) => p
+ case _ => BigInt(0)
+ }).reduce(max(_, _))
+ def toSIntType(t: Type): Type = t match {
+ case FixedType(IntWidth(w), IntWidth(p)) => SIntType(IntWidth(w))
+ case FixedType(w, p) => error("Shouldn't be here")
+ case _ => t
+ }
+ def run(c: Circuit): Circuit = {
+ val moduleTypes = mutable.HashMap[String,Type]()
+ def onModule(m:DefModule) : DefModule = {
+ val types = mutable.HashMap[String,Type]()
+ def updateExpType(e:Expression): Expression = e match {
+ case DoPrim(Mul, args, consts, tpe) => e map updateExpType
+ case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe)
+ case DoPrim(BPShl, args, consts, tpe) => DoPrim(Shl, args, consts, tpe)
+ case DoPrim(BPShr, args, consts, tpe) => DoPrim(Shr, args, consts, tpe)
+ case DoPrim(BPSet, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p)
+ case DoPrim(op, args, consts, tpe) =>
+ val point = calcPoint(args)
+ val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType)
+ newExp map updateExpType match {
+ case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe)
+ case e => e
+ }
+ case Mux(cond, tval, fval, tpe) =>
+ val point = calcPoint(Seq(tval, fval))
+ val newExp = Mux(cond, alignArg(tval, point), alignArg(fval, point), UnknownType)
+ newExp map updateExpType
+ case e: UIntLiteral => e
+ case e: SIntLiteral => e
+ case _ => e map updateExpType match {
+ case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe)
+ case WRef(name, tpe, k, g) => WRef(name, types(name), k, g)
+ case WSubField(exp, name, tpe, g) => WSubField(exp, name, field_type(exp.tpe, name), g)
+ case WSubIndex(exp, value, tpe, g) => WSubIndex(exp, value, sub_type(exp.tpe), g)
+ case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g)
+ }
+ }
+ def updateStmtType(s: Statement): Statement = s match {
+ case DefRegister(info, name, tpe, clock, reset, init) =>
+ val newType = toSIntType(tpe)
+ types(name) = newType
+ DefRegister(info, name, newType, clock, reset, init) map updateExpType
+ case DefWire(info, name, tpe) =>
+ val newType = toSIntType(tpe)
+ types(name) = newType
+ DefWire(info, name, newType)
+ case DefNode(info, name, value) =>
+ val newValue = updateExpType(value)
+ val newType = toSIntType(newValue.tpe)
+ types(name) = newType
+ DefNode(info, name, newValue)
+ case DefMemory(info, name, dt, depth, wL, rL, rs, ws, rws, ruw) =>
+ val newStmt = DefMemory(info, name, toSIntType(dt), depth, wL, rL, rs, ws, rws, ruw)
+ val newType = MemPortUtils.memType(newStmt)
+ types(name) = newType
+ newStmt
+ case WDefInstance(info, name, module, tpe) =>
+ val newType = moduleTypes(module)
+ types(name) = newType
+ WDefInstance(info, name, module, newType)
+ case Connect(info, loc, exp) =>
+ val point = calcPoint(Seq(loc))
+ val newExp = alignArg(exp, point)
+ Connect(info, loc, newExp) map updateExpType
+ case PartialConnect(info, loc, exp) =>
+ val point = calcPoint(Seq(loc))
+ val newExp = alignArg(exp, point)
+ PartialConnect(info, loc, newExp) map updateExpType
+ // check Connect case, need to shl
+ case s => (s map updateStmtType) map updateExpType
+ }
+
+ m.ports.foreach(p => types(p.name) = p.tpe)
+ m match {
+ case Module(info, name, ports, body) => Module(info,name,ports,updateStmtType(body))
+ case m:ExtModule => m
+ }
+ }
+
+ val newModules = for(m <- c.modules) yield {
+ val newPorts = m.ports.map(p => Port(p.info,p.name,p.direction,toSIntType(p.tpe)))
+ m match {
+ case Module(info, name, ports, body) => Module(info,name,newPorts,body)
+ case ExtModule(info, name, ports) => ExtModule(info,name,newPorts)
+ }
+ }
+ newModules.foreach(m => moduleTypes(m.name) = module_type(m))
+ firrtl.passes.InferTypes.run(Circuit(c.info, newModules.map(onModule(_)), c.main ))
+ }
+}
+
+// vim: set ts=4 sw=4 et:
diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala
index 8fd5eef1..ac786386 100644
--- a/src/main/scala/firrtl/passes/InferWidths.scala
+++ b/src/main/scala/firrtl/passes/InferWidths.scala
@@ -200,14 +200,19 @@ object InferWidths extends Pass {
def run (c: Circuit): Circuit = {
val v = ArrayBuffer[WGeq]()
- def get_constraints_t(t1: Type, t2: Type, f: Orientation): Seq[WGeq] = (t1,t2) match {
+ def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match {
case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width))
case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width))
+ case (ClockType, ClockType) => Nil
+ case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2))
case (t1: BundleType, t2: BundleType) =>
(t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) =>
- res ++ get_constraints_t(f1.tpe, f2.tpe, times(f1.flip, f))
+ res ++ (f1.flip match {
+ case Default => get_constraints_t(f1.tpe, f2.tpe)
+ case Flip => get_constraints_t(f2.tpe, f1.tpe)
+ })
}
- case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe, f)
+ case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe)
}
def get_constraints_e(e: Expression): Expression = {
@@ -221,38 +226,44 @@ object InferWidths extends Pass {
e map get_constraints_e
}
+ def get_constraints_declared_type (t: Type): Type = t match {
+ case FixedType(_, p) =>
+ v += WGeq(p,IntWidth(0))
+ t
+ case _ => t map get_constraints_declared_type
+ }
+
def get_constraints_s(s: Statement): Statement = {
- s match {
+ s map get_constraints_declared_type match {
case (s: Connect) =>
val n = get_size(s.loc.tpe)
val locs = create_exps(s.loc)
val exps = create_exps(s.expr)
- v ++= ((locs zip exps).zipWithIndex map {case ((locx, expx), i) =>
+ v ++= ((locs zip exps).zipWithIndex flatMap {case ((locx, expx), i) =>
get_flip(s.loc.tpe, i, Default) match {
- case Default => WGeq(getWidth(locx), getWidth(expx))
- case Flip => WGeq(getWidth(expx), getWidth(locx))
+ case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx))
+ case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx))
}
})
case (s: PartialConnect) =>
val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default)
val locs = create_exps(s.loc)
val exps = create_exps(s.expr)
- v ++= (ls map {case (x, y) =>
+ v ++= (ls flatMap {case (x, y) =>
val locx = locs(x)
val expx = exps(y)
get_flip(s.loc.tpe, x, Default) match {
- case Default => WGeq(getWidth(locx), getWidth(expx))
- case Flip => WGeq(getWidth(expx), getWidth(locx))
+ case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx))
+ case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx))
}
})
- case (s:DefRegister) => v ++= (Seq(
- WGeq(getWidth(s.reset), IntWidth(1)),
- WGeq(IntWidth(1), getWidth(s.reset))
- ) ++ get_constraints_t(s.tpe, s.init.tpe, Default))
- case (s:Conditionally) => v ++= Seq(
- WGeq(getWidth(s.pred), IntWidth(1)),
- WGeq(IntWidth(1), getWidth(s.pred))
- )
+ case (s: DefRegister) => v ++= (
+ get_constraints_t(s.reset.tpe, UIntType(IntWidth(1))) ++
+ get_constraints_t(UIntType(IntWidth(1)), s.reset.tpe) ++
+ get_constraints_t(s.tpe, s.init.tpe))
+ case (s:Conditionally) => v ++=
+ get_constraints_t(s.pred.tpe, UIntType(IntWidth(1))) ++
+ get_constraints_t(UIntType(IntWidth(1)), s.pred.tpe)
case (s: Attach) =>
v += WGeq(getWidth(s.source), MaxWidth(s.exprs map (e => getWidth(e.tpe))))
case _ =>
@@ -261,6 +272,7 @@ object InferWidths extends Pass {
}
c.modules foreach (_ map get_constraints_s)
+ c.modules foreach (_.ports foreach {p => get_constraints_declared_type(p.tpe)})
//println("======== ALL CONSTRAINTS ========")
//for(x <- v) println(x)
diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala
index 11ea4951..0c4642e4 100644
--- a/src/main/scala/firrtl/passes/Passes.scala
+++ b/src/main/scala/firrtl/passes/Passes.scala
@@ -217,6 +217,7 @@ object Legalize extends Pass {
expr.args.head match {
case UIntLiteral(value, _) => UIntLiteral((value >> low.toInt) & mask, width)
case SIntLiteral(value, _) => SIntLiteral((value >> low.toInt) & mask, width)
+ //case FixedLiteral
case _ => expr
}
}
@@ -237,6 +238,7 @@ object Legalize extends Pass {
val expr = t match {
case UIntType(_) => bits
case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w)))
+ //case FixedType(width, point) => FixedType(width, point)
}
Connect(c.info, c.loc, expr)
}