diff options
Diffstat (limited to 'src/main')
| -rw-r--r-- | src/main/resources/logback.xml | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 7 | ||||
| -rw-r--r-- | src/main/scala/firrtl/PrimOps.scala | 575 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Utils.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Checks.scala | 2 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/Passes.scala | 158 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/RemoveAccesses.scala | 160 |
7 files changed, 470 insertions, 437 deletions
diff --git a/src/main/resources/logback.xml b/src/main/resources/logback.xml index 8c9304c2..d2f8beae 100644 --- a/src/main/resources/logback.xml +++ b/src/main/resources/logback.xml @@ -30,7 +30,7 @@ MODIFICATIONS. <pattern>[%-4level] %msg%n</pattern> </encoder> </appender> - <root level="info"> + <root level="warn"> <appender-ref ref="STDOUT" /> </root> </configuration> diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index 4d7ddfe0..7c239b10 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -108,12 +108,11 @@ class HighFirrtlToMiddleFirrtl () extends Transform with SimpleRun { passes.RemoveAccesses, passes.ExpandWhens, passes.CheckInitialization, - passes.ConstProp, passes.ResolveKinds, passes.InferTypes, - passes.ResolveGenders) - //passes.InferWidths, - //passes.CheckWidths) + passes.ResolveGenders, + passes.InferWidths, + passes.CheckWidths) def execute (circuit: Circuit, annotationMap: AnnotationMap): TransformResult = run(circuit, passSeq) } diff --git a/src/main/scala/firrtl/PrimOps.scala b/src/main/scala/firrtl/PrimOps.scala index 7d7524f6..1bf8947a 100644 --- a/src/main/scala/firrtl/PrimOps.scala +++ b/src/main/scala/firrtl/PrimOps.scala @@ -28,6 +28,7 @@ MODIFICATIONS. package firrtl import firrtl.ir._ +import firrtl.Utils.{max, min, pow_minus_one} import com.typesafe.scalalogging.LazyLogging @@ -111,11 +112,26 @@ object PrimOps extends LazyLogging { // Borrowed from Stanza implementation def set_primop_type (e:DoPrim) : DoPrim = { //println-all(["Inferencing primop type: " e]) - def PLUS (w1:Width,w2:Width) : Width = PlusWidth(w1,w2) - def MAX (w1:Width,w2:Width) : Width = MaxWidth(Seq(w1,w2)) - def MINUS (w1:Width,w2:Width) : Width = MinusWidth(w1,w2) - def POW (w1:Width) : Width = ExpWidth(w1) - def MIN (w1:Width,w2:Width) : Width = MinWidth(Seq(w1,w2)) + def PLUS (w1:Width,w2:Width) : Width = (w1, w2) match { + case (IntWidth(i), IntWidth(j)) => IntWidth(i + j) + case _ => PlusWidth(w1,w2) + } + def MAX (w1:Width,w2:Width) : Width = (w1, w2) match { + case (IntWidth(i), IntWidth(j)) => IntWidth(max(i,j)) + case _ => MaxWidth(Seq(w1,w2)) + } + def MINUS (w1:Width,w2:Width) : Width = (w1, w2) match { + case (IntWidth(i), IntWidth(j)) => IntWidth(i - j) + case _ => MinusWidth(w1,w2) + } + def POW (w1:Width) : Width = w1 match { + case IntWidth(i) => IntWidth(pow_minus_one(BigInt(2), i)) + case _ => ExpWidth(w1) + } + def MIN (w1:Width,w2:Width) : Width = (w1, w2) match { + case (IntWidth(i), IntWidth(j)) => IntWidth(min(i,j)) + case _ => MinWidth(Seq(w1,w2)) + } val o = e.op val a = e.args val c = e.consts @@ -127,282 +143,279 @@ object PrimOps extends LazyLogging { def w3 () = Utils.widthBANG(a(2).tpe) def c1 () = IntWidth(c(0)) def c2 () = IntWidth(c(1)) - e.tpe match { - case UIntType(IntWidth(w)) => e - case SIntType(IntWidth(w)) => e - case _ => o match { - case Add => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => UIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Sub => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Mul => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => UIntType(PLUS(w1(),w2())) - case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),w2())) - case (t1:SIntType, t2:UIntType) => SIntType(PLUS(w1(),w2())) - case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),w2())) - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Div => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => UIntType(w1()) - case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) - case (t1:SIntType, t2:UIntType) => SIntType(w1()) - case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Rem => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => UIntType(MIN(w1(),w2())) - case (t1:UIntType, t2:SIntType) => UIntType(MIN(w1(),w2())) - case (t1:SIntType, t2:UIntType) => SIntType(MIN(w1(),PLUS(w2(),Utils.ONE))) - case (t1:SIntType, t2:SIntType) => SIntType(MIN(w1(),w2())) - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Lt => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => Utils.BoolType - case (t1:SIntType, t2:UIntType) => Utils.BoolType - case (t1:UIntType, t2:SIntType) => Utils.BoolType - case (t1:SIntType, t2:SIntType) => Utils.BoolType - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Leq => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => Utils.BoolType - case (t1:SIntType, t2:UIntType) => Utils.BoolType - case (t1:UIntType, t2:SIntType) => Utils.BoolType - case (t1:SIntType, t2:SIntType) => Utils.BoolType - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Gt => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => Utils.BoolType - case (t1:SIntType, t2:UIntType) => Utils.BoolType - case (t1:UIntType, t2:SIntType) => Utils.BoolType - case (t1:SIntType, t2:SIntType) => Utils.BoolType - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Geq => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => Utils.BoolType - case (t1:SIntType, t2:UIntType) => Utils.BoolType - case (t1:UIntType, t2:SIntType) => Utils.BoolType - case (t1:SIntType, t2:SIntType) => Utils.BoolType - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Eq => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => Utils.BoolType - case (t1:SIntType, t2:UIntType) => Utils.BoolType - case (t1:UIntType, t2:SIntType) => Utils.BoolType - case (t1:SIntType, t2:SIntType) => Utils.BoolType - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Neq => { - val t = (t1(),t2()) match { - case (t1:UIntType, t2:UIntType) => Utils.BoolType - case (t1:SIntType, t2:UIntType) => Utils.BoolType - case (t1:UIntType, t2:SIntType) => Utils.BoolType - case (t1:SIntType, t2:SIntType) => Utils.BoolType - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Pad => { - val t = (t1()) match { - case (t1:UIntType) => UIntType(MAX(w1(),c1())) - case (t1:SIntType) => SIntType(MAX(w1(),c1())) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case AsUInt => { - val t = (t1()) match { - case (t1:UIntType) => UIntType(w1()) - case (t1:SIntType) => UIntType(w1()) - case ClockType => UIntType(Utils.ONE) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case AsSInt => { - val t = (t1()) match { - case (t1:UIntType) => SIntType(w1()) - case (t1:SIntType) => SIntType(w1()) - case ClockType => SIntType(Utils.ONE) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case AsClock => { - val t = (t1()) match { - case (t1:UIntType) => ClockType - case (t1:SIntType) => ClockType - case ClockType => ClockType - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Shl => { - val t = (t1()) match { - case (t1:UIntType) => UIntType(PLUS(w1(),c1())) - case (t1:SIntType) => SIntType(PLUS(w1(),c1())) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Shr => { - val t = (t1()) match { - case (t1:UIntType) => UIntType(MAX(MINUS(w1(),c1()),Utils.ONE)) - case (t1:SIntType) => SIntType(MAX(MINUS(w1(),c1()),Utils.ONE)) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Dshl => { - val t = (t1()) match { - case (t1:UIntType) => UIntType(PLUS(w1(),POW(w2()))) - case (t1:SIntType) => SIntType(PLUS(w1(),POW(w2()))) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Dshr => { - val t = (t1()) match { - case (t1:UIntType) => UIntType(w1()) - case (t1:SIntType) => SIntType(w1()) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Cvt => { - val t = (t1()) match { - case (t1:UIntType) => SIntType(PLUS(w1(),Utils.ONE)) - case (t1:SIntType) => SIntType(w1()) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Neg => { - val t = (t1()) match { - case (t1:UIntType) => SIntType(PLUS(w1(),Utils.ONE)) - case (t1:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Not => { - val t = (t1()) match { - case (t1:UIntType) => UIntType(w1()) - case (t1:SIntType) => UIntType(w1()) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case And => { - val t = (t1(),t2()) match { - case (_:SIntType|_:UIntType, _:SIntType|_:UIntType) => UIntType(MAX(w1(),w2())) - case (t1,t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Or => { - val t = (t1(),t2()) match { - case (_:SIntType|_:UIntType, _:SIntType|_:UIntType) => UIntType(MAX(w1(),w2())) - case (t1,t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Xor => { - val t = (t1(),t2()) match { - case (_:SIntType|_:UIntType, _:SIntType|_:UIntType) => UIntType(MAX(w1(),w2())) - case (t1,t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Andr => { - val t = (t1()) match { - case (_:UIntType|_:SIntType) => Utils.BoolType - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Orr => { - val t = (t1()) match { - case (_:UIntType|_:SIntType) => Utils.BoolType - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Xorr => { - val t = (t1()) match { - case (_:UIntType|_:SIntType) => Utils.BoolType - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Cat => { - val t = (t1(),t2()) match { - case (_:UIntType|_:SIntType,_:UIntType|_:SIntType) => UIntType(PLUS(w1(),w2())) - case (t1, t2) => UnknownType - } - DoPrim(o,a,c,t) - } - case Bits => { - val t = (t1()) match { - case (_:UIntType|_:SIntType) => UIntType(PLUS(MINUS(c1(),c2()),Utils.ONE)) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Head => { - val t = (t1()) match { - case (_:UIntType|_:SIntType) => UIntType(c1()) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - case Tail => { - val t = (t1()) match { - case (_:UIntType|_:SIntType) => UIntType(MINUS(w1(),c1())) - case (t1) => UnknownType - } - DoPrim(o,a,c,t) - } - } + o match { + case Add => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => UIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Sub => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:SIntType, t2:UIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(MAX(w1(),w2()),Utils.ONE)) + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Mul => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => UIntType(PLUS(w1(),w2())) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),w2())) + case (t1:SIntType, t2:UIntType) => SIntType(PLUS(w1(),w2())) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),w2())) + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Div => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => UIntType(w1()) + case (t1:UIntType, t2:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1:SIntType, t2:UIntType) => SIntType(w1()) + case (t1:SIntType, t2:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Rem => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => UIntType(MIN(w1(),w2())) + case (t1:UIntType, t2:SIntType) => UIntType(MIN(w1(),w2())) + case (t1:SIntType, t2:UIntType) => SIntType(MIN(w1(),PLUS(w2(),Utils.ONE))) + case (t1:SIntType, t2:SIntType) => SIntType(MIN(w1(),w2())) + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Lt => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Leq => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Gt => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Geq => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Eq => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Neq => { + val t = (t1(),t2()) match { + case (t1:UIntType, t2:UIntType) => Utils.BoolType + case (t1:SIntType, t2:UIntType) => Utils.BoolType + case (t1:UIntType, t2:SIntType) => Utils.BoolType + case (t1:SIntType, t2:SIntType) => Utils.BoolType + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Pad => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(MAX(w1(),c1())) + case (t1:SIntType) => SIntType(MAX(w1(),c1())) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case AsUInt => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(w1()) + case (t1:SIntType) => UIntType(w1()) + case ClockType => UIntType(Utils.ONE) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case AsSInt => { + val t = (t1()) match { + case (t1:UIntType) => SIntType(w1()) + case (t1:SIntType) => SIntType(w1()) + case ClockType => SIntType(Utils.ONE) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case AsClock => { + val t = (t1()) match { + case (t1:UIntType) => ClockType + case (t1:SIntType) => ClockType + case ClockType => ClockType + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Shl => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(PLUS(w1(),c1())) + case (t1:SIntType) => SIntType(PLUS(w1(),c1())) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Shr => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(MAX(MINUS(w1(),c1()),Utils.ONE)) + case (t1:SIntType) => SIntType(MAX(MINUS(w1(),c1()),Utils.ONE)) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Dshl => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(PLUS(w1(),POW(w2()))) + case (t1:SIntType) => SIntType(PLUS(w1(),POW(w2()))) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Dshr => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(w1()) + case (t1:SIntType) => SIntType(w1()) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Cvt => { + val t = (t1()) match { + case (t1:UIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1:SIntType) => SIntType(w1()) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Neg => { + val t = (t1()) match { + case (t1:UIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1:SIntType) => SIntType(PLUS(w1(),Utils.ONE)) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Not => { + val t = (t1()) match { + case (t1:UIntType) => UIntType(w1()) + case (t1:SIntType) => UIntType(w1()) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case And => { + val t = (t1(),t2()) match { + case (_:SIntType|_:UIntType, _:SIntType|_:UIntType) => UIntType(MAX(w1(),w2())) + case (t1,t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Or => { + val t = (t1(),t2()) match { + case (_:SIntType|_:UIntType, _:SIntType|_:UIntType) => UIntType(MAX(w1(),w2())) + case (t1,t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Xor => { + val t = (t1(),t2()) match { + case (_:SIntType|_:UIntType, _:SIntType|_:UIntType) => UIntType(MAX(w1(),w2())) + case (t1,t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Andr => { + val t = (t1()) match { + case (_:UIntType|_:SIntType) => Utils.BoolType + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Orr => { + val t = (t1()) match { + case (_:UIntType|_:SIntType) => Utils.BoolType + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Xorr => { + val t = (t1()) match { + case (_:UIntType|_:SIntType) => Utils.BoolType + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Cat => { + val t = (t1(),t2()) match { + case (_:UIntType|_:SIntType,_:UIntType|_:SIntType) => UIntType(PLUS(w1(),w2())) + case (t1, t2) => UnknownType + } + DoPrim(o,a,c,t) + } + case Bits => { + val t = (t1()) match { + case (_:UIntType|_:SIntType) => UIntType(PLUS(MINUS(c1(),c2()),Utils.ONE)) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Head => { + val t = (t1()) match { + case (_:UIntType|_:SIntType) => UIntType(c1()) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + case Tail => { + val t = (t1()) match { + case (_:UIntType|_:SIntType) => UIntType(MINUS(w1(),c1())) + case (t1) => UnknownType + } + DoPrim(o,a,c,t) + } + } } diff --git a/src/main/scala/firrtl/Utils.scala b/src/main/scala/firrtl/Utils.scala index 76c8e61e..9d11ca2f 100644 --- a/src/main/scala/firrtl/Utils.scala +++ b/src/main/scala/firrtl/Utils.scala @@ -74,6 +74,9 @@ object Utils extends LazyLogging { implicit def toWrappedExpression (x:Expression) = new WrappedExpression(x) def ceil_log2(x: BigInt): BigInt = (x-1).bitLength def ceil_log2(x: Int): Int = scala.math.ceil(scala.math.log(x) / scala.math.log(2)).toInt + def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b + def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a + def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 val gen_names = Map[String,Int]() val delin = "_" val BoolType = UIntType(IntWidth(1)) diff --git a/src/main/scala/firrtl/passes/Checks.scala b/src/main/scala/firrtl/passes/Checks.scala index 9bbf7fa8..ab6db202 100644 --- a/src/main/scala/firrtl/passes/Checks.scala +++ b/src/main/scala/firrtl/passes/Checks.scala @@ -256,7 +256,7 @@ object CheckHighForm extends Pass with LazyLogging { if (!c.modules.map(_.name).contains(s.module)) errors.append(new ModuleNotDefinedException(s.module)) // Check to see if a recursive module instantiation has occured - val childToParent = moduleGraph.add(mname, s.module) + val childToParent = moduleGraph.add(m.name, s.module) if(childToParent.nonEmpty) { errors.append(new InstanceLoop(childToParent.mkString("->"))) } diff --git a/src/main/scala/firrtl/passes/Passes.scala b/src/main/scala/firrtl/passes/Passes.scala index 120c81a9..bd9563dc 100644 --- a/src/main/scala/firrtl/passes/Passes.scala +++ b/src/main/scala/firrtl/passes/Passes.scala @@ -490,9 +490,6 @@ object InferWidths extends Pass { if(in.isEmpty) Seq(default) else in - def max(a: BigInt, b: BigInt): BigInt = if (a >= b) a else b - def min(a: BigInt, b: BigInt): BigInt = if (a >= b) b else a - def pow_minus_one(a: BigInt, b: BigInt): BigInt = a.pow(b.toInt) - 1 def solve(w: Width): Option[BigInt] = w match { case (w: VarWidth) => @@ -522,14 +519,20 @@ object InferWidths extends Pass { //println-all-debug(["WITH: " wx]) wx } + def reduce_var_widths_s (s: Statement): Statement = { + def onType(t: Type): Type = t map onType map reduce_var_widths_w + s map onType + } val modulesx = c.modules.map{ m => { val portsx = m.ports.map{ p => { Port(p.info,p.name,p.direction,mapr(reduce_var_widths_w _,p.tpe)) }} (m) match { case (m:ExtModule) => ExtModule(m.info,m.name,portsx) - case (m:Module) => mname = m.name; Module(m.info,m.name,portsx,mapr(reduce_var_widths_w _,m.body)) }}} - Circuit(c.info,modulesx,c.main) + case (m:Module) => + mname = m.name + Module(m.info,m.name,portsx,m.body map reduce_var_widths_s _) }}} + InferTypes.run(Circuit(c.info,modulesx,c.main)) } def run (c:Circuit): Circuit = { @@ -744,151 +747,6 @@ object ExpandConnects extends Pass { } } -case class Location(base:Expression,guard:Expression) -object RemoveAccesses extends Pass { - private var mname = "" - def name = "Remove Accesses" - def get_locations (e:Expression) : Seq[Location] = { - e match { - case (e:WRef) => create_exps(e).map(Location(_,one)) - case (e:WSubIndex) => { - val ls = get_locations(e.exp) - val start = get_point(e) - val end = start + get_size(tpe(e)) - val stride = get_size(tpe(e.exp)) - val lsx = ArrayBuffer[Location]() - var c = 0 - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { - lsx += ls(i) - } - } - lsx - } - case (e:WSubField) => { - val ls = get_locations(e.exp) - val start = get_point(e) - val end = start + get_size(tpe(e)) - val stride = get_size(tpe(e.exp)) - val lsx = ArrayBuffer[Location]() - var c = 0 - for (i <- 0 until ls.size) { - if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) } - } - lsx - } - case (e:WSubAccess) => { - val ls = get_locations(e.exp) - val stride = get_size(tpe(e)) - val wrap = tpe(e.exp).asInstanceOf[VectorType].size - val lsx = ArrayBuffer[Location]() - var c = 0 - for (i <- 0 until ls.size) { - if ((c % wrap) == 0) { c = 0 } - val basex = ls(i).base - val guardx = AND(ls(i).guard,EQV(uint(c),e.index)) - lsx += Location(basex,guardx) - if ((i + 1) % stride == 0) { - c = c + 1 - } - } - lsx - } - } - } - def has_access (e:Expression) : Boolean = { - var ret:Boolean = false - def rec_has_access (e:Expression) : Expression = { - e match { - case (e:WSubAccess) => { ret = true; e } - case (e) => e map (rec_has_access) - } - } - rec_has_access(e) - ret - } - def run (c:Circuit): Circuit = { - def remove_m (m:Module) : Module = { - val namespace = Namespace(m) - mname = m.name - def remove_s (s:Statement) : Statement = { - val stmts = ArrayBuffer[Statement]() - def create_temp (e:Expression) : Expression = { - val n = namespace.newTemp - stmts += DefWire(info(s),n,tpe(e)) - WRef(n,tpe(e),kind(e),gender(e)) - } - def remove_e (e:Expression) : Expression = { //NOT RECURSIVE (except primops) INTENTIONALLY! - e match { - case (e:DoPrim) => e map (remove_e) - case (e:Mux) => e map (remove_e) - case (e:ValidIf) => e map (remove_e) - case (e:SIntLiteral) => e - case (e:UIntLiteral) => e - case x => { - val e = x match { - case (w:WSubAccess) => WSubAccess(w.exp,remove_e(w.index),w.tpe,w.gender) - case _ => x - } - if (has_access(e)) { - val rs = get_locations(e) - val foo = rs.find(x => {x.guard != one}) - foo match { - case None => error("Shouldn't be here") - case foo:Some[Location] => { - val temp = create_temp(e) - val temps = create_exps(temp) - def get_temp (i:Int) = temps(i % temps.size) - (rs,0 until rs.size).zipped.foreach { - (x,i) => { - if (i < temps.size) { - stmts += Connect(info(s),get_temp(i),x.base) - } else { - stmts += Conditionally(info(s),x.guard,Connect(info(s),get_temp(i),x.base),EmptyStmt) - } - } - } - temp - } - } - } else { e} - } - } - } - - val sx = s match { - case (s:Connect) => { - if (has_access(s.loc)) { - val ls = get_locations(s.loc) - val locx = - if (ls.size == 1 & weq(ls(0).guard,one)) s.loc - else { - val temp = create_temp(s.loc) - for (x <- ls) { stmts += Conditionally(s.info,x.guard,Connect(s.info,x.base,temp),EmptyStmt) } - temp - } - Connect(s.info,locx,remove_e(s.expr)) - } else { Connect(s.info,s.loc,remove_e(s.expr)) } - } - case (s) => s map (remove_e) map (remove_s) - } - stmts += sx - if (stmts.size != 1) Block(stmts) else stmts(0) - } - Module(m.info,m.name,m.ports,remove_s(m.body)) - } - - val modulesx = c.modules.map{ - m => { - m match { - case (m:ExtModule) => m - case (m:Module) => remove_m(m) - } - } - } - Circuit(c.info,modulesx,c.main) - } -} // Replace shr by amount >= arg width with 0 for UInts and MSB for SInts // TODO replace UInt with zero-width wire instead diff --git a/src/main/scala/firrtl/passes/RemoveAccesses.scala b/src/main/scala/firrtl/passes/RemoveAccesses.scala new file mode 100644 index 00000000..d3340f2d --- /dev/null +++ b/src/main/scala/firrtl/passes/RemoveAccesses.scala @@ -0,0 +1,160 @@ +package firrtl.passes + +import firrtl.ir._ +import firrtl.{WRef, WSubAccess, WSubIndex, WSubField} +import firrtl.Mappers._ +import firrtl.Utils._ +import firrtl.WrappedExpression._ +import firrtl.Namespace +import scala.collection.mutable + + +/** Removes all [[firrtl.WSubAccess]] from circuit + */ +object RemoveAccesses extends Pass { + def name = "Remove Accesses" + + /** Container for a base expression and its corresponding guard + */ + case class Location(base: Expression, guard: Expression) + + /** Walks a referencing expression and returns a list of valid references + * (base) and the corresponding guard which, if true, returns that base. + * E.g. if called on a[i] where a: UInt[2], we would return: + * Seq(Location(a[0], UIntLiteral(0)), Location(a[1], UIntLiteral(1))) + */ + def getLocations(e: Expression): Seq[Location] = e match { + case e: WRef => create_exps(e).map(Location(_,one)) + case e: WSubIndex => + val ls = getLocations(e.exp) + val start = get_point(e) + val end = start + get_size(tpe(e)) + val stride = get_size(tpe(e.exp)) + val lsx = mutable.ArrayBuffer[Location]() + for (i <- 0 until ls.size) { + if (((i % stride) >= start) & ((i % stride) < end)) { + lsx += ls(i) + } + } + lsx + case e: WSubField => + val ls = getLocations(e.exp) + val start = get_point(e) + val end = start + get_size(tpe(e)) + val stride = get_size(tpe(e.exp)) + val lsx = mutable.ArrayBuffer[Location]() + for (i <- 0 until ls.size) { + if (((i % stride) >= start) & ((i % stride) < end)) { lsx += ls(i) } + } + lsx + case e: WSubAccess => + val ls = getLocations(e.exp) + val stride = get_size(tpe(e)) + val wrap = tpe(e.exp).asInstanceOf[VectorType].size + val lsx = mutable.ArrayBuffer[Location]() + for (i <- 0 until ls.size) { + val c = (i / stride) % wrap + val basex = ls(i).base + val guardx = AND(ls(i).guard,EQV(uint(c),e.index)) + lsx += Location(basex,guardx) + } + lsx + } + /** Returns true if e contains a [[firrtl.WSubAccess]] + */ + def hasAccess(e: Expression): Boolean = { + var ret: Boolean = false + def rec_has_access(e: Expression): Expression = e match { + case (e:WSubAccess) => { ret = true; e } + case (e) => e map (rec_has_access) + } + rec_has_access(e) + ret + } + def run(c: Circuit): Circuit = { + def remove_m(m: Module): Module = { + val namespace = Namespace(m) + def onStmt(s: Statement): Statement = { + val stmts = mutable.ArrayBuffer[Statement]() + def create_temp(e: Expression): Expression = { + val n = namespace.newTemp + stmts += DefWire(info(s), n, tpe(e)) + WRef(n, tpe(e), kind(e), gender(e)) + } + + /** Replaces a subaccess in a given male expression + */ + def removeMale(e: Expression): Expression = e match { + case (_:WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(e)) => + val rs = getLocations(e) + val foo = rs.find(x => {x.guard != one}) + foo match { + case None => error("Shouldn't be here") + case foo: Some[Location] => + val temp = create_temp(e) + val temps = create_exps(temp) + def getTemp(i: Int) = temps(i % temps.size) + for((x, i) <- rs.zipWithIndex) { + if (i < temps.size) { + stmts += Connect(info(s),getTemp(i),x.base) + } else { + stmts += Conditionally(info(s),x.guard,Connect(info(s),getTemp(i),x.base),EmptyStmt) + } + } + temp + } + case _ => e + } + + /** Replaces a subaccess in a given female expression + */ + def removeFemale(info: Info, loc: Expression): Expression = loc match { + case (_: WSubAccess| _: WSubField| _: WSubIndex| _: WRef) if (hasAccess(loc)) => + val ls = getLocations(loc) + if (ls.size == 1 & weq(ls(0).guard,one)) loc + else { + val temp = create_temp(loc) + for (x <- ls) { stmts += Conditionally(info,x.guard,Connect(info,x.base,temp),EmptyStmt) } + temp + } + case _ => loc + } + + /** Recursively walks a male expression and fixes all subaccesses + * If we see a sub-access, replace it. + * Otherwise, map to children. + */ + def fixMale(e: Expression): Expression = e match { + case w: WSubAccess => removeMale(WSubAccess(w.exp, fixMale(w.index), w.tpe, w.gender)) + //case w: WSubIndex => removeMale(w) + //case w: WSubField => removeMale(w) + case x => x map fixMale + } + + /** Recursively walks a female expression and fixes all subaccesses + * If we see a sub-access, its index is a male expression, and we must replace it. + * Otherwise, map to children. + */ + def fixFemale(e: Expression): Expression = e match { + case w: WSubAccess => WSubAccess(fixFemale(w.exp), fixMale(w.index), w.tpe, w.gender) + case x => x map fixFemale + } + + val sx = s match { + case Connect(info, loc, exp) => + Connect(info, removeFemale(info, fixFemale(loc)), fixMale(exp)) + case (s) => s map (fixMale) map (onStmt) + } + stmts += sx + if (stmts.size != 1) Block(stmts) else stmts(0) + } + Module(m.info, m.name, m.ports, onStmt(m.body)) + } + + val newModules = c.modules.map( _ match { + case m: ExtModule => m + case m: Module => remove_m(m) + }) + Circuit(c.info, newModules, c.main) + } +} |
