1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
|
package firrtlTests.execution
import firrtl._
import firrtl.ir._
object DUTRules {
val dutName = "dut"
val clock = Reference("clock", ClockType)
val reset = Reference("reset", Utils.BoolType)
val counter = Reference("step", UnknownType)
// Need a flat name for the register that latches poke values
val illegal = raw"[\[\]\.]".r
val pokeRegSuffix = "_poke"
def pokeRegName(e: Expression) = illegal.replaceAllIn(e.serialize, "_") + pokeRegSuffix
// Naming patterns are static, so DUT has to be checked for proper form + collisions
def hasNameConflicts(c: Circuit): Boolean = {
val top = c.modules.find(_.name == c.main).get
val names = Namespace(top).cloneUnderlying
names.contains(counter.name) || names.exists(_.contains(pokeRegSuffix))
}
}
object ExecutionTestHelper {
val counterType = UIntType(IntWidth(32))
def apply(body: String): ExecutionTestHelper = {
// Parse input and check that it complies with test syntax rules
val c = ParseStatement.makeDUT(body)
require(!DUTRules.hasNameConflicts(c), "Avoid using 'step' or 'poke' in DUT component names")
// Generate test step counter, create ExecutionTestHelper that represents initial test state
val cnt = DefRegister(NoInfo, DUTRules.counter.name, counterType, DUTRules.clock, DUTRules.reset, Utils.zero)
val inc =
Connect(NoInfo, DUTRules.counter, DoPrim(PrimOps.Add, Seq(DUTRules.counter, UIntLiteral(1)), Nil, UnknownType))
ExecutionTestHelper(c, Seq(cnt, inc), Map.empty[Expression, Expression], Nil, Nil)
}
}
case class ExecutionTestHelper(
dut: Circuit,
setup: Seq[Statement],
pokeRegs: Map[Expression, Expression],
completedSteps: Seq[Conditionally],
activeStep: Seq[Statement]) {
def step(n: Int): ExecutionTestHelper = {
require(n > 0, "Step length must be positive")
(0 until n).foldLeft(this) { case (eth, int) => eth.next }
}
def poke(expString: String, value: Literal): ExecutionTestHelper = {
val pokeExp = ParseExpression(expString)
val pokeable = ensurePokeable(pokeExp)
pokeable.addStatements(Connect(NoInfo, pokeExp, value), Connect(NoInfo, pokeable.pokeRegs(pokeExp), value))
}
def invalidate(expString: String): ExecutionTestHelper = {
addStatements(IsInvalid(NoInfo, ParseExpression(expString)))
}
def expect(expString: String, value: Literal): ExecutionTestHelper = {
val peekExp = ParseExpression(expString)
val neq = DoPrim(PrimOps.Neq, Seq(peekExp, value), Nil, Utils.BoolType)
addStatements(Stop(NoInfo, 1, DUTRules.clock, neq))
}
def finish(): ExecutionTestHelper = {
addStatements(Stop(NoInfo, 0, DUTRules.clock, Utils.one)).next
}
// Private helper methods
private def t = completedSteps.length
private def addStatements(stmts: Statement*) = copy(activeStep = activeStep ++ stmts)
private def next: ExecutionTestHelper = {
val count = Reference(DUTRules.counter.name, DUTRules.counter.tpe)
val ifStep = DoPrim(PrimOps.Eq, Seq(count, UIntLiteral(t)), Nil, Utils.BoolType)
val onThisStep = Conditionally(NoInfo, ifStep, Block(activeStep), EmptyStmt)
copy(completedSteps = completedSteps :+ onThisStep, activeStep = Nil)
}
private def top: Module = {
dut.modules.collectFirst({ case m: Module if m.name == dut.main => m }).get
}
private[execution] def emit: Circuit = {
val finished = finish()
val modulesX = dut.modules.collect {
case m: Module if m.name == dut.main =>
m.copy(body = Block(m.body +: (setup ++ finished.completedSteps)))
case m => m
}
dut.copy(modules = modulesX)
}
private def ensurePokeable(pokeExp: Expression): ExecutionTestHelper = {
if (pokeRegs.contains(pokeExp)) {
this
} else {
val pName = DUTRules.pokeRegName(pokeExp)
val pRef = Reference(pName, UnknownType)
val pReg = DefRegister(NoInfo, pName, UIntType(UnknownWidth), DUTRules.clock, Utils.zero, pRef)
val defaultConn = Connect(NoInfo, pokeExp, pRef)
copy(setup = setup ++ Seq(pReg, defaultConn), pokeRegs = pokeRegs + (pokeExp -> pRef))
}
}
}
|