diff options
| author | Adam Izraelevitz | 2016-01-16 11:42:05 -0800 |
|---|---|---|
| committer | Adam Izraelevitz | 2016-01-16 11:42:05 -0800 |
| commit | b70bce82f6de8ff18fca6d2744f7bf914a37c37b (patch) | |
| tree | 9c8979d0f6107dc26417b9d0d305475af83964a8 /src | |
| parent | 2beab33ac298470bc04caf1c3b7a5a0d17d465d4 (diff) | |
| parent | f00e3c651c816fa8e5eb2f2154f7374602aa8c83 (diff) | |
Merge pull request #52 from ucb-bar/scala-paul
Add a renameall pass that renames nodes according to a user-provided map.
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/firrtl/Driver.scala | 151 | ||||
| -rw-r--r-- | src/main/scala/firrtl/Passes.scala | 76 |
2 files changed, 171 insertions, 56 deletions
diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index 141a326a..4b3b2967 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -17,13 +17,13 @@ object Driver // Appends 0 to the filename and appends .tmp to the extension private def genTempFilename(filename: String): String = { - val pat = """(.*/)([^/]*)([.][^/.]*)""".r + val pat = """(.*/)([^/]*)([.][^/.]*)""".r val (path, name, ext) = filename match { case pat(path, name, ext) => (path, name, ext + ".tmp") case _ => ("./", "temp", ".tmp") } var count = 0 - while( Files.exists(Paths.get(path + name + count + ext )) ) + while( Files.exists(Paths.get(path + name + count + ext )) ) count += 1 path + name + count + ext } @@ -49,58 +49,101 @@ object Driver executePassesWithLogger(ast, passes) } + trait Pass + case class StanzaPass(val name : String) extends Pass + case class AggregatedStanzaPass(val passes : Seq[StanzaPass]) extends Pass + case class ScalaPass(val func : Circuit => Circuit) extends Pass + + def aggregateStanzaPasses(l : Seq[Pass]) : Seq[Pass] = { + if (l.isEmpty) return Seq() + val span = l.span(x => x match { + case p : StanzaPass => true + case _ => false + }) + if (span._1.isEmpty) { + val tail = if(span._2.length > 1) + aggregateStanzaPasses(span._2.tail) + else + Seq() + Seq(span._2.head) ++ tail + } else { + Seq(AggregatedStanzaPass(span._1.asInstanceOf[Seq[StanzaPass]])) ++ aggregateStanzaPasses(span._2) + } + } + + def run(pass : Pass, input : String, output : String)(implicit logger : Logger) : Unit = pass match { + case p : StanzaPass => + val cmd = Seq("firrtl-stanza", "-i", input, "-o", output, "-b", "firrtl", "-x", p.name) + println(cmd.mkString(" ")) + val ret = cmd.!! + println(ret) + case p : AggregatedStanzaPass => + val cmd = Seq("firrtl-stanza", "-i", input, "-o", output, "-b", "firrtl") ++ p.passes.flatMap(x=>Seq("-x", x.name)) + println(cmd.mkString(" ")) + val ret = cmd.!! + println(ret) + case p : ScalaPass => + var ast = Parser.parse(input, Source.fromFile(input).getLines) + val newast = p.func(ast) + println("Writing to " + output) + val writer = new PrintWriter(new File(output)) + writer.write(newast.serialize()) + writer.close() + case _ => logger.warn("Pass " + pass + " cannot be run") + } + private def verilog(input: String, output: String)(implicit logger: Logger) { - val stanzaPass = //List( - List("rem-spec-chars", "high-form-check", - "temp-elim", "to-working-ir", "resolve-kinds", "infer-types", - "resolve-genders", "check-genders", "check-kinds", "check-types", - "expand-accessors", "lower-to-ground", "inline-indexers", "infer-types", - "check-genders", "expand-whens", "infer-widths", "real-ir", "width-check", - "pad-widths", "const-prop", "split-expressions", "width-check", - "high-form-check", "low-form-check", "check-init") - //) - val scalaPass = List(List[String]()) - - val mapString2Pass = Map[String, Circuit => Circuit] ( - "infer-types" -> inferTypes - ) - - //if (stanza.isEmpty || !Files.exists(Paths.get(stanza))) - // throw new FileNotFoundException("Stanza binary not found! " + stanza) - - // For now, just use the stanza implementation in its entirety - val cmd = Seq("firrtl-stanza", "-i", input, "-o", output, "-b", "verilog") ++ stanzaPass.flatMap(Seq("-x", _)) + + val passes = aggregateStanzaPasses(Seq( + StanzaPass("rem-spec-chars"), + StanzaPass("high-form-check"), + ScalaPass(renameall(Map( + "c"->"ccc", + "z"->"zzz", + "top"->"its_a_top_module" + ))), + StanzaPass("temp-elim"), + StanzaPass("to-working-ir"), + StanzaPass("resolve-kinds"), + StanzaPass("infer-types"), + StanzaPass("resolve-genders"), + StanzaPass("check-genders"), + StanzaPass("check-kinds"), + StanzaPass("check-types"), + StanzaPass("expand-accessors"), + StanzaPass("lower-to-ground"), + StanzaPass("inline-indexers"), + StanzaPass("infer-types"), + //ScalaPass(inferTypes), + StanzaPass("check-genders"), + StanzaPass("expand-whens"), + StanzaPass("infer-widths"), + StanzaPass("real-ir"), + StanzaPass("width-check"), + StanzaPass("pad-widths"), + StanzaPass("const-prop"), + StanzaPass("split-expressions"), + StanzaPass("width-check"), + StanzaPass("high-form-check"), + StanzaPass("low-form-check"), + StanzaPass("check-init")//, + //ScalaPass(renamec) + )) + + val outfile = passes.foldLeft( input ) ( (infile, pass) => { + val outfile = genTempFilename(output) + run(pass, infile, outfile) + outfile + }) + + println(outfile) + + // finally, convert to verilog at the end + val cmd = Seq("firrtl-stanza", "-i", outfile, "-o", output, "-X", "verilog") println(cmd.mkString(" ")) val ret = cmd.!! println(ret) - - // Switch between stanza and scala implementations - //var scala2Stanza = input - //for ((stanzaPass, scalaPass) <- stanzaPass zip scalaPass) { - // val stanza2Scala = genTempFilename(output) - // val cmd: Seq[String] = Seq[String](stanza, "-i", scala2Stanza, "-o", stanza2Scala, "-b", "firrtl") ++ stanzaPass.flatMap(Seq("-x", _)) - // println(cmd.mkString(" ")) - // val ret = cmd.!! - // println(ret) - - // if( scalaPass.isEmpty ) { - // scala2Stanza = stanza2Scala - // } else { - // var ast = Parser.parse(input, stanza2Scala) - // //scalaPass.foreach( f => (ast = f(ast)) ) // Does this work? - // for ( f <- scalaPass ) yield { ast = mapString2Pass(f)(ast) } - - // scala2Stanza = genTempFilename(output) - // val writer = new PrintWriter(new File(scala2Stanza)) - // writer.write(ast.serialize()) - // writer.close() - // } - //} - //val cmd = Seq(stanza, "-i", scala2Stanza, "-o", output, "-b", "verilog") - //println(cmd.mkString(" ")) - //val ret = cmd.!! - //println(ret) } def main(args: Array[String]) @@ -115,7 +158,7 @@ object Driver case _ => 'debug } - def nextPrintVar(syms: List[Symbol], chars: List[Char]): List[Symbol] = + def nextPrintVar(syms: List[Symbol], chars: List[Char]): List[Symbol] = chars match { case Nil => syms case 't' :: tail => nextPrintVar(syms ++ List('types), tail) @@ -125,16 +168,16 @@ object Driver case 'g' :: tail => nextPrintVar(syms ++ List('genders), tail) case 'c' :: tail => nextPrintVar(syms ++ List('circuit), tail) case 'd' :: tail => nextPrintVar(syms ++ List('debug), tail) // Currently ignored - case 'i' :: tail => nextPrintVar(syms ++ List('info), tail) + case 'i' :: tail => nextPrintVar(syms ++ List('info), tail) case char :: tail => throw new Exception("Unknown print option " + char) } def nextOption(map: OptionMap, list: List[String]): OptionMap = { list match { case Nil => map - case "-X" :: value :: tail => + case "-X" :: value :: tail => nextOption(map ++ Map('compiler -> value), tail) - case "-d" :: value :: tail => + case "-d" :: value :: tail => nextOption(map ++ Map('debugMode -> value), tail) case "-l" :: value :: tail => nextOption(map ++ Map('log -> value), tail) @@ -146,7 +189,7 @@ object Driver nextOption(map ++ Map('output -> value), tail) case ("-h" | "--help") :: tail => nextOption(map ++ Map('help -> true), tail) - case option :: tail => + case option :: tail => throw new Exception("Unknown option " + option) } } diff --git a/src/main/scala/firrtl/Passes.scala b/src/main/scala/firrtl/Passes.scala index f64d67bb..f5691c45 100644 --- a/src/main/scala/firrtl/Passes.scala +++ b/src/main/scala/firrtl/Passes.scala @@ -17,6 +17,8 @@ object Passes { //mapNameToPass.getOrElse(name, throw new Exception("No Standard FIRRTL Pass of name " + name)) name match { case "infer-types" => inferTypes + // errrrrrrrrrr... + case "renameall" => renameall(Map()) } } @@ -28,7 +30,7 @@ object Passes { } } - /** INFER TYPES + /** INFER TYPES * * This pass infers the type field in all IR nodes by updating * and passing an environment to all statements in pre-order @@ -86,7 +88,7 @@ object Passes { case s: DefInst => (s, typeMap ++ Map(s.name -> s.module.getType)) case s: DefNode => (s, typeMap ++ Map(s.name -> s.value.getType)) case s: DefPoison => (s, typeMap ++ Map(s.name -> s.tpe)) - case s: DefAccessor => (s, typeMap ++ Map(s.name -> getVectorSubtype(s.source.getType))) + case s: DefAccessor => (s, typeMap ++ Map(s.name -> getVectorSubtype(s.source.getType))) case s: When => { // TODO Check: Assuming else block won't see when scope val (conseq, cMap) = inferTypes(typeMap, s.conseq) val (alt, aMap) = inferTypes(typeMap, s.alt) @@ -112,4 +114,74 @@ object Passes { Circuit(c.info, c.name, c.modules.map(inferTypes(typeMap, _))) } + def renameall(s : String)(implicit map : Map[String,String]) : String = + map getOrElse (s, s) + + def renameall(e : Exp)(implicit logger : Logger, map : Map[String,String]) : Exp = { + logger.trace(s"renameall called on expression ${e.toString}") + e match { + case p : Ref => + Ref(renameall(p.name), p.tpe) + case p : Subfield => + Subfield(renameall(p.exp), renameall(p.name), p.tpe) + case p : Index => + Index(renameall(p.exp), p.value, p.tpe) + case p : DoPrimop => + println( p.args.map(x => renameall(x)) ) + DoPrimop(p.op, p.args.map(x => renameall(x)), p.consts, p.tpe) + case p : Exp => p + } + } + + def renameall(s : Stmt)(implicit logger : Logger, map : Map[String,String]) : Stmt = { + logger.trace(s"renameall called on statement ${s.toString}") + + s match { + case p : DefWire => + DefWire(p.info, renameall(p.name), p.tpe) + case p: DefReg => + DefReg(p.info, renameall(p.name), p.tpe, p.clock, p.reset) + case p : DefMemory => + DefMemory(p.info, renameall(p.name), p.seq, p.tpe, p.clock) + case p : DefInst => + DefInst(p.info, renameall(p.name), renameall(p.module)) + case p : DefNode => + DefNode(p.info, renameall(p.name), renameall(p.value)) + case p : DefPoison => + DefPoison(p.info, renameall(p.name), p.tpe) + case p : DefAccessor => + DefAccessor(p.info, renameall(p.name), p.dir, renameall(p.source), renameall(p.index)) + case p : OnReset => + OnReset(p.info, renameall(p.lhs), renameall(p.rhs)) + case p : Connect => + Connect(p.info, renameall(p.lhs), renameall(p.rhs)) + case p : BulkConnect => + BulkConnect(p.info, renameall(p.lhs), renameall(p.rhs)) + case p : When => + When(p.info, renameall(p.pred), renameall(p.conseq), renameall(p.alt)) + case p : Assert => + Assert(p.info, renameall(p.pred)) + case p : Block => + Block(p.stmts.map(renameall)) + case p : Stmt => p + } + } + + def renameall(p : Port)(implicit logger : Logger, map : Map[String,String]) : Port = { + logger.trace(s"renameall called on port ${p.name}") + Port(p.info, renameall(p.name), p.dir, p.tpe) + } + + def renameall(m : Module)(implicit logger : Logger, map : Map[String,String]) : Module = { + logger.trace(s"renameall called on module ${m.name}") + Module(m.info, renameall(m.name), m.ports.map(renameall(_)), renameall(m.stmt)) + } + + def renameall(map : Map[String,String])(implicit logger : Logger) : Circuit => Circuit = { + c => { + implicit val imap = map + logger.trace(s"renameall called on circuit ${c.name} with %{renameto}") + Circuit(c.info, renameall(c.name), c.modules.map(renameall(_))) + } + } } |
