diff options
| author | jackkoenig | 2015-12-03 01:27:42 -0800 |
|---|---|---|
| committer | jackkoenig | 2015-12-03 01:27:42 -0800 |
| commit | 8e050ba48063d7f33551abcbb5c924b5d484aab7 (patch) | |
| tree | 695b2bfb094ae592cd00a8bf40c5dce339d8d9d6 /src | |
| parent | 9509036a2dbbe48af168762b634d96c4289eefe6 (diff) | |
Seem to be able to generate simulation wrapper module from DefInst
Diffstat (limited to 'src')
| -rw-r--r-- | src/main/scala/midas/Fame.scala | 118 |
1 files changed, 91 insertions, 27 deletions
diff --git a/src/main/scala/midas/Fame.scala b/src/main/scala/midas/Fame.scala index 2d32f220..c2a7ff22 100644 --- a/src/main/scala/midas/Fame.scala +++ b/src/main/scala/midas/Fame.scala @@ -137,22 +137,91 @@ object Fame1 { } private def getDefInsts(m: Module): Seq[DefInst] = getDefInsts(m.stmt) - // Find the top module of a firrtl.Circuit - private def findTop(c: Circuit): Module = { - val moduleMap = c.modules.map(m => m.name -> m)(collection.breakOut): Map[String, Module] - moduleMap(c.name) + private def getInstRef(inst: DefInst): Ref = { + inst.module match { + case ref: Ref => ref + case _ => throw new Exception("Invalid module expression for DefInst: " + inst.serialize) + } + } + + // DefInsts have an Expression for the module, this expression should be a reference this + // reference has a tpe that should be a bundle representing the IO of that module class + private def getDefInstType(inst: DefInst): BundleType = { + val ref = getInstRef(inst) + ref.tpe match { + case b: BundleType => b + case _ => throw new Exception("Invalid reference type for DefInst: " + inst.serialize) + } + } + + private def getModuleFromDefInst(nameToModule: Map[String, Module], inst: DefInst): Module = { + val instModule = getInstRef(inst) + if(!nameToModule.contains(instModule.name)) + throw new Exception(s"Module ${instModule.name} not found in circuit!") + else + nameToModule(instModule.name) + } + + private def genWrapperModuleName(name: String): String = s"SimWrap_${name}" + private val readyValidPair = BundleType(Seq(Field("ready", Reverse, UIntType(IntWidth(1))), + Field("valid", Default, UIntType(IntWidth(1))))) + private def getPortDir(dir: FieldDir): PortDir = { + dir match { + case Default => Output + case Reverse => Input + } } + // Takes a set of strings and returns equivalent subfield node + // eg. Seq(io, port, ready) corresponds to io.port.ready + private def namesToSubfield(names: Seq[String]): Subfield = { + def rec(names: Seq[String]): Exp = { + if( names.length == 1 ) Ref(names.head, UnknownType) + else Subfield(rec(names.tail), names.head, UnknownType) + } + rec(names.reverse) match { + case s: Subfield => s + case _ => throw new Exception("Subfield requires more than 1 name!") + } + } + private def genAndReduce(args: Seq[Seq[String]]): DoPrimop = { + if( args.length == 2 ) + DoPrimop(And, Seq(namesToSubfield(args.head), namesToSubfield(args.last)), Seq(), UnknownType) + else + DoPrimop(And, Seq(namesToSubfield(args.head), genAndReduce(args.tail)), Seq(), UnknownType) + } + private def genWrapperModule(inst: DefInst, connections: Seq[String]): Module = { + val instIO = getDefInstType(inst) + // Add ports for each connection + val simInputPorts = connections.map(s => Port(inst.info, s"simInput_${s}", Input, readyValidPair)) + val simOutputPorts = connections.map(s => Port(inst.info, s"simOutput_${s}", Output, readyValidPair)) + val rtlPorts = instIO.fields.map(f => Port(inst.info, f.name, getPortDir(f.dir), f.tpe)) + val ports = rtlPorts ++ simInputPorts ++ simOutputPorts + + val simFireInputs = simInputPorts.map(p => Seq(p.name, "valid")) ++ simOutputPorts.map(p => Seq(p.name, "ready")) + val simFire = DefNode(inst.info, "simFire", genAndReduce(simFireInputs)) + val simClock = DefNode(inst.info, "simClock", DoPrimop(And, + Seq(Ref(simFire.name, UnknownType), Ref("clock", UnknownType)), Seq(), UnknownType)) + val inputsReady = simInputPorts.map(p => + Connect(inst.info, namesToSubfield(Seq(p.name, "ready")), UIntValue(1, IntWidth(1)))) + val outputsValid = simOutputPorts.map(p => + Connect(inst.info, namesToSubfield(Seq(p.name, "valid")), Ref(simFire.name, UnknownType))) + val instIOConnect = instIO.fields.map{ io => + io.tpe match { + case ClockType => Connect(inst.info, Ref(io.name, io.tpe), Ref(simClock.name, UnknownType)) + case _ => + io.dir match { + case Default => Connect(inst.info, Ref(io.name, io.tpe), namesToSubfield(Seq(inst.name, io.name))) + case Reverse => Connect(inst.info, namesToSubfield(Seq(inst.name, io.name)), Ref(io.name, io.tpe)) + } } } + val stmts = Block(Seq(simFire, simClock, inst) ++ inputsReady ++ outputsValid ++ instIOConnect) - private def genWrapperModuleName(m: Module): String = s"SimWrap_${m.name}" - private def genWrapperModule(m: Module, connections: Seq[String]): Module = { - Module(m.info, genWrapperModuleName(m), m.ports, m.stmt) + Module(inst.info, s"SimWrap_${inst.name}", ports, stmts) } def transform(c: Circuit): Circuit = { - //Circuit(c.info, c.name, c.modules.map(fame1Transform(_, c.name))) - //println(s"In circuit ${c.name}, we have instances: ") - //println(insts) - val top = findTop(c) + val nameToModule = c.modules.map(m => m.name -> m)(collection.breakOut): Map[String, Module] + val top = nameToModule(c.name) + println("Top Module:") println(top.serialize) @@ -160,23 +229,18 @@ object Fame1 { println(s"In top module ${top.name}, we have instances: ") insts.foreach(i => println(" " + i.name)) - //val - - //val simInsts = insts.map(convertInst(c, _)) - //println("Simulation instances: ") - //simInsts.foreach{i => - // println(s"${i.name} : ${i.module.name}") - // i.ports.foreach{p => - // val endpoint = p.endpoint match { - // case UnknownSimInst => "?" - // case SimTopIO => "TopIO" - // case inst: SimInst => inst.name - // } - // println(s" ${p.port.name} : ${p.port.dir.serialize} : ${endpoint}") - // } - //} - //println(simInsts) + val connections = Seq("topIO") ++ insts.map(_.name) + println(connections) + + val wrappers = insts.map { inst => + genWrapperModule(inst, connections.filter(_ != inst.name)) + } + + wrappers.foreach { w => println(w.serialize) } + //val wrappers = insts.map(genWrapperModule(_, connections)) + //wrappers.foreach(println(_.serialize)) + c } |
