diff options
| author | Adam Izraelevitz | 2016-10-17 15:10:12 -0700 |
|---|---|---|
| committer | GitHub | 2016-10-17 15:10:12 -0700 |
| commit | 7d08b9a1486fef0459481f6e542464a29fbe1db5 (patch) | |
| tree | e8b2289ac5cbecbd59d58cab8bd503287818ec5d /src | |
| parent | 2848d87721df110d0425114283cb5fa7e6c2ee03 (diff) | |
Add fixed point type (#322)
* WIP: Adding FixedType to Firrtl proper
Got simple example running through width inference
Checks should be ok
Need to look into FixedLiteral more
* Added simple test for fixed types
* Added asFixedPoint to primops
* Added tail case for FixedType
* Added ConvertFixedToSInt.scala
Added pass to MiddleToLowerFirrtl transform
* Replace AsFixedType with AsSInt in fixed removal
* Bugfix: constant from asFixed not deleted
* Added unit test for bulk connect
* Fixed partial connect bug #241
* Fixed missing case for FixedPoint in legalizeConnect
* Add FixedMathSpec that demonstrates some problems with FixedPointMath
* Fixed test and ConvertToSInt to pass.
Negative binary points not easily supported, needs much more time to
implement.
* Refactored checking neg widths
Make checking for negative binary points easier
* Added tests for inferring many FixedType ops
shl, shr, cat, bits, head, tail, setbp, shiftbp
* Handle bpshl, bpshr, bpset in ConvertFixedToSInt
Changed name from shiftbp -> bpshl, bpshr
Change name from setbp -> bpset
Added more tests
* Added set binary point test that fails
* Added simple test for zero binary point
* gitignore fixes for antlr intermediate dir and intellij dir
* removed unused imports
retool the fixed point with zero binary point test
* simplified example of inability to set binary point to zero
* Temporary fix for zero-width binary point
This fix allows for all widths to be zero, but since this is a feature I
am working on next, I'm not going to bother with a more stringent check.
* change version for dsp tools
* Removed extra temporary file
* Fixed merge bug
* Fixed another merge bug
* Removed commented out/unrelated files
* Removed snake case
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/antlr4/FIRRTL.g4 | 5 | ||||
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 99 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 87 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Visitor.scala | 17 | ||||
| -rw-r--r-- | src/main/scala/firrtl/WIR.scala | 1 | ||||
| -rw-r--r-- | src/main/scala/firrtl/ir/IR.scala | 17 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/.ConvertFixedToSInt.scala.swo | bin | 0 -> 20480 bytes | |||
| -rw-r--r-- | src/main/scala/firrtl/passes/CheckChirrtl.scala | 13 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 78 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ConvertFixedToSInt.scala | 119 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/InferWidths.scala | 48 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 2 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala | 118 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala | 321 | ||||
| -rw-r--r-- | src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala | 221 |
16 files changed, 1070 insertions, 79 deletions
diff --git a/src/main/antlr4/FIRRTL.g4 b/src/main/antlr4/FIRRTL.g4 index 1232b65f..4ceed9f0 100644 --- a/src/main/antlr4/FIRRTL.g4 +++ b/src/main/antlr4/FIRRTL.g4 @@ -78,6 +78,7 @@ dir type : 'UInt' ('<' IntLit '>')? | 'SInt' ('<' IntLit '>')? + | 'Fixed' ('<' IntLit '>')? ('<' '<' IntLit '>' '>')? | 'Clock' | 'Analog' ('<' IntLit '>')? | '{' field* '}' // Bundle @@ -274,6 +275,10 @@ primop | 'bits(' | 'head(' | 'tail(' + | 'asFixedPoint(' + | 'bpshl(' + | 'bpshr(' + | 'bpset(' ; /*------------------------------------------------------------------ diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index eb44b4c2..307ef9d1 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -126,7 +126,8 @@ class MiddleFirrtlToLowFirrtl extends Transform with SimpleRun { passes.ResolveKinds, passes.InferTypes, passes.ResolveGenders, - passes.InferWidths) + passes.InferWidths, + passes.ConvertFixedToSInt) def execute(circuit: Circuit, annotationMap: AnnotationMap): TransformResult = run(circuit, passSeq) } diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index e3dd11c0..ea265369 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -28,7 +28,7 @@ MODIFICATIONS. package firrtl import firrtl.ir._ -import firrtl.Utils.{max, min, pow_minus_one} +import firrtl.Utils.{min, max, pow_minus_one} import com.typesafe.scalalogging.LazyLogging @@ -98,10 +98,18 @@ object PrimOps extends LazyLogging { case object Head extends PrimOp { override def toString = "head" } /** Tail */ case object Tail extends PrimOp { override def toString = "tail" } + /** Interpret as Fixed Point **/ + case object AsFixedPoint extends PrimOp { override def toString = "asFixedPoint" } + /** Shift Binary Point Left **/ + case object BPShl extends PrimOp { override def toString = "bpshl" } + /** Shift Binary Point Right **/ + case object BPShr extends PrimOp { override def toString = "bpshr" } + /** Set Binary Point **/ + case object BPSet extends PrimOp { override def toString = "bpset" } private lazy val builtinPrimOps: Seq[PrimOp] = Seq(Add, Sub, Mul, Div, Rem, Lt, Leq, Gt, Geq, Eq, Neq, Pad, AsUInt, AsSInt, AsClock, Shl, Shr, - Dshl, Dshr, Neg, Cvt, Not, And, Or, Xor, Andr, Orr, Xorr, Cat, Bits, Head, Tail) + Dshl, Dshr, Neg, Cvt, Not, And, Or, Xor, Andr, Orr, Xorr, Cat, Bits, Head, Tail, AsFixedPoint, BPShl, BPShr, BPSet) private lazy val strToPrimOp: Map[String, PrimOp] = builtinPrimOps.map { case op : PrimOp=> op.toString -> op }.toMap /** Seq of String representations of [[ir.PrimOp]]s */ @@ -109,34 +117,38 @@ object PrimOps extends LazyLogging { /** Gets the corresponding [[ir.PrimOp]] from its String representation */ def fromString(op: String): PrimOp = strToPrimOp(op) + // Width Constraint Functions + def PLUS (w1:Width, w2:Width) : Width = (w1, w2) match { + case (IntWidth(i), IntWidth(j)) => IntWidth(i + j) + case _ => PlusWidth(w1, w2) + } + def MAX (w1:Width, w2:Width) : Width = (w1, w2) match { + case (IntWidth(i), IntWidth(j)) => IntWidth(max(i,j)) + case _ => MaxWidth(Seq(w1, w2)) + } + def MINUS (w1:Width, w2:Width) : Width = (w1, w2) match { + case (IntWidth(i), IntWidth(j)) => IntWidth(i - j) + case _ => MinusWidth(w1, w2) + } + def POW (w1:Width) : Width = w1 match { + case IntWidth(i) => IntWidth(pow_minus_one(BigInt(2), i)) + case _ => ExpWidth(w1) + } + def MIN (w1:Width, w2:Width) : Width = (w1, w2) match { + case (IntWidth(i), IntWidth(j)) => IntWidth(min(i,j)) + case _ => MinWidth(Seq(w1, w2)) + } + // Borrowed from Stanza implementation def set_primop_type (e:DoPrim) : DoPrim = { //println-all(["Inferencing primop type: " e]) - def PLUS (w1:Width, w2:Width) : Width = (w1, w2) match { - case (IntWidth(i), IntWidth(j)) => IntWidth(i + j) - case _ => PlusWidth(w1, w2) - } - def MAX (w1:Width, w2:Width) : Width = (w1, w2) match { - case (IntWidth(i), IntWidth(j)) => IntWidth(max(i,j)) - case _ => MaxWidth(Seq(w1, w2)) - } - def MINUS (w1:Width, w2:Width) : Width = (w1, w2) match { - case (IntWidth(i), IntWidth(j)) => IntWidth(i - j) - case _ => MinusWidth(w1, w2) - } - def POW (w1:Width) : Width = w1 match { - case IntWidth(i) => IntWidth(pow_minus_one(BigInt(2), i)) - case _ => ExpWidth(w1) - } - def MIN (w1:Width, w2:Width) : Width = (w1, w2) match { - case (IntWidth(i), IntWidth(j)) => IntWidth(min(i,j)) - case _ => MinWidth(Seq(w1, w2)) - } def t1 = e.args.head.tpe def t2 = e.args(1).tpe def t3 = e.args(2).tpe def w1 = passes.getWidth(e.args.head.tpe) def w2 = passes.getWidth(e.args(1).tpe) + def p1 = t1 match { case FixedType(w, p) => p } //Intentional + def p2 = t2 match { case FixedType(w, p) => p } //Intentional def c1 = IntWidth(e.consts.head) def c2 = IntWidth(e.consts(1)) e copy (tpe = e.op match { @@ -145,6 +157,7 @@ object PrimOps extends LazyLogging { case (_: UIntType, _: SIntType) => SIntType(PLUS(MAX(w1, MINUS(w2, IntWidth(1))), IntWidth(2))) case (_: SIntType, _: UIntType) => SIntType(PLUS(MAX(w2, MINUS(w1, IntWidth(1))), IntWidth(2))) case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1))) + case (_: FixedType, _: FixedType) => FixedType(PLUS(PLUS(MAX(p1, p2), MAX(MINUS(w1, p1), MINUS(w2, p2))), IntWidth(1)), MAX(p1, p2)) case _ => UnknownType } case Sub => (t1, t2) match { @@ -152,6 +165,7 @@ object PrimOps extends LazyLogging { case (_: UIntType, _: SIntType) => SIntType(MAX(PLUS(w2, IntWidth(1)), PLUS(w1, IntWidth(2)))) case (_: SIntType, _: UIntType) => SIntType(MAX(PLUS(w1, IntWidth(1)), PLUS(w2, IntWidth(2)))) case (_: SIntType, _: SIntType) => SIntType(PLUS(MAX(w1, w2), IntWidth(1))) + case (_: FixedType, _: FixedType) => FixedType(PLUS(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))),IntWidth(1)), MAX(p1, p2)) case _ => UnknownType } case Mul => (t1, t2) match { @@ -159,6 +173,7 @@ object PrimOps extends LazyLogging { case (_: UIntType, _: SIntType) => SIntType(PLUS(w1, w2)) case (_: SIntType, _: UIntType) => SIntType(PLUS(w1, w2)) case (_: SIntType, _: SIntType) => SIntType(PLUS(w1, w2)) + case (_: FixedType, _: FixedType) => FixedType(PLUS(w1, w2), PLUS(p1, p2)) case _ => UnknownType } case Div => (t1, t2) match { @@ -180,6 +195,7 @@ object PrimOps extends LazyLogging { case (_: SIntType, _: UIntType) => Utils.BoolType case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Leq => (t1, t2) match { @@ -187,6 +203,7 @@ object PrimOps extends LazyLogging { case (_: SIntType, _: UIntType) => Utils.BoolType case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Gt => (t1, t2) match { @@ -194,6 +211,7 @@ object PrimOps extends LazyLogging { case (_: SIntType, _: UIntType) => Utils.BoolType case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Geq => (t1, t2) match { @@ -201,6 +219,7 @@ object PrimOps extends LazyLogging { case (_: SIntType, _: UIntType) => Utils.BoolType case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Eq => (t1, t2) match { @@ -208,6 +227,7 @@ object PrimOps extends LazyLogging { case (_: SIntType, _: UIntType) => Utils.BoolType case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Neq => (t1, t2) match { @@ -215,16 +235,19 @@ object PrimOps extends LazyLogging { case (_: SIntType, _: UIntType) => Utils.BoolType case (_: UIntType, _: SIntType) => Utils.BoolType case (_: SIntType, _: SIntType) => Utils.BoolType + case (_: FixedType, _: FixedType) => Utils.BoolType case _ => UnknownType } case Pad => t1 match { case _: UIntType => UIntType(MAX(w1, c1)) case _: SIntType => SIntType(MAX(w1, c1)) + case _: FixedType => FixedType(MAX(w1, c1), p1) case _ => UnknownType } case AsUInt => t1 match { case _: UIntType => UIntType(w1) case _: SIntType => UIntType(w1) + case _: FixedType => UIntType(w1) case ClockType => UIntType(IntWidth(1)) case AnalogType(w) => UIntType(w1) case _ => UnknownType @@ -232,10 +255,19 @@ object PrimOps extends LazyLogging { case AsSInt => t1 match { case _: UIntType => SIntType(w1) case _: SIntType => SIntType(w1) + case _: FixedType => SIntType(w1) case ClockType => SIntType(IntWidth(1)) case _: AnalogType => SIntType(w1) case _ => UnknownType } + case AsFixedPoint => t1 match { + case _: UIntType => FixedType(w1, c1) + case _: SIntType => FixedType(w1, c1) + case _: FixedType => FixedType(w1, c1) + case ClockType => FixedType(IntWidth(1), c1) + case _: AnalogType => FixedType(w1, c1) + case _ => UnknownType + } case AsClock => t1 match { case _: UIntType => ClockType case _: SIntType => ClockType @@ -246,21 +278,25 @@ object PrimOps extends LazyLogging { case Shl => t1 match { case _: UIntType => UIntType(PLUS(w1, c1)) case _: SIntType => SIntType(PLUS(w1, c1)) + case _: FixedType => FixedType(PLUS(w1,c1), p1) case _ => UnknownType } case Shr => t1 match { case _: UIntType => UIntType(MAX(MINUS(w1, c1), IntWidth(1))) case _: SIntType => SIntType(MAX(MINUS(w1, c1), IntWidth(1))) + case _: FixedType => FixedType(MAX(MAX(MINUS(w1,c1), IntWidth(1)), p1), p1) case _ => UnknownType } case Dshl => t1 match { case _: UIntType => UIntType(PLUS(w1, POW(w2))) case _: SIntType => SIntType(PLUS(w1, POW(w2))) + case _: FixedType => FixedType(PLUS(w1, POW(w2)), p1) case _ => UnknownType } case Dshr => t1 match { case _: UIntType => UIntType(w1) case _: SIntType => SIntType(w1) + case _: FixedType => FixedType(w1, p1) case _ => UnknownType } case Cvt => t1 match { @@ -304,18 +340,33 @@ object PrimOps extends LazyLogging { } case Cat => (t1, t2) match { case (_: UIntType | _: SIntType, _: UIntType | _: SIntType) => UIntType(PLUS(w1, w2)) + case (_: FixedType, _: UIntType| _: SIntType) => FixedType(PLUS(w1, w2), PLUS(p1, w2)) + case (_: UIntType | _: SIntType, _: FixedType) => FixedType(PLUS(w1, w2), p1) case (t1, t2) => UnknownType } case Bits => t1 match { case (_: UIntType | _: SIntType) => UIntType(PLUS(MINUS(c1, c2), IntWidth(1))) + case _: FixedType => UIntType(PLUS(MINUS(c1, c2), IntWidth(1))) case _ => UnknownType } case Head => t1 match { - case (_: UIntType | _: SIntType) => UIntType(c1) + case (_: UIntType | _: SIntType | _: FixedType) => UIntType(c1) case _ => UnknownType } case Tail => t1 match { - case (_: UIntType | _: SIntType) => UIntType(MINUS(w1, c1)) + case (_: UIntType | _: SIntType | _: FixedType) => UIntType(MINUS(w1, c1)) + case _ => UnknownType + } + case BPShl => t1 match { + case _: FixedType => FixedType(PLUS(w1,c1), PLUS(p1, c1)) + case _ => UnknownType + } + case BPShr => t1 match { + case _: FixedType => FixedType(MINUS(w1,c1), MINUS(p1, c1)) + case _ => UnknownType + } + case BPSet => t1 match { + case _: FixedType => FixedType(PLUS(c1, MINUS(w1, p1)), c1) case _ => UnknownType } }) diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index a9f9418e..294afe57 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -158,10 +158,90 @@ object Utils extends LazyLogging { } //============== TYPES ================ +//<<<<<<< HEAD +// def mux_type (e1:Expression,e2:Expression) : Type = mux_type(tpe(e1),tpe(e2)) +// def mux_type (t1:Type,t2:Type) : Type = { +// if (wt(t1) == wt(t2)) { +// (t1,t2) match { +// case (t1:UIntType,t2:UIntType) => UIntType(UnknownWidth) +// case (t1:SIntType,t2:SIntType) => SIntType(UnknownWidth) +// case (t1:FixedType,t2:FixedType) => FixedType(UnknownWidth, UnknownWidth) +// case (t1:VectorType,t2:VectorType) => VectorType(mux_type(t1.tpe,t2.tpe),t1.size) +// case (t1:BundleType,t2:BundleType) => +// BundleType((t1.fields,t2.fields).zipped.map((f1,f2) => { +// Field(f1.name,f1.flip,mux_type(f1.tpe,f2.tpe)) +// })) +// } +// } else UnknownType +// } +// def mux_type_and_widths (e1:Expression,e2:Expression) : Type = mux_type_and_widths(tpe(e1),tpe(e2)) +// def PLUS (w1:Width,w2:Width) : Width = (w1, w2) match { +// case (IntWidth(i), IntWidth(j)) => IntWidth(i + j) +// case _ => PlusWidth(w1,w2) +// } +// def MAX (w1:Width,w2:Width) : Width = (w1, w2) match { +// case (IntWidth(i), IntWidth(j)) => IntWidth(max(i,j)) +// case _ => MaxWidth(Seq(w1,w2)) +// } +// def MINUS (w1:Width,w2:Width) : Width = (w1, w2) match { +// case (IntWidth(i), IntWidth(j)) => IntWidth(i - j) +// case _ => MinusWidth(w1,w2) +// } +// def POW (w1:Width) : Width = w1 match { +// case IntWidth(i) => IntWidth(pow_minus_one(BigInt(2), i)) +// case _ => ExpWidth(w1) +// } +// def MIN (w1:Width,w2:Width) : Width = (w1, w2) match { +// case (IntWidth(i), IntWidth(j)) => IntWidth(min(i,j)) +// case _ => MinWidth(Seq(w1,w2)) +// } +// def mux_type_and_widths (t1:Type,t2:Type) : Type = { +// def wmax (w1:Width,w2:Width) : Width = { +// (w1,w2) match { +// case (w1:IntWidth,w2:IntWidth) => IntWidth(w1.width.max(w2.width)) +// case (w1,w2) => MaxWidth(Seq(w1,w2)) +// } +// } +// val wt1 = new WrappedType(t1) +// val wt2 = new WrappedType(t2) +// if (wt1 == wt2) { +// (t1,t2) match { +// case (t1:UIntType,t2:UIntType) => UIntType(wmax(t1.width,t2.width)) +// case (t1:SIntType,t2:SIntType) => SIntType(wmax(t1.width,t2.width)) +// case (FixedType(w1, p1), FixedType(w2, p2)) => +// FixedType(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))), MAX(p1, p2)) +// case (t1:VectorType,t2:VectorType) => VectorType(mux_type_and_widths(t1.tpe,t2.tpe),t1.size) +// case (t1:BundleType,t2:BundleType) => BundleType((t1.fields zip t2.fields).map{case (f1, f2) => Field(f1.name,f1.flip,mux_type_and_widths(f1.tpe,f2.tpe))}) +// } +// } else UnknownType +// } +// def module_type (m:DefModule) : Type = { +// BundleType(m.ports.map(p => p.toField)) +// } +// def sub_type (v:Type) : Type = { +// v match { +// case v:VectorType => v.tpe +// case v => UnknownType +// } +// } +// def field_type (v:Type,s:String) : Type = { +// v match { +// case v:BundleType => { +// val ft = v.fields.find(p => p.name == s) +// ft match { +// case ft:Some[Field] => ft.get.tpe +// case ft => UnknownType +// } +// } +// case v => UnknownType +// } +// } +//======= def mux_type(e1: Expression, e2: Expression): Type = mux_type(e1.tpe, e2.tpe) def mux_type(t1: Type, t2: Type): Type = (t1, t2) match { case (t1: UIntType, t2: UIntType) => UIntType(UnknownWidth) case (t1: SIntType, t2: SIntType) => SIntType(UnknownWidth) + case (t1: FixedType, t2: FixedType) => FixedType(UnknownWidth, UnknownWidth) case (t1: VectorType, t2: VectorType) => VectorType(mux_type(t1.tpe, t2.tpe), t1.size) case (t1: BundleType, t2: BundleType) => BundleType(t1.fields zip t2.fields map { case (f1, f2) => Field(f1.name, f1.flip, mux_type(f1.tpe, f2.tpe)) @@ -178,6 +258,8 @@ object Utils extends LazyLogging { (t1, t2) match { case (t1x: UIntType, t2x: UIntType) => UIntType(wmax(t1x.width, t2x.width)) case (t1x: SIntType, t2x: SIntType) => SIntType(wmax(t1x.width, t2x.width)) + case (FixedType(w1, p1), FixedType(w2, p2)) => + FixedType(PLUS(MAX(p1, p2),MAX(MINUS(w1, p1), MINUS(w2, p2))), MAX(p1, p2)) case (t1x: VectorType, t2x: VectorType) => VectorType( mux_type_and_widths(t1x.tpe, t2x.tpe), t1x.size) case (t1x: BundleType, t2x: BundleType) => BundleType(t1x.fields zip t2x.fields map { @@ -201,6 +283,7 @@ object Utils extends LazyLogging { } case vx => UnknownType } +//>>>>>>> e54fb610c6bf0a7fe5c9c0f0e0b3acbb3728cfd0 // ================================= def error(str: String) = throw new FIRRTLException(str) @@ -218,6 +301,7 @@ object Utils extends LazyLogging { (t1, t2) match { case (_: UIntType, _: UIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil case (_: SIntType, _: SIntType) => if (flip1 == flip2) Seq((0, 0)) else Nil + case (_: FixedType, _: FixedType) => if (flip1 == flip2) Seq((0, 0)) else Nil case (t1x: BundleType, t2x: BundleType) => def emptyMap = Map[String, (Type, Orientation, Int)]() val t1_fields = t1x.fields.foldLeft(emptyMap, 0) { case ((map, ilen), f1) => @@ -455,11 +539,8 @@ object Utils extends LazyLogging { "final", "first_match", "for", "force", "foreach", "forever", "fork", "forkjoin", "function", - "generate", "genvar", - "highz0", "highz1", - "if", "iff", "ifnone", "ignore_bins", "illegal_bins", "import", "incdir", "include", "initial", "initvar", "inout", "input", "inside", "instance", "int", "integer", "interconnect", diff --git a/src/main/scala/firrtl/Visitor.scala b/src/main/scala/firrtl/Visitor.scala index b8850e53..d45283c6 100644 --- a/src/main/scala/firrtl/Visitor.scala +++ b/src/main/scala/firrtl/Visitor.scala @@ -129,19 +129,28 @@ class Visitor(infoMode: InfoMode) extends FIRRTLBaseVisitor[FirrtlNode] { // Match on a type instead of on strings? private def visitType[FirrtlNode](ctx: FIRRTLParser.TypeContext): Type = { + def getWidth(n: TerminalNode): Width = IntWidth(string2BigInt(n.getText)) ctx.getChild(0) match { case term: TerminalNode => term.getText match { - case "UInt" => if (ctx.getChildCount > 1) UIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) + case "UInt" => if (ctx.getChildCount > 1) UIntType(IntWidth(string2BigInt(ctx.IntLit(0).getText))) else UIntType(UnknownWidth) - case "SInt" => if (ctx.getChildCount > 1) SIntType(IntWidth(string2BigInt(ctx.IntLit.getText))) + case "SInt" => if (ctx.getChildCount > 1) SIntType(IntWidth(string2BigInt(ctx.IntLit(0).getText))) else SIntType(UnknownWidth) + case "Fixed" => ctx.IntLit.size match { + case 0 => FixedType(UnknownWidth, UnknownWidth) + case 1 => ctx.getChild(2).getText match { + case "<" => FixedType(UnknownWidth, getWidth(ctx.IntLit(0))) + case _ => FixedType(getWidth(ctx.IntLit(0)), UnknownWidth) + } + case 2 => FixedType(getWidth(ctx.IntLit(0)), getWidth(ctx.IntLit(1))) + } case "Clock" => ClockType - case "Analog" => if (ctx.getChildCount > 1) AnalogType(IntWidth(string2BigInt(ctx.IntLit.getText))) + case "Analog" => if (ctx.getChildCount > 1) AnalogType(IntWidth(string2BigInt(ctx.IntLit(0).getText))) else AnalogType(UnknownWidth) case "{" => BundleType(ctx.field.map(visitField)) } - case typeContext: TypeContext => new VectorType(visitType(ctx.`type`), string2Int(ctx.IntLit.getText)) + case typeContext: TypeContext => new VectorType(visitType(ctx.`type`), string2Int(ctx.IntLit(0).getText)) } } diff --git a/src/main/scala/firrtl/WIR.scala b/src/main/scala/firrtl/WIR.scala index 2acc9b4c..3dfddc17 100644 --- a/src/main/scala/firrtl/WIR.scala +++ b/src/main/scala/firrtl/WIR.scala @@ -188,6 +188,7 @@ class WrappedType(val t: Type) { case (_: UIntType, _: UIntType) => true case (_: SIntType, _: SIntType) => true case (ClockType, ClockType) => true + case (_: FixedType, _: FixedType) => true // Analog totally skips out of the Firrtl type system. // The only way Analog can play with another Analog component is through Attach. // Ohterwise, we'd need to special case it during ExpandWhens, Lowering, diff --git a/src/main/scala/firrtl/ir/IR.scala b/src/main/scala/firrtl/ir/IR.scala index 3d65e3b1..f5b80ac6 100644 --- a/src/main/scala/firrtl/ir/IR.scala +++ b/src/main/scala/firrtl/ir/IR.scala @@ -127,6 +127,16 @@ case class SIntLiteral(value: BigInt, width: Width) extends Literal { def mapType(f: Type => Type): Expression = this def mapWidth(f: Width => Width): Expression = SIntLiteral(value, f(width)) } +case class FixedLiteral(value: BigInt, width: Width, point: Width) extends Literal { + def tpe = FixedType(width, point) + def serialize = { + val pstring = if(point == UnknownWidth) "" else s"<${point.serialize}>" + s"Fixed${width.serialize}$pstring(" + Utils.serialize(value) + ")" + } + def mapExpr(f: Expression => Expression): Expression = this + def mapType(f: Type => Type): Expression = this + def mapWidth(f: Width => Width): Expression = FixedLiteral(value, f(width), f(point)) +} case class DoPrim(op: PrimOp, args: Seq[Expression], consts: Seq[BigInt], tpe: Type) extends Expression { def serialize: String = op.serialize + "(" + (args.map(_.serialize) ++ consts.map(_.toString)).mkString(", ") + ")" @@ -381,6 +391,13 @@ case class SIntType(width: Width) extends GroundType { def serialize: String = "SInt" + width.serialize def mapWidth(f: Width => Width): Type = SIntType(f(width)) } +case class FixedType(width: Width, point: Width) extends GroundType { + override def serialize: String = { + val pstring = if(point == UnknownWidth) "" else s"<${point.serialize}>" + s"Fixed${width.serialize}$pstring" + } + def mapWidth(f: Width => Width): Type = FixedType(f(width), f(point)) +} case class BundleType(fields: Seq[Field]) extends AggregateType { def serialize: String = "{ " + (fields map (_.serialize) mkString ", ") + "}" def mapType(f: Type => Type): Type = diff --git a/src/main/scala/firrtl/passes/.ConvertFixedToSInt.scala.swo b/src/main/scala/firrtl/passes/.ConvertFixedToSInt.scala.swo Binary files differnew file mode 100644 index 00000000..abd7c349 --- /dev/null +++ b/src/main/scala/firrtl/passes/.ConvertFixedToSInt.scala.swo diff --git a/src/main/scala/firrtl/passes/CheckChirrtl.scala b/src/main/scala/firrtl/passes/CheckChirrtl.scala index 17e228eb..504702b5 100644 --- a/src/main/scala/firrtl/passes/CheckChirrtl.scala +++ b/src/main/scala/firrtl/passes/CheckChirrtl.scala @@ -68,22 +68,21 @@ object CheckChirrtl extends Pass { errors append new InvalidLOCException(info, mname) case _ => // Do Nothing } - def checkChirrtlW(info: Info, mname: String)(w: Width): Width = w match { - case w: IntWidth if w.width <= 0 => - errors append new NegWidthException(info, mname) + case w: IntWidth if (w.width < BigInt(0)) => + errors.append(new NegWidthException(info, mname)) w case _ => w } - def checkChirrtlT(info: Info, mname: String)(t: Type): Type = { + def checkChirrtlT(info: Info, mname: String)(t: Type): Type = t map checkChirrtlT(info, mname) match { case t: VectorType if t.size < 0 => errors append new NegVecSizeException(info, mname) - case _ => // Do nothing + t map checkChirrtlW(info, mname) + //case FixedType(width, point) => FixedType(checkChirrtlW(width), point) + case _ => t map checkChirrtlW(info, mname) } - t map checkChirrtlW(info, mname) map checkChirrtlT(info, mname) - } def validSubexp(info: Info, mname: String)(e: Expression): Expression = { e match { diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index e6cd1d04..8eebda2f 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -100,7 +100,7 @@ object CheckHighForm extends Pass { correctNum(Option(2), 0) case AsUInt | AsSInt | AsClock | Cvt | Neq | Not => correctNum(Option(1), 0) - case Pad | Shl | Shr | Head | Tail => + case AsFixedPoint | Pad | Shl | Shr | Head | Tail | BPShl | BPShr | BPSet => correctNum(Option(1), 1) case Bits => correctNum(Option(1), 2) @@ -132,21 +132,20 @@ object CheckHighForm extends Pass { def checkHighFormW(info: Info, mname: String)(w: Width): Width = { w match { - case wx: IntWidth if wx.width <= 0 => + case wx: IntWidth if wx.width < 0 => errors append new NegWidthException(info, mname) case wx => // Do nothing } w } - def checkHighFormT(info: Info, mname: String)(t: Type): Type = { - t match { - case tx: VectorType if tx.size < 0 => + def checkHighFormT(info: Info, mname: String)(t: Type): Type = + t map checkHighFormT(info, mname) match { + case tx: VectorType if tx.size < 0 => errors append new NegVecSizeException(info, mname) - case _ => // Do nothing + t + case _ => t map checkHighFormW(info, mname) } - t map checkHighFormW(info, mname) map checkHighFormT(info, mname) - } def validSubexp(info: Info, mname: String)(e: Expression): Expression = { e match { @@ -265,6 +264,7 @@ object CheckTypes extends Pass { s"$info: [module $mname] Primop $op requires all arguments to be UInt type.") class OpNotAllSameType(info: Info, mname: String, op: String) extends PassException( s"$info: [module $mname] Primop $op requires all operands to have the same type.") + class OpNoMixFix(info:Info, mname: String, op: String) extends PassException(s"${info}: [module ${mname}] Primop ${op} cannot operate on args of some, but not all, fixed type.") class OpNotAnalog(info: Info, mname: String, exp: String) extends PassException( s"$info: [module $mname] Attach requires all arguments to be Analog type: $exp.") class NodePassiveType(info: Info, mname: String) extends PassException( @@ -304,12 +304,39 @@ object CheckTypes extends Pass { if (ls exists (x => wt(ls.head.tpe) != wt(e.tpe))) errors append new OpNotAllSameType(info, mname, e.op.serialize) } - def allInt(ls: Seq[Expression]) { + def allUSC(ls: Seq[Expression]) { + val error = ls.foldLeft(false)((error, x) => x.tpe match { + case (_: UIntType| _: SIntType| ClockType) => error + case _ => true + }) + if (error) errors.append(new OpNotGround(info, mname, e.op.serialize)) + } + def allUSF(ls: Seq[Expression]) { + val error = ls.foldLeft(false)((error, x) => x.tpe match { + case (_: UIntType| _: SIntType| _: FixedType) => error + case _ => true + }) + if (error) errors.append(new OpNotGround(info, mname, e.op.serialize)) + } + def allUS(ls: Seq[Expression]) { if (ls exists (x => x.tpe match { case _: UIntType | _: SIntType => false case _ => true })) errors append new OpNotGround(info, mname, e.op.serialize) } + def allF(ls: Seq[Expression]) { + val error = ls.foldLeft(false)((error, x) => x.tpe match { + case _:FixedType => error + case _ => true + }) + if (error) errors.append(new OpNotGround(info, mname, e.op.serialize)) + } + def strictFix(ls: Seq[Expression]) = + ls.filter(!_.tpe.isInstanceOf[FixedType]).size match { + case 0 => + case x if(x == ls.size) => + case x => errors.append(new OpNoMixFix(info, mname, e.op.serialize)) + } def all_uint (ls: Seq[Expression]) { if (ls exists (x => x.tpe match { case _: UIntType => false @@ -323,10 +350,14 @@ object CheckTypes extends Pass { }) errors append new OpNotUInt(info, mname, e.op.serialize, x.serialize) } e.op match { - case AsUInt | AsSInt | AsClock => - case Dshl => is_uint(e.args(1)); allInt(e.args) - case Dshr => is_uint(e.args(1)); allInt(e.args) - case _ => allInt(e.args) + case AsUInt | AsSInt | AsFixedPoint => + case AsClock => allUSC(e.args) + case Dshl => is_uint(e.args(1)); allUSF(e.args) + case Dshr => is_uint(e.args(1)); allUSF(e.args) + case Add | Sub | Mul | Lt | Leq | Gt | Geq | Eq | Neq => allUSF(e.args); strictFix(e.args) + case Pad | Shl | Shr | Cat | Bits | Head | Tail => allUSF(e.args) + case BPShl | BPShr | BPSet => allF(e.args) + case _ => allUS(e.args) } } @@ -383,6 +414,7 @@ object CheckTypes extends Pass { case (ClockType, ClockType) => flip1 == flip2 case (_: UIntType, _: UIntType) => flip1 == flip2 case (_: SIntType, _: SIntType) => flip1 == flip2 + case (_: FixedType, _: FixedType) => flip1 == flip2 case (_: AnalogType, _: AnalogType) => false case (t1: BundleType, t2: BundleType) => val t1_fields = (t1.fields foldLeft Map[String, (Type, Orientation)]())( @@ -583,7 +615,7 @@ object CheckWidths extends Pass { def check_width_w(info: Info, mname: String)(w: Width): Width = { w match { - case w: IntWidth if w.width > 0 => + case w: IntWidth if w.width >= 0 => case _: IntWidth => errors append new NegWidthException(info, mname) case _ => @@ -592,6 +624,9 @@ object CheckWidths extends Pass { w } + def check_width_t(info: Info, mname: String)(t: Type): Type = + t map check_width_t(info, mname) map check_width_w(info, mname) + def check_width_e(info: Info, mname: String)(e: Expression): Expression = { e match { case e: UIntLiteral => e.width match { @@ -614,26 +649,25 @@ object CheckWidths extends Pass { errors append new WidthTooBig(info, mname) case _ => } - e map check_width_w(info, mname) map check_width_e(info, mname) + //e map check_width_t(info, mname) map check_width_e(info, mname) + e map check_width_e(info, mname) } + def check_width_s(minfo: Info, mname: String)(s: Statement): Statement = { val info = get_info(s) match { case NoInfo => minfo case x => x } - s map check_width_e(info, mname) map check_width_s(info, mname) match { - case Attach(infox, source, exprs) => + s map check_width_e(info, mname) map check_width_s(info, mname) map check_width_t(info, mname) match { + case Attach(infox, source, exprs) => exprs foreach ( e => if (bitWidth(e.tpe) != bitWidth(source.tpe)) errors append new AttachWidthsNotEqual(infox, mname, e.serialize, source.serialize) ) s case _ => s - } + } } - def check_width_p(minfo: Info, mname: String)(p: Port): Port = { - p.tpe map check_width_w(p.info, mname) - p - } + def check_width_p(minfo: Info, mname: String)(p: Port): Port = p.copy(tpe = check_width_t(p.info, mname)(p.tpe)) def check_width_m(m: DefModule) { m map check_width_p(m.info, m.name) map check_width_s(m.info, m.name) diff --git a/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala new file mode 100644 index 00000000..8cf0b890 --- /dev/null +++ b/src/main/scala/firrtl/passes/ConvertFixedToSInt.scala @@ -0,0 +1,119 @@ +package firrtl.passes + +import scala.collection.mutable +import firrtl.PrimOps._ +import firrtl.ir._ +import firrtl._ +import firrtl.Mappers._ +import firrtl.Utils.{sub_type, module_type, field_type, BoolType, max, min, pow_minus_one} + +/** Replaces FixedType with SIntType, and correctly aligns all binary points + */ +object ConvertFixedToSInt extends Pass { + def name = "Convert Fixed Types to SInt Types" + def alignArg(e: Expression, point: BigInt): Expression = e.tpe match { + case FixedType(IntWidth(w), IntWidth(p)) => // assert(point >= p) + if((point - p) > 0) { + DoPrim(Shl, Seq(e), Seq(point - p), UnknownType) + } else if (point - p < 0) { + DoPrim(Shr, Seq(e), Seq(p - point), UnknownType) + } else e + case FixedType(w, p) => error("Shouldn't be here") + case _ => e + } + def calcPoint(es: Seq[Expression]): BigInt = + es.map(_.tpe match { + case FixedType(IntWidth(w), IntWidth(p)) => p + case _ => BigInt(0) + }).reduce(max(_, _)) + def toSIntType(t: Type): Type = t match { + case FixedType(IntWidth(w), IntWidth(p)) => SIntType(IntWidth(w)) + case FixedType(w, p) => error("Shouldn't be here") + case _ => t + } + def run(c: Circuit): Circuit = { + val moduleTypes = mutable.HashMap[String,Type]() + def onModule(m:DefModule) : DefModule = { + val types = mutable.HashMap[String,Type]() + def updateExpType(e:Expression): Expression = e match { + case DoPrim(Mul, args, consts, tpe) => e map updateExpType + case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) + case DoPrim(BPShl, args, consts, tpe) => DoPrim(Shl, args, consts, tpe) + case DoPrim(BPShr, args, consts, tpe) => DoPrim(Shr, args, consts, tpe) + case DoPrim(BPSet, args, consts, FixedType(w, IntWidth(p))) => alignArg(args.head, p) + case DoPrim(op, args, consts, tpe) => + val point = calcPoint(args) + val newExp = DoPrim(op, args.map(x => alignArg(x, point)), consts, UnknownType) + newExp map updateExpType match { + case DoPrim(AsFixedPoint, args, consts, tpe) => DoPrim(AsSInt, args, Seq.empty, tpe) + case e => e + } + case Mux(cond, tval, fval, tpe) => + val point = calcPoint(Seq(tval, fval)) + val newExp = Mux(cond, alignArg(tval, point), alignArg(fval, point), UnknownType) + newExp map updateExpType + case e: UIntLiteral => e + case e: SIntLiteral => e + case _ => e map updateExpType match { + case ValidIf(cond, value, tpe) => ValidIf(cond, value, value.tpe) + case WRef(name, tpe, k, g) => WRef(name, types(name), k, g) + case WSubField(exp, name, tpe, g) => WSubField(exp, name, field_type(exp.tpe, name), g) + case WSubIndex(exp, value, tpe, g) => WSubIndex(exp, value, sub_type(exp.tpe), g) + case WSubAccess(exp, index, tpe, g) => WSubAccess(exp, index, sub_type(exp.tpe), g) + } + } + def updateStmtType(s: Statement): Statement = s match { + case DefRegister(info, name, tpe, clock, reset, init) => + val newType = toSIntType(tpe) + types(name) = newType + DefRegister(info, name, newType, clock, reset, init) map updateExpType + case DefWire(info, name, tpe) => + val newType = toSIntType(tpe) + types(name) = newType + DefWire(info, name, newType) + case DefNode(info, name, value) => + val newValue = updateExpType(value) + val newType = toSIntType(newValue.tpe) + types(name) = newType + DefNode(info, name, newValue) + case DefMemory(info, name, dt, depth, wL, rL, rs, ws, rws, ruw) => + val newStmt = DefMemory(info, name, toSIntType(dt), depth, wL, rL, rs, ws, rws, ruw) + val newType = MemPortUtils.memType(newStmt) + types(name) = newType + newStmt + case WDefInstance(info, name, module, tpe) => + val newType = moduleTypes(module) + types(name) = newType + WDefInstance(info, name, module, newType) + case Connect(info, loc, exp) => + val point = calcPoint(Seq(loc)) + val newExp = alignArg(exp, point) + Connect(info, loc, newExp) map updateExpType + case PartialConnect(info, loc, exp) => + val point = calcPoint(Seq(loc)) + val newExp = alignArg(exp, point) + PartialConnect(info, loc, newExp) map updateExpType + // check Connect case, need to shl + case s => (s map updateStmtType) map updateExpType + } + + m.ports.foreach(p => types(p.name) = p.tpe) + m match { + case Module(info, name, ports, body) => Module(info,name,ports,updateStmtType(body)) + case m:ExtModule => m + } + } + + val newModules = for(m <- c.modules) yield { + val newPorts = m.ports.map(p => Port(p.info,p.name,p.direction,toSIntType(p.tpe))) + m match { + case Module(info, name, ports, body) => Module(info,name,newPorts,body) + case ExtModule(info, name, ports) => ExtModule(info,name,newPorts) + } + } + newModules.foreach(m => moduleTypes(m.name) = module_type(m)) + firrtl.passes.InferTypes.run(Circuit(c.info, newModules.map(onModule(_)), c.main )) + } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/main/scala/firrtl/passes/InferWidths.scala b/src/main/scala/firrtl/passes/InferWidths.scala index 8fd5eef1..ac786386 100644 --- a/src/main/scala/firrtl/passes/InferWidths.scala +++ b/src/main/scala/firrtl/passes/InferWidths.scala @@ -200,14 +200,19 @@ object InferWidths extends Pass { def run (c: Circuit): Circuit = { val v = ArrayBuffer[WGeq]() - def get_constraints_t(t1: Type, t2: Type, f: Orientation): Seq[WGeq] = (t1,t2) match { + def get_constraints_t(t1: Type, t2: Type): Seq[WGeq] = (t1,t2) match { case (t1: UIntType, t2: UIntType) => Seq(WGeq(t1.width, t2.width)) case (t1: SIntType, t2: SIntType) => Seq(WGeq(t1.width, t2.width)) + case (ClockType, ClockType) => Nil + case (FixedType(w1, p1), FixedType(w2, p2)) => Seq(WGeq(w1,w2), WGeq(p1,p2)) case (t1: BundleType, t2: BundleType) => (t1.fields zip t2.fields foldLeft Seq[WGeq]()){case (res, (f1, f2)) => - res ++ get_constraints_t(f1.tpe, f2.tpe, times(f1.flip, f)) + res ++ (f1.flip match { + case Default => get_constraints_t(f1.tpe, f2.tpe) + case Flip => get_constraints_t(f2.tpe, f1.tpe) + }) } - case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe, f) + case (t1: VectorType, t2: VectorType) => get_constraints_t(t1.tpe, t2.tpe) } def get_constraints_e(e: Expression): Expression = { @@ -221,38 +226,44 @@ object InferWidths extends Pass { e map get_constraints_e } + def get_constraints_declared_type (t: Type): Type = t match { + case FixedType(_, p) => + v += WGeq(p,IntWidth(0)) + t + case _ => t map get_constraints_declared_type + } + def get_constraints_s(s: Statement): Statement = { - s match { + s map get_constraints_declared_type match { case (s: Connect) => val n = get_size(s.loc.tpe) val locs = create_exps(s.loc) val exps = create_exps(s.expr) - v ++= ((locs zip exps).zipWithIndex map {case ((locx, expx), i) => + v ++= ((locs zip exps).zipWithIndex flatMap {case ((locx, expx), i) => get_flip(s.loc.tpe, i, Default) match { - case Default => WGeq(getWidth(locx), getWidth(expx)) - case Flip => WGeq(getWidth(expx), getWidth(locx)) + case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx)) + case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx)) } }) case (s: PartialConnect) => val ls = get_valid_points(s.loc.tpe, s.expr.tpe, Default, Default) val locs = create_exps(s.loc) val exps = create_exps(s.expr) - v ++= (ls map {case (x, y) => + v ++= (ls flatMap {case (x, y) => val locx = locs(x) val expx = exps(y) get_flip(s.loc.tpe, x, Default) match { - case Default => WGeq(getWidth(locx), getWidth(expx)) - case Flip => WGeq(getWidth(expx), getWidth(locx)) + case Default => get_constraints_t(locx.tpe, expx.tpe)//WGeq(getWidth(locx), getWidth(expx)) + case Flip => get_constraints_t(expx.tpe, locx.tpe)//WGeq(getWidth(expx), getWidth(locx)) } }) - case (s:DefRegister) => v ++= (Seq( - WGeq(getWidth(s.reset), IntWidth(1)), - WGeq(IntWidth(1), getWidth(s.reset)) - ) ++ get_constraints_t(s.tpe, s.init.tpe, Default)) - case (s:Conditionally) => v ++= Seq( - WGeq(getWidth(s.pred), IntWidth(1)), - WGeq(IntWidth(1), getWidth(s.pred)) - ) + case (s: DefRegister) => v ++= ( + get_constraints_t(s.reset.tpe, UIntType(IntWidth(1))) ++ + get_constraints_t(UIntType(IntWidth(1)), s.reset.tpe) ++ + get_constraints_t(s.tpe, s.init.tpe)) + case (s:Conditionally) => v ++= + get_constraints_t(s.pred.tpe, UIntType(IntWidth(1))) ++ + get_constraints_t(UIntType(IntWidth(1)), s.pred.tpe) case (s: Attach) => v += WGeq(getWidth(s.source), MaxWidth(s.exprs map (e => getWidth(e.tpe)))) case _ => @@ -261,6 +272,7 @@ object InferWidths extends Pass { } c.modules foreach (_ map get_constraints_s) + c.modules foreach (_.ports foreach {p => get_constraints_declared_type(p.tpe)}) //println("======== ALL CONSTRAINTS ========") //for(x <- v) println(x) diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 11ea4951..0c4642e4 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -217,6 +217,7 @@ object Legalize extends Pass { expr.args.head match { case UIntLiteral(value, _) => UIntLiteral((value >> low.toInt) & mask, width) case SIntLiteral(value, _) => SIntLiteral((value >> low.toInt) & mask, width) + //case FixedLiteral case _ => expr } } @@ -237,6 +238,7 @@ object Legalize extends Pass { val expr = t match { case UIntType(_) => bits case SIntType(_) => DoPrim(AsSInt, Seq(bits), Seq(), SIntType(IntWidth(w))) + //case FixedType(width, point) => FixedType(width, point) } Connect(c.info, c.loc, expr) } diff --git a/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala new file mode 100644 index 00000000..85f34606 --- /dev/null +++ b/src/test/scala/firrtlTests/fixed/FixedPointMathSpec.scala @@ -0,0 +1,118 @@ +// See LICENSE for license details. + +package firrtlTests.fixed + +import java.io.StringWriter + +import firrtl.Annotations.AnnotationMap +import firrtl.{LowFirrtlCompiler, Parser} +import firrtl.Parser.IgnoreInfo +import firrtlTests.FirrtlFlatSpec + +class FixedPointMathSpec extends FirrtlFlatSpec { + def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) + + "Fixed types" should "parse" in { + val SumPattern = """.*output sum.*<(\d+)>.*.*""".r + val ProductPattern = """.*output product.*<(\d+)>.*""".r + val DifferencePattern = """.*output difference.*<(\d+)>.*""".r + + val AssignPattern = """\s*(\w+) <= (\w+)\((.*)\)\s*""".r + + for { + bits1 <- 1 to 4 +// binaryPoint1 <- -4 to 4 + binaryPoint1 <- 1 to 4 + bits2 <- 1 to 4 +// binaryPoint2 <- -4 to 4 + binaryPoint2 <- 1 to 4 + } { + val input = + s"""circuit Unit : + | module Unit : + | input a : Fixed<$bits1><<$binaryPoint1>> + | input b : Fixed<$bits2><<$binaryPoint2>> + | output sum : Fixed + | output product : Fixed + | output difference : Fixed + | sum <= add(a, b) + | product <= mul(a, b) + | difference <= sub(a, b) + | """. + stripMargin + + val lowerer = new LowFirrtlCompiler + + val writer = new StringWriter() + + lowerer.compile(parse(input), new AnnotationMap(Seq.empty), writer) + + val output = writer.toString.split("\n") + + def config = s"($bits1,$binaryPoint1)($bits2,$binaryPoint2)" + + def inferredAddWidth: Int = { + val binaryDifference = binaryPoint1 - binaryPoint2 + val (newW1, newW2) = if(binaryDifference > 0) { + (bits1, bits2 + binaryDifference) + } else { + (bits1 + binaryDifference.abs, bits2) + } + newW1.max(newW2) + 1 + } + + println(s"Test for configuratio $config") + + for(line <- output) { + line match { + case SumPattern(varWidth) => + assert(varWidth.toInt === inferredAddWidth, s"$config sum sint bits wrong for $line") + case ProductPattern(varWidth) => + assert(varWidth.toInt === bits1 + bits2, s"$config product bits wrong for $line") + case DifferencePattern(varWidth) => + assert(varWidth.toInt === inferredAddWidth, s"$config difference bits wrong for $line") + case AssignPattern(varName, operation, args) => + varName match { + case "sum" => + assert(operation === "add", s"var sum should be result of an add in $line") + if (binaryPoint1 > binaryPoint2) { + assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") + assert(args.contains(s"shl(b, ${binaryPoint1 - binaryPoint2})"), + s"$config second arg incorrect in $line") + } else if (binaryPoint1 < binaryPoint2) { + assert(args.contains(s"shl(a, ${(binaryPoint1 - binaryPoint2).abs})"), + s"$config second arg incorrect in $line") + assert(!args.contains("shl(b"), s"$config second arg should be just b in $line") + } else { + assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") + assert(!args.contains("shl(b"), s"$config second arg should be just b in $line") + } + case "product" => + assert(operation === "mul", s"var sum should be result of an add in $line") + assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") + assert(!args.contains("shl(b"), s"$config second arg should be just b in $line") + case "difference" => + assert(operation === "sub", s"var difference should be result of an sub in $line") + if (binaryPoint1 > binaryPoint2) { + assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") + assert(args.contains(s"shl(b, ${binaryPoint1 - binaryPoint2})"), + s"$config second arg incorrect in $line") + } else if (binaryPoint1 < binaryPoint2) { + assert(args.contains(s"shl(a, ${(binaryPoint1 - binaryPoint2).abs})"), + s"$config second arg incorrect in $line") + assert(!args.contains("shl(b"), s"$config second arg should be just b in $line") + } else { + assert(!args.contains("shl(a"), s"$config first arg should be just a in $line") + assert(!args.contains("shl(b"), s"$config second arg should be just b in $line") + } + case _ => + } + case _ => +// println(s"No pattern found for ${line}") + } + } + + println(writer.toString) + } + } +} diff --git a/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala new file mode 100644 index 00000000..7144f98a --- /dev/null +++ b/src/test/scala/firrtlTests/fixed/FixedTypeInferenceSpec.scala @@ -0,0 +1,321 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtlTests +package fixed + +import firrtl._ +import firrtl.ir.Circuit +import firrtl.passes._ +import firrtl.Parser.IgnoreInfo + +class FixedTypeInferenceSpec extends FirrtlFlatSpec { + def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) + private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Pass) => p.run(c) + } + val lines = c.serialize.split("\n") map normalized + + for(l <- lines) { + println(l) + } + expected foreach { e => + lines should contain(e) + } + } + + "Fixed types" should "infer add correctly" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | input b : Fixed<10> + | input c : Fixed<4><<3>> + | output d : Fixed + | d <= add(a, add(b, c))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | input b : Fixed<10><<0>> + | input c : Fixed<4><<3>> + | output d : Fixed<15><<3>> + | d <= add(a, add(b, c))""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "be correctly shifted left" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed + | d <= shl(a, 2)""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed<12><<2>> + | d <= shl(a, 2)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "be correctly shifted right" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed + | d <= shr(a, 2)""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed<8><<2>> + | d <= shr(a, 2)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "relatively move binary point left" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed + | d <= bpshl(a, 2)""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed<12><<4>> + | d <= bpshl(a, 2)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "relatively move binary point right" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed + | d <= bpshr(a, 2)""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed<8><<0>> + | d <= bpshr(a, 2)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "absolutely set binary point correctly" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed + | d <= bpset(a, 3)""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed<11><<3>> + | d <= bpset(a, 3)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "cat, head, tail, bits" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | input b : Fixed<7><<3>> + | input c : UInt<2> + | output cat : Fixed + | output head : UInt + | output tail : UInt + | output bits : UInt + | cat <= cat(a, c) + | head <= head(a, 3) + | tail <= tail(a, 3) + | bits <= bits(a, 6, 3)""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | input b : Fixed<7><<3>> + | input c : UInt<2> + | output cat : Fixed<12><<4>> + | output head : UInt<3> + | output tail : UInt<7> + | output bits : UInt<4> + | cat <= cat(a, c) + | head <= head(a, 3) + | tail <= tail(a, 3) + | bits <= bits(a, 6, 3)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "be cast to" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths) + val input = + """circuit Unit : + | module Unit : + | input a : SInt<10> + | output d : Fixed + | d <= asFixedPoint(a, 2)""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : SInt<10> + | output d : Fixed<10><<2>> + | d <= asFixedPoint(a, 2)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "support binary point of zero" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths, + ConvertFixedToSInt) + val input = + """ + |circuit Unit : + | module Unit : + | input clk : Clock + | input reset : UInt<1> + | input io_in : Fixed<6><<0>> + | output io_out : Fixed<6><<0>> + | + | io_in is invalid + | io_out is invalid + | io_out <= io_in + """.stripMargin + val check = + """ + |circuit Unit : + | module Unit : + | input clk : Clock + | input reset : UInt<1> + | input io_in : SInt<6> + | output io_out : SInt<6> + | + | io_out <= io_in + | + """.stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } +} + +// vim: set ts=4 sw=4 et: diff --git a/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala new file mode 100644 index 00000000..d9d6dd27 --- /dev/null +++ b/src/test/scala/firrtlTests/fixed/RemoveFixedTypeSpec.scala @@ -0,0 +1,221 @@ +/* +Copyright (c) 2014 - 2016 The Regents of the University of +California (Regents). All Rights Reserved. Redistribution and use in +source and binary forms, with or without modification, are permitted +provided that the following conditions are met: + * Redistributions of source code must retain the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer. + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + two paragraphs of disclaimer in the documentation and/or other materials + provided with the distribution. + * Neither the name of the Regents nor the names of its contributors + may be used to endorse or promote products derived from this + software without specific prior written permission. +IN NO EVENT SHALL REGENTS BE LIABLE TO ANY PARTY FOR DIRECT, INDIRECT, +SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, INCLUDING LOST PROFITS, +ARISING OUT OF THE USE OF THIS SOFTWARE AND ITS DOCUMENTATION, EVEN IF +REGENTS HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +REGENTS SPECIFICALLY DISCLAIMS ANY WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE. THE SOFTWARE AND ACCOMPANYING DOCUMENTATION, IF +ANY, PROVIDED HEREUNDER IS PROVIDED "AS IS". REGENTS HAS NO OBLIGATION +TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR +MODIFICATIONS. +*/ + +package firrtlTests +package fixed + +import firrtl.Annotations.AnnotationMap +import firrtl._ +import firrtl.ir.Circuit +import firrtl.passes._ +import firrtl.Parser.IgnoreInfo + +class RemoveFixedTypeSpec extends FirrtlFlatSpec { + def parse (input:String) = Parser.parse(input.split("\n").toIterator, IgnoreInfo) + private def executeTest(input: String, expected: Seq[String], passes: Seq[Pass]) = { + val c = passes.foldLeft(Parser.parse(input.split("\n").toIterator)) { + (c: Circuit, p: Pass) => p.run(c) + } + val lines = c.serialize.split("\n") map normalized + + for(l <- lines) { + println(l) + } + expected foreach { e => + lines should contain(e) + } + } + + "Fixed types" should "be removed" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths, + ConvertFixedToSInt) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | input b : Fixed<10> + | input c : Fixed<4><<3>> + | output d : Fixed<<5>> + | d <= add(a, add(b, c))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : SInt<10> + | input b : SInt<10> + | input c : SInt<4> + | output d : SInt<15> + | d <= shl(add(shl(a, 1), add(shl(b, 3), c)), 2)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + "Fixed types" should "be removed, even with a bulk connect" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths, + ConvertFixedToSInt) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | input b : Fixed<10> + | input c : Fixed<4><<3>> + | output d : Fixed<<5>> + | d <- add(a, add(b, c))""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : SInt<10> + | input b : SInt<10> + | input c : SInt<4> + | output d : SInt<15> + | d <- shl(add(shl(a, 1), add(shl(b, 3), c)), 2)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "remove binary point shift correctly" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths, + ConvertFixedToSInt) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed<12><<4>> + | d <= bpshl(a, 2)""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : SInt<10> + | output d : SInt<12> + | d <= shl(a, 2)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "remove binary point shift correctly in reverse" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths, + ConvertFixedToSInt) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed<9><<1>> + | d <= bpshr(a, 1)""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : SInt<10> + | output d : SInt<9> + | d <= shr(a, 1)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed types" should "remove an absolutely set binary point correctly" in { + val passes = Seq( + ToWorkingIR, + CheckHighForm, + ResolveKinds, + InferTypes, + CheckTypes, + ResolveGenders, + CheckGenders, + InferWidths, + CheckWidths, + ConvertFixedToSInt) + val input = + """circuit Unit : + | module Unit : + | input a : Fixed<10><<2>> + | output d : Fixed + | d <= bpset(a, 3)""".stripMargin + val check = + """circuit Unit : + | module Unit : + | input a : SInt<10> + | output d : SInt<11> + | d <= shl(a, 1)""".stripMargin + executeTest(input, check.split("\n") map normalized, passes) + } + + "Fixed point numbers" should "allow binary point to be set to zero at creation" in { + val input = + """ + |circuit Unit : + | module Unit : + | input clk : Clock + | input reset : UInt<1> + | input io_in : Fixed<6><<0>> + | output io_out : Fixed + | + | io_in is invalid + | io_out is invalid + | io_out <= io_in + """.stripMargin + + class CheckChirrtlTransform extends Transform with SimpleRun { + val passSeq = Seq(passes.CheckChirrtl) + def execute (circuit: Circuit, annotationMap: AnnotationMap): TransformResult = + run(circuit, passSeq) + } + + val chirrtlTransform = new CheckChirrtlTransform + chirrtlTransform.execute(parse(input), new AnnotationMap(Seq.empty)) + } +} + +// vim: set ts=4 sw=4 et: |
