aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorjackkoenig2015-12-03 01:27:42 -0800
committerjackkoenig2015-12-03 01:27:42 -0800
commit8e050ba48063d7f33551abcbb5c924b5d484aab7 (patch)
tree695b2bfb094ae592cd00a8bf40c5dce339d8d9d6 /src
parent9509036a2dbbe48af168762b634d96c4289eefe6 (diff)
Seem to be able to generate simulation wrapper module from DefInst
Diffstat (limited to 'src')
-rw-r--r--src/main/scala/midas/Fame.scala118
1 files changed, 91 insertions, 27 deletions
diff --git a/src/main/scala/midas/Fame.scala b/src/main/scala/midas/Fame.scala
index 2d32f220..c2a7ff22 100644
--- a/src/main/scala/midas/Fame.scala
+++ b/src/main/scala/midas/Fame.scala
@@ -137,22 +137,91 @@ object Fame1 {
}
private def getDefInsts(m: Module): Seq[DefInst] = getDefInsts(m.stmt)
- // Find the top module of a firrtl.Circuit
- private def findTop(c: Circuit): Module = {
- val moduleMap = c.modules.map(m => m.name -> m)(collection.breakOut): Map[String, Module]
- moduleMap(c.name)
+ private def getInstRef(inst: DefInst): Ref = {
+ inst.module match {
+ case ref: Ref => ref
+ case _ => throw new Exception("Invalid module expression for DefInst: " + inst.serialize)
+ }
+ }
+
+ // DefInsts have an Expression for the module, this expression should be a reference this
+ // reference has a tpe that should be a bundle representing the IO of that module class
+ private def getDefInstType(inst: DefInst): BundleType = {
+ val ref = getInstRef(inst)
+ ref.tpe match {
+ case b: BundleType => b
+ case _ => throw new Exception("Invalid reference type for DefInst: " + inst.serialize)
+ }
+ }
+
+ private def getModuleFromDefInst(nameToModule: Map[String, Module], inst: DefInst): Module = {
+ val instModule = getInstRef(inst)
+ if(!nameToModule.contains(instModule.name))
+ throw new Exception(s"Module ${instModule.name} not found in circuit!")
+ else
+ nameToModule(instModule.name)
+ }
+
+ private def genWrapperModuleName(name: String): String = s"SimWrap_${name}"
+ private val readyValidPair = BundleType(Seq(Field("ready", Reverse, UIntType(IntWidth(1))),
+ Field("valid", Default, UIntType(IntWidth(1)))))
+ private def getPortDir(dir: FieldDir): PortDir = {
+ dir match {
+ case Default => Output
+ case Reverse => Input
+ }
}
+ // Takes a set of strings and returns equivalent subfield node
+ // eg. Seq(io, port, ready) corresponds to io.port.ready
+ private def namesToSubfield(names: Seq[String]): Subfield = {
+ def rec(names: Seq[String]): Exp = {
+ if( names.length == 1 ) Ref(names.head, UnknownType)
+ else Subfield(rec(names.tail), names.head, UnknownType)
+ }
+ rec(names.reverse) match {
+ case s: Subfield => s
+ case _ => throw new Exception("Subfield requires more than 1 name!")
+ }
+ }
+ private def genAndReduce(args: Seq[Seq[String]]): DoPrimop = {
+ if( args.length == 2 )
+ DoPrimop(And, Seq(namesToSubfield(args.head), namesToSubfield(args.last)), Seq(), UnknownType)
+ else
+ DoPrimop(And, Seq(namesToSubfield(args.head), genAndReduce(args.tail)), Seq(), UnknownType)
+ }
+ private def genWrapperModule(inst: DefInst, connections: Seq[String]): Module = {
+ val instIO = getDefInstType(inst)
+ // Add ports for each connection
+ val simInputPorts = connections.map(s => Port(inst.info, s"simInput_${s}", Input, readyValidPair))
+ val simOutputPorts = connections.map(s => Port(inst.info, s"simOutput_${s}", Output, readyValidPair))
+ val rtlPorts = instIO.fields.map(f => Port(inst.info, f.name, getPortDir(f.dir), f.tpe))
+ val ports = rtlPorts ++ simInputPorts ++ simOutputPorts
+
+ val simFireInputs = simInputPorts.map(p => Seq(p.name, "valid")) ++ simOutputPorts.map(p => Seq(p.name, "ready"))
+ val simFire = DefNode(inst.info, "simFire", genAndReduce(simFireInputs))
+ val simClock = DefNode(inst.info, "simClock", DoPrimop(And,
+ Seq(Ref(simFire.name, UnknownType), Ref("clock", UnknownType)), Seq(), UnknownType))
+ val inputsReady = simInputPorts.map(p =>
+ Connect(inst.info, namesToSubfield(Seq(p.name, "ready")), UIntValue(1, IntWidth(1))))
+ val outputsValid = simOutputPorts.map(p =>
+ Connect(inst.info, namesToSubfield(Seq(p.name, "valid")), Ref(simFire.name, UnknownType)))
+ val instIOConnect = instIO.fields.map{ io =>
+ io.tpe match {
+ case ClockType => Connect(inst.info, Ref(io.name, io.tpe), Ref(simClock.name, UnknownType))
+ case _ =>
+ io.dir match {
+ case Default => Connect(inst.info, Ref(io.name, io.tpe), namesToSubfield(Seq(inst.name, io.name)))
+ case Reverse => Connect(inst.info, namesToSubfield(Seq(inst.name, io.name)), Ref(io.name, io.tpe))
+ } } }
+ val stmts = Block(Seq(simFire, simClock, inst) ++ inputsReady ++ outputsValid ++ instIOConnect)
- private def genWrapperModuleName(m: Module): String = s"SimWrap_${m.name}"
- private def genWrapperModule(m: Module, connections: Seq[String]): Module = {
- Module(m.info, genWrapperModuleName(m), m.ports, m.stmt)
+ Module(inst.info, s"SimWrap_${inst.name}", ports, stmts)
}
def transform(c: Circuit): Circuit = {
- //Circuit(c.info, c.name, c.modules.map(fame1Transform(_, c.name)))
- //println(s"In circuit ${c.name}, we have instances: ")
- //println(insts)
- val top = findTop(c)
+ val nameToModule = c.modules.map(m => m.name -> m)(collection.breakOut): Map[String, Module]
+ val top = nameToModule(c.name)
+
println("Top Module:")
println(top.serialize)
@@ -160,23 +229,18 @@ object Fame1 {
println(s"In top module ${top.name}, we have instances: ")
insts.foreach(i => println(" " + i.name))
- //val
-
- //val simInsts = insts.map(convertInst(c, _))
- //println("Simulation instances: ")
- //simInsts.foreach{i =>
- // println(s"${i.name} : ${i.module.name}")
- // i.ports.foreach{p =>
- // val endpoint = p.endpoint match {
- // case UnknownSimInst => "?"
- // case SimTopIO => "TopIO"
- // case inst: SimInst => inst.name
- // }
- // println(s" ${p.port.name} : ${p.port.dir.serialize} : ${endpoint}")
- // }
- //}
- //println(simInsts)
+ val connections = Seq("topIO") ++ insts.map(_.name)
+ println(connections)
+
+ val wrappers = insts.map { inst =>
+ genWrapperModule(inst, connections.filter(_ != inst.name))
+ }
+
+ wrappers.foreach { w => println(w.serialize) }
+ //val wrappers = insts.map(genWrapperModule(_, connections))
+ //wrappers.foreach(println(_.serialize))
+
c
}