aboutsummaryrefslogtreecommitdiff
path: root/src/main/scala/firrtl/proto/ToProto.scala
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/scala/firrtl/proto/ToProto.scala')
-rw-r--r--src/main/scala/firrtl/proto/ToProto.scala413
1 files changed, 413 insertions, 0 deletions
diff --git a/src/main/scala/firrtl/proto/ToProto.scala b/src/main/scala/firrtl/proto/ToProto.scala
new file mode 100644
index 00000000..b3fb9a0c
--- /dev/null
+++ b/src/main/scala/firrtl/proto/ToProto.scala
@@ -0,0 +1,413 @@
+// See LICENSE for license details.
+
+package firrtl
+package proto
+
+import java.io.{BufferedOutputStream, OutputStream}
+
+import FirrtlProtos._
+import Firrtl.Expression.PrimOp.Op
+import com.google.protobuf.{CodedOutputStream, WireFormat}
+import firrtl.PrimOps._
+
+import scala.collection.JavaConverters._
+
+object ToProto {
+
+
+ /** Serialize a FIRRTL Circuit to an Output Stream as a ProtoBuf message
+ *
+ * @param ostream Output stream that will be written
+ * @param circuit The Circuit to serialize
+ */
+ def writeToStream(ostream: OutputStream, circuit: ir.Circuit): Unit = {
+ writeToStreamFast(ostream, circuit.info, circuit.modules.map(() => _), circuit.main)
+ }
+
+ /** Serialized a deconstructed Circuit with lazy Modules
+ *
+ * This serializer allows intermediate objects to be garbage collected during serialization
+ * to save time and memory
+ *
+ * @param ostream Output stream that will be written
+ * @param info Info of Circuit
+ * @param modules Functions to generate Modules lazily
+ * @param main Top-level module of the Circuit
+ */
+ // Note this function is sensitive to changes to the Firrtl and Circuit protobuf message definitions
+ def writeToStreamFast(
+ ostream: OutputStream,
+ info: ir.Info,
+ modules: Seq[() => ir.DefModule],
+ main: String
+ ): Unit = {
+ val costream = CodedOutputStream.newInstance(ostream)
+
+ // Write each module for the circuit
+ val ostreamInner = new java.io.ByteArrayOutputStream()
+ val costreamInner = CodedOutputStream.newInstance(ostreamInner)
+ for (mod <- modules) {
+ costreamInner.writeMessage(Firrtl.Circuit.MODULE_FIELD_NUMBER, convert(mod()).build)
+ }
+ val top = Firrtl.Top.newBuilder().setName(main).build
+ costreamInner.writeMessage(Firrtl.Circuit.TOP_FIELD_NUMBER, top)
+
+ // Write Circuit header first
+ costream.writeTag(Firrtl.CIRCUIT_FIELD_NUMBER, WireFormat.WIRETYPE_LENGTH_DELIMITED)
+ costream.writeUInt32NoTag(costreamInner.getTotalBytesWritten)
+ costream.flush()
+
+ // Write Modules
+ costreamInner.flush()
+ ostreamInner.writeTo(ostream)
+ ostreamInner.flush()
+ }
+
+ val convert: Map[ir.PrimOp, Op] = Map(
+ Add -> Op.OP_ADD,
+ Sub -> Op.OP_SUB,
+ Mul -> Op.OP_TIMES,
+ Div -> Op.OP_DIVIDE,
+ Rem -> Op.OP_REM,
+ Lt -> Op.OP_LESS,
+ Leq -> Op.OP_LESS_EQ,
+ Gt -> Op.OP_GREATER,
+ Geq -> Op.OP_GREATER_EQ,
+ Eq -> Op.OP_EQUAL,
+ Neq -> Op.OP_NOT_EQUAL,
+ Pad -> Op.OP_PAD,
+ AsUInt -> Op.OP_AS_UINT,
+ AsSInt -> Op.OP_AS_SINT,
+ AsClock -> Op.OP_AS_CLOCK,
+ AsFixedPoint -> Op.OP_AS_FIXED_POINT,
+ Shl -> Op.OP_SHIFT_LEFT,
+ Shr -> Op.OP_SHIFT_RIGHT,
+ Dshl -> Op.OP_DYNAMIC_SHIFT_LEFT,
+ Dshr -> Op.OP_DYNAMIC_SHIFT_RIGHT,
+ Cvt -> Op.OP_CONVERT,
+ Neg -> Op.OP_NEG,
+ Not -> Op.OP_BIT_NOT,
+ And -> Op.OP_BIT_AND,
+ Or -> Op.OP_BIT_OR,
+ Xor -> Op.OP_BIT_XOR,
+ Andr -> Op.OP_AND_REDUCE,
+ Orr -> Op.OP_OR_REDUCE,
+ Xorr -> Op.OP_XOR_REDUCE,
+ Cat -> Op.OP_CONCAT,
+ Bits -> Op.OP_EXTRACT_BITS,
+ Head -> Op.OP_HEAD,
+ Tail -> Op.OP_TAIL,
+ BPShl -> Op.OP_SHIFT_BINARY_POINT_LEFT,
+ BPShr -> Op.OP_SHIFT_BINARY_POINT_RIGHT,
+ BPSet -> Op.OP_SET_BINARY_POINT
+ )
+
+ def convertToIntegerLiteral(value: BigInt): Firrtl.Expression.IntegerLiteral.Builder = {
+ Firrtl.Expression.IntegerLiteral.newBuilder()
+ .setValue(value.toString)
+ }
+
+ def convertToBigInt(value: BigInt): Firrtl.BigInt.Builder = {
+ Firrtl.BigInt.newBuilder()
+ .setValue(com.google.protobuf.ByteString.copyFrom(value.toByteArray))
+ }
+
+ def convert(info: ir.Info): Firrtl.SourceInfo.Builder = {
+ val ib = Firrtl.SourceInfo.newBuilder()
+ info match {
+ case ir.NoInfo =>
+ ib.setNone(Firrtl.SourceInfo.None.newBuilder)
+ case ir.FileInfo(ir.StringLit(text)) =>
+ ib.setText(text)
+ // TODO properly implement MultiInfo
+ case ir.MultiInfo(infos) =>
+ val x = if (infos.nonEmpty) infos.head else ir.NoInfo
+ convert(x)
+ }
+ }
+
+ def convert(expr: ir.Expression): Firrtl.Expression.Builder = {
+ val eb = Firrtl.Expression.newBuilder()
+ expr match {
+ case ir.Reference(name, _) =>
+ val rb = Firrtl.Expression.Reference.newBuilder()
+ .setId(name)
+ eb.setReference(rb)
+ case ir.SubField(e, name, _) =>
+ val sb = Firrtl.Expression.SubField.newBuilder()
+ .setExpression(convert(e))
+ .setField(name)
+ eb.setSubField(sb)
+ case ir.SubIndex(e, value, _) =>
+ val sb = Firrtl.Expression.SubIndex.newBuilder()
+ .setExpression(convert(e))
+ .setIndex(convertToIntegerLiteral(value))
+ eb.setSubIndex(sb)
+ case ir.SubAccess(e, index, _) =>
+ val sb = Firrtl.Expression.SubAccess.newBuilder()
+ .setExpression(convert(e))
+ .setIndex(convert(index))
+ eb.setSubAccess(sb)
+ case ir.UIntLiteral(value, width) =>
+ val ub = Firrtl.Expression.UIntLiteral.newBuilder()
+ .setValue(convertToIntegerLiteral(value))
+ convert(width).foreach(ub.setWidth)
+ eb.setUintLiteral(ub)
+ case ir.SIntLiteral(value, width) =>
+ val sb = Firrtl.Expression.SIntLiteral.newBuilder()
+ .setValue(convertToIntegerLiteral(value))
+ convert(width).foreach(sb.setWidth)
+ eb.setSintLiteral(sb)
+ case ir.FixedLiteral(value, width, point) =>
+ val fb = Firrtl.Expression.FixedLiteral.newBuilder()
+ .setValue(convertToBigInt(value))
+ convert(width).foreach(fb.setWidth)
+ convert(point).foreach(fb.setPoint)
+ eb.setFixedLiteral(fb)
+ case ir.DoPrim(op, args, consts, _) =>
+ val db = Firrtl.Expression.PrimOp.newBuilder()
+ .setOp(convert(op))
+ consts.foreach(c => db.addConst(convertToIntegerLiteral(c)))
+ args.foreach(a => db.addArg(convert(a)))
+ eb.setPrimOp(db)
+ case ir.Mux(cond, tval, fval, _) =>
+ val mb = Firrtl.Expression.Mux.newBuilder()
+ .setCondition(convert(cond))
+ .setTValue(convert(tval))
+ .setFValue(convert(fval))
+ eb.setMux(mb)
+ }
+ }
+
+ def convert(dir: MPortDir): Firrtl.Statement.MemoryPort.Direction = {
+ import Firrtl.Statement.MemoryPort.Direction._
+ dir match {
+ case MInfer => MEMORY_PORT_DIRECTION_INFER
+ case MRead => MEMORY_PORT_DIRECTION_READ
+ case MWrite => MEMORY_PORT_DIRECTION_WRITE
+ case MReadWrite => MEMORY_PORT_DIRECTION_READ_WRITE
+ }
+ }
+
+ def convert(stmt: ir.Statement): Seq[Firrtl.Statement.Builder] = {
+ stmt match {
+ case ir.Block(stmts) => stmts.flatMap(convert(_))
+ case ir.EmptyStmt => Seq.empty
+ case other =>
+ val sb = Firrtl.Statement.newBuilder()
+ other match {
+ case ir.DefNode(_, name, expr) =>
+ val nb = Firrtl.Statement.Node.newBuilder()
+ .setId(name)
+ .setExpression(convert(expr))
+ sb.setNode(nb)
+ case ir.DefWire(_, name, tpe) =>
+ val wb = Firrtl.Statement.Wire.newBuilder()
+ .setId(name)
+ .setType(convert(tpe))
+ sb.setWire(wb)
+ case ir.DefRegister(_, name, tpe, clock, reset, init) =>
+ val rb = Firrtl.Statement.Register.newBuilder()
+ .setId(name)
+ .setType(convert(tpe))
+ .setClock(convert(clock))
+ .setReset(convert(reset))
+ .setInit(convert(init))
+ sb.setRegister(rb)
+ case ir.DefInstance(_, name, module) =>
+ val ib = Firrtl.Statement.Instance.newBuilder()
+ .setId(name)
+ .setModuleId(module)
+ sb.setInstance(ib)
+ case ir.Connect(_, loc, expr) =>
+ val cb = Firrtl.Statement.Connect.newBuilder()
+ .setLocation(convert(loc))
+ .setExpression(convert(expr))
+ sb.setConnect(cb)
+ case ir.PartialConnect(_, loc, expr) =>
+ val cb = Firrtl.Statement.PartialConnect.newBuilder()
+ .setLocation(convert(loc))
+ .setExpression(convert(expr))
+ sb.setPartialConnect(cb)
+ case ir.Conditionally(_, pred, conseq, alt) =>
+ val cs = convert(conseq)
+ val as = convert(alt)
+ val wb = Firrtl.Statement.When.newBuilder()
+ .setPredicate(convert(pred))
+ cs.foreach(wb.addConsequent)
+ as.foreach(wb.addOtherwise)
+ sb.setWhen(wb)
+ case ir.Print(_, string, args, clk, en) =>
+ val pb = Firrtl.Statement.Printf.newBuilder()
+ .setValue(string.string)
+ .setClk(convert(clk))
+ .setEn(convert(en))
+ args.foreach(a => pb.addArg(convert(a)))
+ sb.setPrintf(pb)
+ case ir.Stop(_, ret, clk, en) =>
+ val stopb = Firrtl.Statement.Stop.newBuilder()
+ .setReturnValue(ret)
+ .setClk(convert(clk))
+ .setEn(convert(en))
+ sb.setStop(stopb)
+ case ir.IsInvalid(_, expr) =>
+ val ib = Firrtl.Statement.IsInvalid.newBuilder()
+ .setExpression(convert(expr))
+ sb.setIsInvalid(ib)
+ case ir.DefMemory(_, name, dtype, depth, wlat, rlat, rs, ws, rws, _) =>
+ val mem = Firrtl.Statement.Memory.newBuilder()
+ .setId(name)
+ .setType(convert(dtype))
+ .setDepth(depth)
+ .setWriteLatency(wlat)
+ .setReadLatency(rlat)
+ mem.addAllReaderId(rs.asJava)
+ mem.addAllWriterId(ws.asJava)
+ mem.addAllReadwriterId(rws.asJava)
+ sb.setMemory(mem)
+ case CDefMemory(_, name, tpe, size, seq) =>
+ val tpeb = convert(ir.VectorType(tpe, size))
+ val mb = Firrtl.Statement.CMemory.newBuilder()
+ .setId(name)
+ .setType(tpeb)
+ .setSyncRead(seq)
+ sb.setCmemory(mb)
+ case CDefMPort(_, name, _, mem, exprs, dir) =>
+ val pb = Firrtl.Statement.MemoryPort.newBuilder()
+ .setId(name)
+ .setMemoryId(mem)
+ .setMemoryIndex(convert(exprs.head))
+ .setExpression(convert(exprs(1)))
+ .setDirection(convert(dir))
+ sb.setMemoryPort(pb)
+ case ir.Attach(_, exprs) =>
+ val ab = Firrtl.Statement.Attach.newBuilder()
+ exprs.foreach(e => ab.addExpression(convert(e)))
+ sb.setAttach(ab)
+ }
+ stmt match {
+ case hasInfo: ir.HasInfo => sb.setSourceInfo(convert(hasInfo.info))
+ case _ => // Do nothing
+ }
+ Seq(sb)
+ }
+ }
+
+ def convert(field: ir.Field): Firrtl.Type.BundleType.Field.Builder = {
+ val b = Firrtl.Type.BundleType.Field.newBuilder()
+ .setId(field.name)
+ .setIsFlipped(field.flip == ir.Flip)
+ .setType(convert(field.tpe))
+ b
+ }
+
+ /** Converts a Width to a ProtoBuf Width Builder
+ *
+ * @param width Input width
+ * @return Option width where None means the width field should be cleared in the parent object
+ */
+ def convert(width: ir.Width): Option[Firrtl.Width.Builder] = width match {
+ case ir.IntWidth(w) => Some(Firrtl.Width.newBuilder().setValue(w.toInt))
+ case ir.UnknownWidth => None
+ }
+
+ def convert(vtpe: ir.VectorType): Firrtl.Type.VectorType.Builder =
+ Firrtl.Type.VectorType.newBuilder()
+ .setType(convert(vtpe.tpe))
+ .setSize(vtpe.size)
+
+ def convert(tpe: ir.Type): Firrtl.Type.Builder = {
+ val tb = Firrtl.Type.newBuilder()
+ tpe match {
+ case ir.UIntType(width) =>
+ val ut = Firrtl.Type.UIntType.newBuilder()
+ convert(width).foreach(ut.setWidth)
+ tb.setUintType(ut)
+ case ir.SIntType(width) =>
+ val st = Firrtl.Type.SIntType.newBuilder()
+ convert(width).foreach(st.setWidth)
+ tb.setSintType(st)
+ case ir.FixedType(width, point) =>
+ val ft = Firrtl.Type.FixedType.newBuilder()
+ convert(width).foreach(ft.setWidth)
+ convert(point).foreach(ft.setPoint)
+ tb.setFixedType(ft)
+ case ir.ClockType =>
+ val ct = Firrtl.Type.ClockType.newBuilder()
+ tb.setClockType(ct)
+ case ir.AnalogType(width) =>
+ val at = Firrtl.Type.AnalogType.newBuilder()
+ convert(width).foreach(at.setWidth)
+ tb.setAnalogType(at)
+ case ir.BundleType(fields) =>
+ val bt = Firrtl.Type.BundleType.newBuilder()
+ fields.foreach(f => bt.addField(convert(f)))
+ tb.setBundleType(bt)
+ case vtpe: ir.VectorType =>
+ val vtb = convert(vtpe)
+ tb.setVectorType(vtb)
+ }
+ }
+
+ def convert(direction: ir.Direction): Firrtl.Port.Direction = direction match {
+ case ir.Input => Firrtl.Port.Direction.PORT_DIRECTION_IN
+ case ir.Output => Firrtl.Port.Direction.PORT_DIRECTION_OUT
+ }
+
+ def convert(port: ir.Port): Firrtl.Port.Builder = {
+ Firrtl.Port.newBuilder()
+ .setId(port.name)
+ .setDirection(convert(port.direction))
+ .setType(convert(port.tpe))
+ }
+
+ def convert(param: ir.Param): Firrtl.Module.ExternalModule.Parameter.Builder = {
+ import Firrtl.Module.ExternalModule._
+ val pb = Parameter.newBuilder()
+ .setId(param.name)
+ param match {
+ case ir.IntParam(_, value) =>
+ pb.setInteger(convertToBigInt(value))
+ case ir.DoubleParam(_, value) =>
+ pb.setDouble(value)
+ case ir.StringParam(_, ir.StringLit(value)) =>
+ pb.setString(value)
+ case ir.RawStringParam(_, value) =>
+ pb.setRawString(value)
+ }
+ }
+
+ def convert(module: ir.DefModule): Firrtl.Module.Builder = {
+ val ports = module.ports.map(convert(_))
+ val b = Firrtl.Module.newBuilder()
+ module match {
+ case mod: ir.Module =>
+ val stmts = convert(mod.body)
+ val mb = Firrtl.Module.UserModule.newBuilder()
+ .setId(mod.name)
+ ports.foreach(mb.addPort)
+ stmts.foreach(mb.addStatement)
+ b.setUserModule(mb)
+ case ext: ir.ExtModule =>
+ val eb = Firrtl.Module.ExternalModule.newBuilder()
+ .setId(ext.name)
+ .setDefinedName(ext.defname)
+ ports.foreach(eb.addPort)
+ val params = ext.params.map(convert(_))
+ params.foreach(eb.addParameter)
+ b.setExternalModule(eb)
+ }
+ }
+
+ def convert(circuit: ir.Circuit): Firrtl = {
+ val moduleBuilders = circuit.modules.map(convert(_))
+ val cb = Firrtl.Circuit.newBuilder
+ .addTop(Firrtl.Top.newBuilder().setName(circuit.main))
+ for (m <- moduleBuilders) {
+ cb.addModule(m)
+ }
+ Firrtl.newBuilder()
+ .addCircuit(cb.build())
+ .build()
+ }
+}