diff options
Diffstat (limited to 'src/main/scala/firrtl/passes/memlib/InferReadWrite.scala')
| -rw-r--r-- | src/main/scala/firrtl/passes/memlib/InferReadWrite.scala | 19 |
1 files changed, 16 insertions, 3 deletions
diff --git a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala index 2d1d7f6b..3494de45 100644 --- a/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala +++ b/src/main/scala/firrtl/passes/memlib/InferReadWrite.scala @@ -8,10 +8,14 @@ import firrtl.ir._ import firrtl.Mappers._ import firrtl.PrimOps._ import firrtl.Utils.{one, zero, BoolType} +import firrtl.options.HasScoptOptions import MemPortUtils.memPortField import firrtl.passes.memlib.AnalysisUtils.{Connects, getConnects, getOrigin} import WrappedExpression.weq import annotations._ +import scopt.OptionParser +import firrtl.stage.RunFirrtlTransformAnnotation + case object InferReadWriteAnnotation extends NoTargetAnnotation @@ -72,10 +76,10 @@ object InferReadWritePass extends Pass { def replaceStmt(repl: Netlist)(s: Statement): Statement = s map replaceStmt(repl) map replaceExp(repl) match { - case Connect(_, EmptyExpression, _) => EmptyStmt + case Connect(_, EmptyExpression, _) => EmptyStmt case sx => sx } - + def inferReadWriteStmt(connects: Connects, repl: Netlist, stmts: Statements) @@ -143,9 +147,18 @@ object InferReadWritePass extends Pass { // Transform input: Middle Firrtl. Called after "HighFirrtlToMidleFirrtl" // To use this transform, circuit name should be annotated with its TransId. -class InferReadWrite extends Transform with SeqTransformBased { +class InferReadWrite extends Transform with SeqTransformBased with HasScoptOptions { def inputForm = MidForm def outputForm = MidForm + + def addOptions(parser: OptionParser[AnnotationSeq]): Unit = parser + .opt[Unit]("infer-rw") + .abbr("firw") + .valueName ("<circuit>") + .action( (_, c) => c ++ Seq(InferReadWriteAnnotation, RunFirrtlTransformAnnotation(new InferReadWrite)) ) + .maxOccurs(1) + .text("Enable readwrite port inference for the target circuit") + def transforms = Seq( InferReadWritePass, CheckInitialization, |
