diff options
Diffstat (limited to 'src/main/scala')
| -rw-r--r-- | src/main/scala/firrtl/LoweringCompilers.scala | 3 | ||||
| -rw-r--r-- | src/main/scala/firrtl/passes/ZeroWidth.scala | 65 |
2 files changed, 67 insertions, 1 deletions
diff --git a/src/main/scala/firrtl/LoweringCompilers.scala b/src/main/scala/firrtl/LoweringCompilers.scala index ab9f6ea4..c2238c2d 100644 --- a/src/main/scala/firrtl/LoweringCompilers.scala +++ b/src/main/scala/firrtl/LoweringCompilers.scala @@ -68,7 +68,8 @@ class HighFirrtlToMiddleFirrtl extends CoreTransform { passes.ResolveGenders, passes.InferWidths, passes.CheckWidths, - passes.ConvertFixedToSInt) + passes.ConvertFixedToSInt, + passes.ZeroWidth) } /** Expands all aggregate types into many ground-typed components. Must diff --git a/src/main/scala/firrtl/passes/ZeroWidth.scala b/src/main/scala/firrtl/passes/ZeroWidth.scala new file mode 100644 index 00000000..8638ea68 --- /dev/null +++ b/src/main/scala/firrtl/passes/ZeroWidth.scala @@ -0,0 +1,65 @@ +// See LICENSE for license details. + +package firrtl.passes + +import scala.collection.mutable +import firrtl.PrimOps._ +import firrtl.ir._ +import firrtl._ +import firrtl.Mappers._ +import firrtl.Utils.throwInternalError + + +object ZeroWidth extends Pass { + def name = this.getClass.getName + private val ZERO = BigInt(0) + private def removeZero(t: Type): Option[Type] = t match { + case GroundType(IntWidth(ZERO)) => None + case BundleType(fields) => + fields map (f => (f, removeZero(f.tpe))) collect { + case (Field(name, flip, _), Some(t)) => Field(name, flip, t) + } match { + case Nil => None + case seq => Some(BundleType(seq)) + } + case VectorType(t, size) => removeZero(t) map (VectorType(_, size)) + case x => Some(x) + } + private def onExp(e: Expression): Expression = removeZero(e.tpe) match { + case None => e.tpe match { + case UIntType(x) => UIntLiteral(ZERO, IntWidth(BigInt(1))) + case SIntType(x) => SIntLiteral(ZERO, IntWidth(BigInt(1))) + case _ => throwInternalError + } + case Some(t) => + def replaceType(x: Type): Type = t + (e map replaceType) map onExp + } + private def onStmt(s: Statement): Statement = s match { + case sx: IsDeclaration => + var removed = false + def applyRemoveZero(t: Type): Type = removeZero(t) match { + case None => removed = true; t + case Some(tx) => tx + } + val sxx = (sx map onExp) map applyRemoveZero + if(removed) EmptyStmt else sxx + case Connect(info, loc, exp) => removeZero(loc.tpe) match { + case None => EmptyStmt + case Some(t) => Connect(info, loc, onExp(exp)) + } + case sx => sx map onStmt + } + private def onModule(m: DefModule): DefModule = { + val ports = m.ports map (p => (p, removeZero(p.tpe))) collect { + case (Port(info, name, dir, _), Some(t)) => Port(info, name, dir, t) + } + m match { + case ext: ExtModule => ext.copy(ports = ports) + case in: Module => in.copy(ports = ports, body = onStmt(in.body)) + } + } + def run(c: Circuit): Circuit = { + InferTypes.run(c.copy(modules = c.modules map onModule)) + } +} |
