summaryrefslogtreecommitdiff
path: root/hdl/core/alu.py
diff options
context:
space:
mode:
Diffstat (limited to 'hdl/core/alu.py')
-rw-r--r--hdl/core/alu.py269
1 files changed, 269 insertions, 0 deletions
diff --git a/hdl/core/alu.py b/hdl/core/alu.py
new file mode 100644
index 0000000..049c8af
--- /dev/null
+++ b/hdl/core/alu.py
@@ -0,0 +1,269 @@
+from amaranth import *
+from amaranth.sim import Simulator, Settle, Delay
+from enum import Enum, unique
+
+from hdl.utils import cmd, DubbleBuff
+
+@unique
+class AluOpCodes(Enum):
+ add = 0
+ addc = 1
+ sub = 2
+ subc = 3
+ bit_and = 4
+ bit_or = 5
+ bit_xor = 6
+ bit_nor = 7
+ lleft = 8
+ lright = 9
+ aright = 10
+ set_bit = 11
+ clear_bit = 12
+ umult = 13
+ smult = 14
+ udiv = 15
+ sdiv = 16
+
+class ALU(Elaboratable):
+ def __init__(self, **kargs):
+ self.in1 = Signal(32, reset_less=True)
+ self.in2 = Signal(32, reset_less=True)
+ self.c_in = Signal(1)
+ self.op = Signal(4, reset_less=True)
+
+ self.tmp = Signal(33, reset_less=True)
+
+ self.c_out = Signal(1, reset_less=True)
+ self.overflow = Signal(1, reset_less=True)
+ self.zero = Signal(1, reset_less=True)
+ self.neg = Signal(1, reset_less=True)
+ self.odd = Signal(1, reset_less=True)
+
+ self.out = Signal(32, reset_less=True)
+
+ self.sim = kargs["sim"] if "sim" in kargs else None
+
+ ports_in = [self.in1, self.in2, self.op, self.c_in]
+ ports_out = [self.c_in, self.out, self.c_out, self.overflow, self.zero, self.neg, self.odd]
+ self.ports = {'in': ports_in, 'out': ports_out}
+
+ def elaborate(self, platform=None):
+ m = Module()
+
+ # dummy sync for simulation only
+ if self.sim == True:
+ dummy = Signal()
+ m.d.sync += dummy.eq(~dummy)
+
+ with m.Switch(self.op):
+ with m.Case(AluOpCodes.add.value):
+ m.d.comb += self.tmp.eq(self.in1 + self.in2)
+
+ with m.Case(AluOpCodes.addc.value):
+ m.d.comb += self.tmp.eq(self.in1 + self.in2 + self.c_in)
+
+ with m.Case(AluOpCodes.sub.value):
+ m.d.comb += self.tmp.eq(self.in1 - self.in2)
+
+ with m.Case(AluOpCodes.subc.value):
+ m.d.comb += self.tmp.eq(self.in1 + (~self.in2 + self.c_in))
+
+ with m.Case(AluOpCodes.bit_and.value):
+ m.d.comb += self.tmp.eq(Cat(self.in1 & self.in2, 0))
+
+ with m.Case(AluOpCodes.bit_or.value):
+ m.d.comb += self.tmp.eq(Cat(self.in1 | self.in2, 0))
+
+ with m.Case(AluOpCodes.bit_xor.value):
+ m.d.comb += self.tmp.eq(Cat(self.in1 ^ self.in2, 0))
+
+ with m.Case(AluOpCodes.bit_nor.value):
+ m.d.comb += self.tmp.eq(Cat(~(self.in1 | self.in2), 0))
+
+ with m.Case(AluOpCodes.lleft.value):
+ m.d.comb += self.tmp.eq(Cat(self.in1, 0) << self.in2[0:5])
+
+ with m.Case(AluOpCodes.lright.value):
+ tmp2 = Signal(33)
+ m.d.comb += tmp2.eq(Cat(0, self.in1) >> self.in2[0:5])
+ m.d.comb += self.tmp.eq(Cat(tmp2[1:33], tmp2[0])) # move shifted bit to carry bit
+
+ with m.Case(AluOpCodes.aright.value):
+ tmp2 = Signal(33)
+ m.d.comb += tmp2.eq(Cat(0, self.in1).as_signed() >> self.in2[0:5])
+ m.d.comb += self.tmp.eq(Cat(tmp2[1:33], tmp2[0])) # move shifted bit to carry bit
+
+ with m.Case(AluOpCodes.set_bit.value):
+ m.d.comb += self.tmp.eq(Cat(self.in1 | (1 << self.in2[0:5]), 0))
+
+ with m.Case(AluOpCodes.clear_bit.value):
+ m.d.comb += self.tmp.eq(Cat(self.in1 & ~(1 << self.in2[0:5]), 0))
+
+ with m.Case(AluOpCodes.umult.value):
+ m.d.comb += self.tmp.eq(Cat(self.in1[0:16] * self.in2[0:16], 0))
+
+ with m.Case(AluOpCodes.smult.value):
+ m.d.comb += self.tmp.eq(Cat(self.in1[0:16].as_signed() * self.in2[0:16].as_signed(), 0))
+
+
+ # bad juju,
+ # TODO: come back and check this will work
+ # with m.Case(AluOpCodes.udiv.value):
+ # m.d.comb += self.tmp.eq(Cat(self.in1 // self.in2, 0))
+
+ # with m.Case(AluOpCodes.sdiv.value):
+ # m.d.comb += self.tmp.eq(self.in1.as_signed() // self.in2.as_signed()) # for some reason I have not confirmed, signed div can yield a 33 bit number, acording to amaranth
+
+ with m.Case():
+ m.d.comb += self.tmp.eq(0)
+
+ m.d.comb += self.c_out.eq(self.tmp[32])
+ m.d.comb += self.overflow.eq(self.tmp[32] ^ self.tmp[31])
+ m.d.comb += self.out.eq(self.tmp[0:32])
+ m.d.comb += self.neg.eq(self.out[31])
+ m.d.comb += self.zero.eq(self.out == 0)
+ m.d.comb += self.odd.eq(self.out.xor()) # 1 if odd number of bits, 0 if even
+
+ return m
+
+def test_alu(filename="alu.vcd"):
+ dut = ALU(sim=True)
+
+ def proc1():
+ def sub_proc(val1, val2, c_in=0):
+ yield dut.in1.eq(val1)
+ yield dut.in2.eq(val2)
+ yield dut.c_in.eq(c_in)
+ yield
+ yield Settle()
+
+ # test addition
+ yield dut.op.eq(AluOpCodes.add.value)
+ yield from sub_proc(27, 13)
+ out = yield dut.out
+ assert 27 + 13 == (out), f'ERROR: {out} != {27 + 13}'
+
+ # test addition with carry
+ yield dut.op.eq(AluOpCodes.addc.value)
+ yield from sub_proc(11, 43, 1)
+ out = yield dut.out.as_signed()
+ assert 11 + 43 + 1 == out, f'ERROR: {out} != {11 + 43 + 1}'
+
+ # test subtraction
+ yield dut.op.eq(AluOpCodes.sub.value)
+ yield from sub_proc(25, 13)
+ out = yield dut.out
+ assert 25 - 13 == out, f'ERROR: {out} != {25 - 13}'
+
+ # test subtraction with carry
+ yield dut.op.eq(AluOpCodes.subc.value)
+ yield from sub_proc(25, -13, 0)
+ out = yield dut.out.as_signed()
+ assert 25 + 13 -1 +0 == out, f'ERROR: {out} != {25 + 13 -1 +0}'
+
+ # test subtraction with carry
+ yield dut.op.eq(AluOpCodes.subc.value)
+ yield from sub_proc(25, -13, 1)
+ out = yield dut.out.as_signed()
+ assert 25 + 13 -1 +1 == out, f'ERROR: {out} != {25 + 13 -1 +1}'
+
+ # test binary and
+ yield dut.op.eq(AluOpCodes.bit_and.value)
+ yield from sub_proc(0b10101011, 0b01010101)
+ out = yield dut.out
+ assert 0b00000001 == out, f'ERROR: {out} != {0b00000001}'
+
+ # test binary or
+ yield dut.op.eq(AluOpCodes.bit_or.value)
+ yield from sub_proc(0b10101011, 0b01000101)
+ out = yield dut.out
+ assert 0b11101111 == out, f'ERROR: {out} != {0b11101111}'
+
+ # test binary nor
+ yield dut.op.eq(AluOpCodes.bit_nor.value)
+ yield from sub_proc(0b10001011, 0b01000101)
+ out = yield dut.out
+ assert 0b11111111111111111111111100110000 == out, f'ERROR: {bin(out)} != {bin(0b11111111111111111111111100110000)}'
+
+ # test binary xor
+ yield dut.op.eq(AluOpCodes.bit_xor.value)
+ yield from sub_proc(0b10001011, 0b01000101)
+ out = yield dut.out
+ assert 0b11001110 == out, f'ERROR: {out} != {0b11001110}'
+
+ # test logical shift left
+ yield dut.op.eq(AluOpCodes.lleft.value)
+ yield from sub_proc(0b10001011, 25) # shift left by 5
+ out = yield dut.out
+ assert 0b00010110000000000000000000000000 == out, f'ERROR: {bin(out)} != {bin(0b00010110000000000000000000000000)}'
+ out = yield dut.c_out
+ assert 1 == out, f'ERROR: {out} != {1}'
+
+ # test logical shift right
+ yield dut.op.eq(AluOpCodes.lright.value)
+ yield from sub_proc(0b10001011, 4) # shift right by 5
+ out = yield dut.out
+ assert 0b1000 == out, f'ERROR: {bin(out)} != {bin(0b1000)}'
+ out = yield dut.c_out
+ assert 1 == out, f'ERROR: {out} != {1}'
+
+ # test aligned shift right
+ yield dut.op.eq(AluOpCodes.aright.value)
+ yield from sub_proc(0x80001234, 4) # shift right by 4
+ out = yield dut.out
+ assert 0xF8000123 == out, f'ERROR: {out} != {0xF8000123}'
+ out = yield dut.c_out
+ assert 0 == out, f'ERROR: {out} != {0}'
+
+ # test unsigned overflow
+ yield dut.op.eq(AluOpCodes.add.value)
+ yield from sub_proc(0xFFFFFFFF, 1) # add 1 to 0xFFFFFFFF
+ out = yield dut.overflow
+ assert out == 1, f'ERROR: {out} != {1}'
+ out = yield dut.c_out
+ assert out == 1, f'ERROR: {out} != {1}'
+
+ # test unsigned underflow
+ yield dut.op.eq(AluOpCodes.add.value)
+ yield from sub_proc(0, -1) # subtract 1 from 0
+ out = yield dut.overflow
+ assert out == 1, f'ERROR: {out} != {1}'
+ out = yield dut.c_out
+ assert out == 0, f'ERROR: {out} != {0}'
+
+ # test zero
+ yield dut.op.eq(AluOpCodes.add.value)
+ yield from sub_proc(0, 0) # add 0 to 0
+ out = yield dut.zero
+ assert out == 1, f'ERROR: {out} != {1}'
+
+ # test zero
+ yield dut.op.eq(AluOpCodes.add.value)
+ yield from sub_proc(0, 1) # add 0 to 0
+ out = yield dut.zero
+ assert out == 0, f'ERROR: {out} != {0}'
+
+ # test odd
+ yield dut.op.eq(AluOpCodes.add.value)
+ yield from sub_proc(0, 0xAAAAAAAA) # add 0 to 0
+ out = yield dut.odd
+ assert out == 0, f'ERROR: {out} != {0}'
+
+ # test odd
+ yield dut.op.eq(AluOpCodes.add.value)
+ yield from sub_proc(0, 0xAAAAAAAB) # add 0 to 0
+ out = yield dut.odd
+ assert out == 1, f'ERROR: {out} != {1}'
+
+
+ sim = Simulator(dut)
+ sim.add_clock(1e-6)
+ sim.add_sync_process(proc1)
+
+ with sim.write_vcd(filename):
+ sim.run()
+
+
+if __name__ == '__main__':
+ hdl = DubbleBuff(ALU())
+ cmd(hdl, test_alu)