from amaranth import * from amaranth.sim import Simulator, Settle, Delay from enum import Enum, unique from hdl.utils import sim, e2s, cmd, rand_bits_mix from hdl.lib.in_out_buff import InOutBuff from hdl.config import NUM_RAND_TESTS @unique class AluOpCodes(Enum): add = 0 addc = 1 sub = 2 subc = 3 bit_and = 4 bit_or = 5 bit_xor = 6 bit_nor = 7 lleft = 8 lright = 9 aright = 10 multul = 11 # low 32 bits of unsigned multiplication multuh = 12 # high 32 bits of unsigned multiplication multsl = 13 # low 32 bits of signed multiplication multsh = 14 # high 32 bits of signed multiplication @unique class ALUFlags(Enum): zero = 0 carry = 1 overflow = 2 negative = 3 class ALU(Elaboratable): def __init__(self, **kargs): self.in1 = Signal(32, reset_less=True) self.in2 = Signal(32, reset_less=True) self.c_in = Signal(1) self.op = Signal(e2s(AluOpCodes), reset_less=True) self.tmp = Signal(33, reset_less=True) self.c_out = Signal(1, reset_less=True) self.overflow = Signal(1, reset_less=True) self.zero = Signal(1, reset_less=True) self.neg = Signal(1, reset_less=True) self.alu_flags = Signal(len(ALUFlags), reset_less=True) # alu flags is one hot self.out = Signal(32, reset_less=True) self.sim = kargs.get('sim', False) ports_in = [self.in1, self.in2, self.op, self.c_in] ports_out = [self.c_in, self.out, self.alu_flags] self.ports = {'in': ports_in, 'out': ports_out} def elaborate(self, platform=None): m = Module() with m.Switch(self.op): with m.Case(AluOpCodes.add.value): m.d.comb += self.tmp.eq(self.in1 + self.in2) with m.Case(AluOpCodes.addc.value): m.d.comb += self.tmp.eq(self.in1 + self.in2 + self.c_in) with m.Case(AluOpCodes.sub.value): m.d.comb += self.tmp.eq(self.in1 - self.in2) with m.Case(AluOpCodes.subc.value): m.d.comb += self.tmp.eq(self.in1 + ~self.in2 + self.c_in) with m.Case(AluOpCodes.bit_and.value): m.d.comb += self.tmp.eq(Cat(self.in1 & self.in2, 0)) with m.Case(AluOpCodes.bit_or.value): m.d.comb += self.tmp.eq(Cat(self.in1 | self.in2, 0)) with m.Case(AluOpCodes.bit_xor.value): m.d.comb += self.tmp.eq(Cat(self.in1 ^ self.in2, 0)) with m.Case(AluOpCodes.bit_nor.value): m.d.comb += self.tmp.eq(Cat(~(self.in1 | self.in2), 0)) with m.Case(AluOpCodes.lleft.value): m.d.comb += self.tmp.eq(Cat(self.in1, 0) << self.in2[0:5]) with m.Case(AluOpCodes.lright.value): tmp2 = Signal(33) m.d.comb += tmp2.eq(Cat(0, self.in1) >> self.in2) m.d.comb += self.tmp.eq(Cat(tmp2[1:33], tmp2[0])) # move shifted bit to carry bit with m.Case(AluOpCodes.aright.value): tmp2 = Signal(33) m.d.comb += tmp2.eq(Cat(0, self.in1).as_signed() >> self.in2) m.d.comb += self.tmp.eq(Cat(tmp2[1:33], tmp2[0])) # move shifted bit to carry bit with m.Case(AluOpCodes.multul.value): m.d.comb += self.tmp.eq(Cat((self.in1 * self.in2)[:32], 0)) with m.Case(AluOpCodes.multuh.value): m.d.comb += self.tmp.eq(Cat((self.in1 * self.in2)[32:], 0)) with m.Case(AluOpCodes.multsl.value): m.d.comb += self.tmp.eq(Cat((self.in1.as_signed() * self.in2.as_signed())[:32], 0)) with m.Case(AluOpCodes.multsh.value): m.d.comb += self.tmp.eq(Cat((self.in1.as_signed() * self.in2.as_signed())[32:], 0)) # bad juju, # TODO: come back and check this will work # with m.Case(AluOpCodes.udiv.value): # m.d.comb += self.tmp.eq(Cat(self.in1 // self.in2, 0)) # with m.Case(AluOpCodes.sdiv.value): # m.d.comb += self.tmp.eq(self.in1.as_signed() // self.in2.as_signed()) # for some reason I have not confirmed, signed div can yield a 33 bit number, acording to amaranth with m.Case(): m.d.comb += self.tmp.eq(0) m.d.comb += self.c_out.eq(self.tmp[32]) m.d.comb += self.overflow.eq(self.tmp[32] ^ self.tmp[31]) m.d.comb += self.neg.eq(self.out[31]) m.d.comb += self.zero.eq(self.out == 0) m.d.comb += self.alu_flags[ALUFlags.zero.value].eq(self.zero) m.d.comb += self.alu_flags[ALUFlags.carry.value].eq(self.c_out) m.d.comb += self.alu_flags[ALUFlags.overflow.value].eq(self.overflow) m.d.comb += self.alu_flags[ALUFlags.negative.value].eq(self.neg) m.d.comb += self.out.eq(self.tmp[0:32]) return m def sub_proc(dut, val1, val2, c_in=0): yield dut.in1.eq(val1) yield dut.in2.eq(val2) yield dut.c_in.eq(c_in) yield Settle() # test addition def test_alu_add(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.add.value) yield from sub_proc(dut, 27, 13) out = yield dut.out assert 27 + 13 == (out), f'ERROR: {out} != {27 + 13}' sim(dut, proc, sync=False) # test addition with carry def test_alu_addc(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.addc.value) yield from sub_proc(dut, 11, 43, 1) out = yield dut.out.as_signed() assert 11 + 43 + 1 == out, f'ERROR: {out} != {11 + 43 + 1}' sim(dut, proc, sync=False) # test subtraction def test_alu_sub(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.sub.value) yield from sub_proc(dut, 25, 13) out = yield dut.out assert 25 - 13 == out, f'ERROR: {out} != {25 - 13}' sim(dut, proc, sync=False) # test subtraction with carry def test_alu_subc_0(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.subc.value) yield from sub_proc(dut, 25, -13, 0) out = yield dut.out.as_signed() assert 25 + 13 -1 +0 == out, f'ERROR: {out} != {25 + 13 -1 +0}' sim(dut, proc, sync=False) # test subtraction with carry def test_alu_subc_1(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.subc.value) yield from sub_proc(dut, 25, -13, 1) out = yield dut.out.as_signed() assert 25 + 13 -1 +1 == out, f'ERROR: {out} != {25 + 13 -1 +1}' sim(dut, proc, sync=False) # test binary and def test_alu_and(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.bit_and.value) yield from sub_proc(dut, 0b10101011, 0b01010101) out = yield dut.out assert 0b00000001 == out, f'ERROR: {out} != {0b00000001}' sim(dut, proc, sync=False) # test binary or def test_alu_or(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.bit_or.value) yield from sub_proc(dut, 0b10101011, 0b01000101) out = yield dut.out assert 0b11101111 == out, f'ERROR: {out} != {0b11101111}' sim(dut, proc, sync=False) # test binary nor def test_alu_nor(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.bit_nor.value) yield from sub_proc(dut, 0b10001011, 0b01000101) out = yield dut.out assert 0b11111111111111111111111100110000 == out, f'ERROR: {bin(out)} != {bin(0b11111111111111111111111100110000)}' sim(dut, proc, sync=False) # test binary xor def test_alu_xor(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.bit_xor.value) yield from sub_proc(dut, 0b10001011, 0b01000101) out = yield dut.out assert 0b11001110 == out, f'ERROR: {out} != {0b11001110}' sim(dut, proc, sync=False) # test logical shift left def test_alu_logic_shift_left(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.lleft.value) yield from sub_proc(dut, 0b10001011, 25) # shift left by 5 out = yield dut.out assert 0b00010110000000000000000000000000 == out, f'ERROR: {bin(out)} != {bin(0b00010110000000000000000000000000)}' out = yield dut.c_out assert 1 == out, f'ERROR: {out} != {1}' sim(dut, proc, sync=False) # test logical shift right def test_alu_logic_shift_right(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.lright.value) yield from sub_proc(dut, 0b10001011, 4) # shift right by 5 out = yield dut.out assert 0b1000 == out, f'ERROR: {bin(out)} != {bin(0b1000)}' out = yield dut.c_out assert 1 == out, f'ERROR: {out} != {1}' sim(dut, proc, sync=False) # test arithmetic shift right def test_alu_arith_shift_right(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.aright.value) yield from sub_proc(dut, 0x80001234, 4) # shift right by 4 out = yield dut.out assert 0xF8000123 == out, f'ERROR: {out} != {0xF8000123}' out = yield dut.c_out assert 0 == out, f'ERROR: {out} != {0}' sim(dut, proc, sync=False) # test low unsigned multiply def test_alu_mul_low_u(tests=NUM_RAND_TESTS): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.multul.value) yield dut.c_in.eq(0) for _ in range(tests): in1 = rand_bits_mix(32, sus='u') in2 = rand_bits_mix(32, sus='u') yield dut.in1.eq(in1) yield dut.in2.eq(in2) yield Settle() expected = (in1 * in2) & 0xFFFFFFFF assert (yield dut.out) == expected, f"mul_low_u failed: in1={hex(in1)}, in2={hex(in2)}, out={hex((yield dut.out))}, expected={hex(expected)}" sim(dut, proc, sync=False) # test high unsigned multiply def test_alu_mul_high_u(tests=NUM_RAND_TESTS): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.multuh.value) yield dut.c_in.eq(0) for _ in range(tests): in1 = rand_bits_mix(32, sus='u') in2 = rand_bits_mix(32, sus='u') yield dut.in1.eq(in1) yield dut.in2.eq(in2) yield Settle() expected = ((in1 * in2) >> 32) & 0xFFFFFFFF assert (yield dut.out) == expected, f"mul_high_u failed: in1={hex(in1)}, in2={hex(in2)}, out={hex((yield dut.out))}, expected={hex(expected)}" sim(dut, proc, sync=False) # test low signed multiply def test_alu_mul_low_s(tests=NUM_RAND_TESTS): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.multsl.value) yield dut.c_in.eq(0) for _ in range(tests): in1 = rand_bits_mix(32, sus='s') in2 = rand_bits_mix(32, sus='s') yield dut.in1.eq(in1) yield dut.in2.eq(in2) yield Settle() expected = (in1 * in2) & 0xFFFFFFFF assert (yield dut.out) == expected, f"mul_low_s failed: in1={hex(in1)}, in2={hex(in2)}, out={hex((yield dut.out))}, expected={hex(expected)}" sim(dut, proc, sync=False) # test high signed multiply def test_alu_mul_high_s(tests=NUM_RAND_TESTS): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.multsh.value) yield dut.c_in.eq(0) for _ in range(tests): in1 = rand_bits_mix(32, sus='s') in2 = rand_bits_mix(32, sus='s') yield dut.in1.eq(in1) yield dut.in2.eq(in2) yield Settle() expected = ((in1 * in2) >> 32) & 0xFFFFFFFF assert (yield dut.out) == expected, f"mul_high_s failed: in1={hex(in1)}, in2={hex(in2)}, out={hex((yield dut.out))}, expected={hex(expected)}" sim(dut, proc, sync=False) # test unsigned overflow def test_alu_unsigned_overflow(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.add.value) yield from sub_proc(dut, 0xFFFFFFFF, 1) # add 1 to 0xFFFFFFFF out = yield dut.overflow assert out == 1, f'ERROR: {out} != {1}' out = yield dut.c_out assert out == 1, f'ERROR: {out} != {1}' sim(dut, proc, sync=False) # test unsigned underflow def test_alu_unsigned_underflow(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.add.value) yield from sub_proc(dut, 0, -1) # subtract 1 from 0 out = yield dut.overflow assert out == 1, f'ERROR: {out} != {1}' out = yield dut.c_out assert out == 0, f'ERROR: {out} != {0}' sim(dut, proc, sync=False) # test zero def test_alu_zero_0(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.add.value) yield from sub_proc(dut, 0, 1) # add 0 to 0 out = yield dut.zero assert out == 0, f'ERROR: {out} != {0}' sim(dut, proc, sync=False) # test zero def test_alu_zero_1(): dut = ALU(sim=True) def proc(): yield dut.op.eq(AluOpCodes.add.value) yield from sub_proc(dut, 0, 0) # add 0 to 0 out = yield dut.zero assert out == 1, f'ERROR: {out} != {1}' sim(dut, proc, sync=False) if __name__ == '__main__': hdl = InOutBuff(ALU()) cmd(hdl)