from amaranth import * from amaranth.sim import Simulator, Settle, Delay from enum import Enum, unique from hdl.utils import * from hdl.lib.in_out_buff import InOutBuff # used for timing analysis from hdl.core.alu import ALUFlags, ALU, AluOpCodes #ALUOpCodes is for simulation only, not used in hardware from hdl.config import NUM_RAND_TESTS @unique class JumpOpCodes(Enum): j_eq = 0 j_ne = 1 j_lt_u = 2 j_lte_u = 3 j_lt_s = 4 j_lte_s = 5 class JumpCtl(Elaboratable): def __init__(self, **kargs): self.alu_flags = Signal(len(ALUFlags), reset_less=True) self.op = Signal(e2s(JumpOpCodes), reset_less=True) self.signed_bits = Signal(2, reset_less=True) self.cond_true = Signal(reset_less=True) # true if jump condition is met ports_in = [self.alu_flags, self.op, self.signed_bits] ports_out = [self.cond_true] self.ports = {'in': ports_in, 'out': ports_out} self.sim = kargs["sim"] if "sim" in kargs else False def elaborate(self, platform=None): m = Module() # dummy sync for simulation only needed if there is no other sequential logic if self.sim == True: dummy = Signal() m.d.sync += dummy.eq(~dummy) # xor the bits if both are positive or negative, this is needed to prevent problems with overflow diff_sign = Signal(reset_less=True) m.d.comb += diff_sign.eq(self.signed_bits.xor()) # checks conditions for ALU in1 and in2 # e.g. in1 less than in2 # this is done by in1 - in2 and reading the negative and zero flags with m.Switch(self.op): with m.Case(JumpOpCodes.j_eq.value): m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.zero.value]) with m.Case(JumpOpCodes.j_ne.value): m.d.comb += self.cond_true.eq(~self.alu_flags[ALUFlags.zero.value]) with m.Case(JumpOpCodes.j_lt_u.value): m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.carry.value] & ~self.alu_flags[ALUFlags.zero.value]) with m.Case(JumpOpCodes.j_lte_u.value): m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.carry.value] | self.alu_flags[ALUFlags.zero.value]) with m.Case(JumpOpCodes.j_lt_s.value): with m.If(diff_sign): m.d.comb += self.cond_true.eq(self.signed_bits[0] & ~self.alu_flags[ALUFlags.zero.value]) # signed bits are different, so use sign as condition to branch with m.Else(): m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.negative.value] & ~self.alu_flags[ALUFlags.zero.value]) with m.Case(JumpOpCodes.j_lte_s.value): with m.If(diff_sign): m.d.comb += self.cond_true.eq(self.signed_bits[0] | self.alu_flags[ALUFlags.zero.value]) with m.Else(): m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.negative.value] | self.alu_flags[ALUFlags.zero.value]) with m.Case(): m.d.comb += self.cond_true.eq(0) return m class DUT(Elaboratable): def __init__(self, **kargs): # DUT will ONLY be used for simulation self.alu = ALU(sim=True) self.jump = JumpCtl(sim=True) def elaborate(self, platform=None): m = Module() m.submodules.alu = self.alu m.submodules.jump = self.jump m.d.comb += self.jump.alu_flags.eq(self.alu.alu_flags) m.d.comb += self.jump.signed_bits.eq(Cat(self.alu.in1[31], self.alu.in2[31])) return m def _init_dut(dut): yield dut.alu.in1.eq(0) yield dut.alu.in2.eq(0) yield dut.alu.op.eq(AluOpCodes.sub) yield Settle() # test jump if equal def test_jump_eq(tests=NUM_RAND_TESTS): dut = DUT(sim=True) # sim=True is not needed, but I am trying to be consistent def proc(): yield from _init_dut(dut) yield dut.jump.op.eq(JumpOpCodes.j_eq.value) for _ in range(tests): in1 = rand_bits_mix(32) in2 = rand_bits_mix(32) yield dut.alu.in1.eq(in1) yield dut.alu.in2.eq(in2) yield from eval() assert (yield dut.jump.cond_true) == (in1 == in2), f"jump_eq failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}" sim(dut, proc) # test jump if not equal def test_jump_ne(tests=NUM_RAND_TESTS): dut = DUT(sim=True) def proc(): yield from _init_dut(dut) yield dut.jump.op.eq(JumpOpCodes.j_ne.value) for _ in range(tests): in1 = rand_bits_mix(32) in2 = rand_bits_mix(32) yield dut.alu.in1.eq(in1) yield dut.alu.in2.eq(in2) yield from eval() assert (yield dut.jump.cond_true) == (in1 != in2), f"jump_ne failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}" sim(dut, proc) # test jump if less than unsigned def test_jump_lt_u(tests=NUM_RAND_TESTS): dut = DUT(sim=True) def proc(): yield from _init_dut(dut) yield dut.jump.op.eq(JumpOpCodes.j_lt_u.value) for _ in range(tests): in1 = rand_bits_mix(32, sus='u') in2 = rand_bits_mix(32, sus='u') yield dut.alu.in1.eq(in1) yield dut.alu.in2.eq(in2) yield from eval() assert (yield dut.jump.cond_true) == (in1 < in2), f"jump_lt_u failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}" sim(dut, proc) # test jump if less than or equal to unsigned def test_jump_lte_u(tests=NUM_RAND_TESTS): dut = DUT(sim=True) def proc(): yield from _init_dut(dut) yield dut.jump.op.eq(JumpOpCodes.j_lte_u.value) for _ in range(tests): in1 = rand_bits_mix(32, sus='u') in2 = rand_bits_mix(32, sus='u') yield dut.alu.in1.eq(in1) yield dut.alu.in2.eq(in2) yield from eval() assert (yield dut.jump.cond_true) == (in1 <= in2), f"jump_lte_u failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}" sim(dut, proc) # test jump if less than signed def test_jump_lt_s(tests=NUM_RAND_TESTS): dut = DUT(sim=True) def proc(): yield from _init_dut(dut) yield dut.jump.op.eq(JumpOpCodes.j_lt_s.value) for _ in range(tests): in1 = rand_bits_mix(32, sus='s') in2 = rand_bits_mix(32, sus='s') yield dut.alu.in1.eq(in1) yield dut.alu.in2.eq(in2) yield from eval() assert (yield dut.jump.cond_true) == (in1 < in2), f"jump_lt_s failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}" sim(dut, proc) # test jump if less than or equal to signed def test_jump_lte_s(tests=NUM_RAND_TESTS): dut = DUT(sim=True) def proc(): yield from _init_dut(dut) yield dut.jump.op.eq(JumpOpCodes.j_lte_s.value) for _ in range(tests): in1 = rand_bits_mix(32, sus='s') in2 = rand_bits_mix(32, sus='s') yield dut.alu.in1.eq(in1) yield dut.alu.in2.eq(in2) yield from eval() assert (yield dut.jump.cond_true) == (in1 <= in2), f"jump_lte_s failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}" sim(dut, proc) if __name__ == '__main__': hdl = JumpCtl() cmd(hdl)