diff options
5 files changed, 56 insertions, 4 deletions
diff --git a/src/main/scala/firrtl/Compiler.scala b/src/main/scala/firrtl/Compiler.scala index 87662800..1ef35891 100644 --- a/src/main/scala/firrtl/Compiler.scala +++ b/src/main/scala/firrtl/Compiler.scala @@ -295,6 +295,9 @@ trait Emitter extends Transform { def emit(state: CircuitState, writer: Writer): Unit } +/** Wraps exceptions from CustomTransforms so they can be reported appropriately */ +case class CustomTransformException(cause: Throwable) extends Exception("", cause) + object CompilerUtils extends LazyLogging { /** Generates a sequence of [[Transform]]s to lower a Firrtl circuit * @@ -427,6 +430,15 @@ trait Compiler extends LazyLogging { compile(state.copy(annotations = emitAnno +: state.annotations), customTransforms) } + private def isCustomTransform(xform: Transform): Boolean = { + def getTopPackage(pack: java.lang.Package): java.lang.Package = + Package.getPackage(pack.getName.split('.').head) + // We use the top package of the Driver to get the top firrtl package + Option(xform.getClass.getPackage).map { p => + getTopPackage(p) != firrtl.Driver.getClass.getPackage + }.getOrElse(true) + } + /** Perform compilation * * Emission will only be performed if [[EmitAnnotation]]s are present @@ -440,7 +452,14 @@ trait Compiler extends LazyLogging { val allTransforms = CompilerUtils.mergeTransforms(transforms, customTransforms) :+ emitter val (timeMillis, finalState) = Utils.time { - allTransforms.foldLeft(state) { (in, xform) => xform.runTransform(in) } + allTransforms.foldLeft(state) { (in, xform) => + try { + xform.runTransform(in) + } catch { + // Wrap exceptions from custom transforms so they are reported as such + case e: Exception if isCustomTransform(xform) => throw CustomTransformException(e) + } + } } logger.error(f"Total FIRRTL Compile Time: $timeMillis%.1f ms") diff --git a/src/main/scala/firrtl/Driver.scala b/src/main/scala/firrtl/Driver.scala index 47841cec..c277e120 100644 --- a/src/main/scala/firrtl/Driver.scala +++ b/src/main/scala/firrtl/Driver.scala @@ -248,6 +248,8 @@ object Driver { case p: PassException => throw p case p: PassExceptions => throw p case p: FIRRTLException => throw p + // Propagate exceptions from custom transforms + case CustomTransformException(cause) => throw cause // Treat remaining exceptions as internal errors. case e: Exception => throwInternalError(exception = Some(e)) } diff --git a/src/test/scala/firrtlTests/CustomTransformSpec.scala b/src/test/scala/firrtlTests/CustomTransformSpec.scala index d1ff6fd1..1b0e8190 100644 --- a/src/test/scala/firrtlTests/CustomTransformSpec.scala +++ b/src/test/scala/firrtlTests/CustomTransformSpec.scala @@ -46,5 +46,30 @@ class CustomTransformSpec extends FirrtlFlatSpec { runFirrtlTest("CustomTransform", "/features", customTransforms = List(new ReplaceExtModuleTransform)) } + + they should "not cause \"Internal Errors\"" in { + val input = """ + |circuit test : + | module test : + | output out : UInt + | out <= UInt(123)""".stripMargin + val errorString = "My Custom Transform failed!" + class ErroringTransform extends Transform { + def inputForm = HighForm + def outputForm = HighForm + def execute(state: CircuitState): CircuitState = { + require(false, errorString) + state + } + } + val optionsManager = new ExecutionOptionsManager("test") with HasFirrtlOptions { + firrtlOptions = FirrtlExecutionOptions( + firrtlSource = Some(input), + customTransforms = List(new ErroringTransform)) + } + (the [java.lang.IllegalArgumentException] thrownBy { + Driver.execute(optionsManager) + }).getMessage should include (errorString) + } } diff --git a/src/test/scala/firrtlTests/InferReadWriteSpec.scala b/src/test/scala/firrtlTests/InferReadWriteSpec.scala index bffb1b51..db50b491 100644 --- a/src/test/scala/firrtlTests/InferReadWriteSpec.scala +++ b/src/test/scala/firrtlTests/InferReadWriteSpec.scala @@ -135,8 +135,11 @@ circuit sram6t : """.stripMargin val annos = Seq(memlib.InferReadWriteAnnotation) - intercept[InferReadWriteCheckException] { + intercept[Exception] { compileAndEmit(CircuitState(parse(input), ChirrtlForm, annos)) + } match { + case CustomTransformException(_: InferReadWriteCheckException) => // success + case _ => fail() } } diff --git a/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala b/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala index de84d79d..c75e0914 100644 --- a/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala +++ b/src/test/scala/firrtlTests/annotationTests/EliminateTargetPathsSpec.scala @@ -260,16 +260,19 @@ class EliminateTargetPathsSpec extends FirrtlPropSpec with FirrtlMatchers { | m2.i <= m1.o | o <= m2.o """.stripMargin - intercept[NoSuchTargetException] { + val e1 = the [CustomTransformException] thrownBy { val Top_m1 = Top.instOf("m1", "MiddleX") val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Top_m1))) new LowFirrtlCompiler().compile(inputState, customTransforms) } - intercept[NoSuchTargetException] { + e1.cause shouldBe a [NoSuchTargetException] + + val e2 = the [CustomTransformException] thrownBy { val Top_m2 = Top.instOf("x2", "Middle") val inputState = CircuitState(parse(input), ChirrtlForm, Seq(DummyAnnotation(Top_m2))) new LowFirrtlCompiler().compile(inputState, customTransforms) } + e2.cause shouldBe a [NoSuchTargetException] } property("No name conflicts between two new modules") { |
