summaryrefslogtreecommitdiff
path: root/hdl/core
diff options
context:
space:
mode:
Diffstat (limited to 'hdl/core')
-rw-r--r--hdl/core/jump_ctl.py180
1 files changed, 157 insertions, 23 deletions
diff --git a/hdl/core/jump_ctl.py b/hdl/core/jump_ctl.py
index 4dd037b..6a7cf1b 100644
--- a/hdl/core/jump_ctl.py
+++ b/hdl/core/jump_ctl.py
@@ -2,10 +2,11 @@ from amaranth import *
from amaranth.sim import Simulator, Settle, Delay
from enum import Enum, unique
-from hdl.utils import cmd, step, sim
+from hdl.utils import cmd, step, eval, sim, rand_bits_mix
from hdl.lib.in_out_buff import InOutBuff # used for timing analysis
-from hdl.core.alu import ALUFlags, ALU
+from hdl.core.alu import ALUFlags, ALU, AluOpCodes #ALUOpCodes is for simulation only, not used in hardware
+from hdl.config import NUM_RAND_TESTS
@unique
class JumpOpCodes(Enum):
@@ -16,16 +17,17 @@ class JumpOpCodes(Enum):
j_lt_s = 4
j_lte_s = 5
-class HDL(Elaboratable):
+class JumpCtl(Elaboratable):
def __init__(self, **kargs):
self.alu_flags = Signal(ALU().alu_flags.width, reset_less=True)
- self.jump_op = Signal(3, reset_less=True)
+ self.op = Signal(3, reset_less=True)
+ self.signed_bits = Signal(2, reset_less=True)
- self.jump = Signal(reset_less=True) # true if jump condition is met
+ self.cond_true = Signal(reset_less=True) # true if jump condition is met
- ports_in = [self.alu_flags, self.jump_op]
- ports_out = [self.jump]
+ ports_in = [self.alu_flags, self.op, self.signed_bits]
+ ports_out = [self.cond_true]
self.ports = {'in': ports_in, 'out': ports_out}
self.sim = kargs["sim"] if "sim" in kargs else False
@@ -38,36 +40,168 @@ class HDL(Elaboratable):
dummy = Signal()
m.d.sync += dummy.eq(~dummy)
- with m.Switch(self.jump_op):
+ # xor the bits if both are positive or negative, this is needed to prevent problems with overflow
+ diff_sign = Signal(reset_less=True)
+ m.d.comb += diff_sign.eq(self.signed_bits.xor())
+
+ # checks conditions for ALU in1 and in2
+ # e.g. in1 less than in2
+ # this is done by in1 - in2 and reading the negative and zero flags
+ with m.Switch(self.op):
with m.Case(JumpOpCodes.j_eq.value):
- m.d.comb += self.jump.eq(self.alu_flags[ALUFlags.zero])
+ m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.zero.value])
+
with m.Case(JumpOpCodes.j_ne.value):
- m.d.comb += self.jump.eq(~self.alu_flags[ALUFlags.zero])
+ m.d.comb += self.cond_true.eq(~self.alu_flags[ALUFlags.zero.value])
+
with m.Case(JumpOpCodes.j_lt_u.value):
- m.d.comb += self.jump.eq(self.alu_flags[ALUFlags.negative] & ~self.alu_flags[ALUFlags.zero])
+ m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.carry.value] & ~self.alu_flags[ALUFlags.zero.value])
+
with m.Case(JumpOpCodes.j_lte_u.value):
- m.d.comb += self.jump.eq(self.alu_flags[ALUFlags.negative] | self.alu_flags[ALUFlags.zero])
+ m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.carry.value] | self.alu_flags[ALUFlags.zero.value])
+
with m.Case(JumpOpCodes.j_lt_s.value):
- pass
+ with m.If(diff_sign):
+ m.d.comb += self.cond_true.eq(self.signed_bits[0] & ~self.alu_flags[ALUFlags.zero.value]) # signed bits are different, so use sign as condition to branch
+ with m.Else():
+ m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.negative.value] & ~self.alu_flags[ALUFlags.zero.value])
+
with m.Case(JumpOpCodes.j_lte_s.value):
- pass
+ with m.If(diff_sign):
+ m.d.comb += self.cond_true.eq(self.signed_bits[0] | self.alu_flags[ALUFlags.zero.value])
+ with m.Else():
+ m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.negative.value] | self.alu_flags[ALUFlags.zero.value])
+
with m.Case():
- m.d.comb += self.jump.eq(0)
+ m.d.comb += self.cond_true.eq(0)
return m
-# test addition
-def test_hdl():
- dut = HDL(sim=True)
+class DUT(Elaboratable):
+ def __init__(self, **kargs): # DUT will ONLY be used for simulation
+ self.alu = ALU(sim=True)
+ self.jump = JumpCtl(sim=True)
+
+ def elaborate(self, platform=None):
+ m = Module()
+ m.submodules.alu = self.alu
+ m.submodules.jump = self.jump
+
+ m.d.comb += self.jump.alu_flags.eq(self.alu.alu_flags)
+ m.d.comb += self.jump.signed_bits.eq(Cat(self.alu.in1[31], self.alu.in2[31]))
+ return m
+
+def _init_dut(dut):
+ yield dut.alu.in1.eq(0)
+ yield dut.alu.in2.eq(0)
+ yield dut.alu.op.eq(AluOpCodes.sub)
+ yield Settle()
+
+
+# test jump if equal
+def test_jump_eq(tests=NUM_RAND_TESTS):
+ dut = DUT(sim=True) # sim=True is not needed, but I am trying to be consistent
def proc():
- yield from step #step clock
- yield Settle() #needed if for combinatorial logic
- yield dut.something #read value
+ yield from _init_dut(dut)
+ yield dut.jump.op.eq(JumpOpCodes.j_eq.value)
+
+ for _ in range(tests):
+ in1 = rand_bits_mix(32)
+ in2 = rand_bits_mix(32)
+ yield dut.alu.in1.eq(in1)
+ yield dut.alu.in2.eq(in2)
+ yield from eval()
+ assert (yield dut.jump.cond_true) == (in1 == in2), f"jump_eq failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"
+
sim(dut, proc)
+# test jump if not equal
+def test_jump_ne(tests=NUM_RAND_TESTS):
+ dut = DUT(sim=True)
+ def proc():
+ yield from _init_dut(dut)
+ yield dut.jump.op.eq(JumpOpCodes.j_ne.value)
+
+ for _ in range(tests):
+ in1 = rand_bits_mix(32)
+ in2 = rand_bits_mix(32)
+ yield dut.alu.in1.eq(in1)
+ yield dut.alu.in2.eq(in2)
+ yield from eval()
+ assert (yield dut.jump.cond_true) == (in1 != in2), f"jump_ne failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"
+
+ sim(dut, proc)
+
+# test jump if less than unsigned
+def test_jump_lt_u(tests=NUM_RAND_TESTS):
+ dut = DUT(sim=True)
+ def proc():
+ yield from _init_dut(dut)
+ yield dut.jump.op.eq(JumpOpCodes.j_lt_u.value)
+
+ for _ in range(tests):
+ in1 = rand_bits_mix(32, sus='u')
+ in2 = rand_bits_mix(32, sus='u')
+ yield dut.alu.in1.eq(in1)
+ yield dut.alu.in2.eq(in2)
+ yield from eval()
+ assert (yield dut.jump.cond_true) == (in1 < in2), f"jump_lt_u failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"
+
+ sim(dut, proc)
+
+# test jump if less than or equal to unsigned
+def test_jump_lte_u(tests=NUM_RAND_TESTS):
+ dut = DUT(sim=True)
+ def proc():
+ yield from _init_dut(dut)
+ yield dut.jump.op.eq(JumpOpCodes.j_lte_u.value)
+
+ for _ in range(tests):
+ in1 = rand_bits_mix(32, sus='u')
+ in2 = rand_bits_mix(32, sus='u')
+ yield dut.alu.in1.eq(in1)
+ yield dut.alu.in2.eq(in2)
+ yield from eval()
+ assert (yield dut.jump.cond_true) == (in1 <= in2), f"jump_lte_u failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"
+
+ sim(dut, proc)
+
+# test jump if less than signed
+def test_jump_lt_s(tests=NUM_RAND_TESTS):
+ dut = DUT(sim=True)
+ def proc():
+ yield from _init_dut(dut)
+ yield dut.jump.op.eq(JumpOpCodes.j_lt_s.value)
+
+ for _ in range(tests):
+ in1 = rand_bits_mix(32, sus='s')
+ in2 = rand_bits_mix(32, sus='s')
+ yield dut.alu.in1.eq(in1)
+ yield dut.alu.in2.eq(in2)
+ yield from eval()
+ assert (yield dut.jump.cond_true) == (in1 < in2), f"jump_lt_s failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"
+
+ sim(dut, proc)
+
+# test jump if less than or equal to signed
+def test_jump_lte_s(tests=NUM_RAND_TESTS):
+ dut = DUT(sim=True)
+ def proc():
+ yield from _init_dut(dut)
+ yield dut.jump.op.eq(JumpOpCodes.j_lte_s.value)
+
+ for _ in range(tests):
+ in1 = rand_bits_mix(32, sus='s')
+ in2 = rand_bits_mix(32, sus='s')
+ yield dut.alu.in1.eq(in1)
+ yield dut.alu.in2.eq(in2)
+ yield from eval()
+ assert (yield dut.jump.cond_true) == (in1 <= in2), f"jump_lte_s failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"
+
+ sim(dut, proc)
if __name__ == '__main__':
- hdl = InOutBuff(HDL())
- cmd(hdl)
+ hdl = JumpCtl()
+ cmd(hdl) \ No newline at end of file