summaryrefslogtreecommitdiff
path: root/hdl/core/alu.py
diff options
context:
space:
mode:
Diffstat (limited to 'hdl/core/alu.py')
-rw-r--r--hdl/core/alu.py97
1 files changed, 89 insertions, 8 deletions
diff --git a/hdl/core/alu.py b/hdl/core/alu.py
index f8a1a25..8693ba7 100644
--- a/hdl/core/alu.py
+++ b/hdl/core/alu.py
@@ -1,9 +1,11 @@
+from cmath import exp
from amaranth import *
from amaranth.sim import Simulator, Settle, Delay
from enum import Enum, unique
from hdl.utils import *
from hdl.lib.in_out_buff import InOutBuff
+from hdl.config import NUM_RAND_TESTS
@unique
class AluOpCodes(Enum):
@@ -18,10 +20,10 @@ class AluOpCodes(Enum):
lleft = 8
lright = 9
aright = 10
- umult = 11
- smult = 12
- # udiv = 13
- # sdiv = 14
+ multul = 11 # low 32 bits of unsigned multiplication
+ multuh = 12 # high 32 bits of unsigned multiplication
+ multsl = 13 # low 32 bits of signed multiplication
+ multsh = 14 # high 32 bits of signed multiplication
@unique
class ALUFlags(Enum):
@@ -100,11 +102,17 @@ class ALU(Elaboratable):
m.d.comb += tmp2.eq(Cat(0, self.in1).as_signed() >> self.in2[0:5])
m.d.comb += self.tmp.eq(Cat(tmp2[1:33], tmp2[0])) # move shifted bit to carry bit
- with m.Case(AluOpCodes.umult.value):
- m.d.comb += self.tmp.eq(Cat(self.in1[0:16] * self.in2[0:16], 0))
+ with m.Case(AluOpCodes.multul.value):
+ m.d.comb += self.tmp.eq(Cat((self.in1 * self.in2)[:32], 0))
+
+ with m.Case(AluOpCodes.multuh.value):
+ m.d.comb += self.tmp.eq(Cat((self.in1 * self.in2)[32:], 0))
- with m.Case(AluOpCodes.smult.value):
- m.d.comb += self.tmp.eq(Cat(self.in1[0:16].as_signed() * self.in2[0:16].as_signed(), 0))
+ with m.Case(AluOpCodes.multsl.value):
+ m.d.comb += self.tmp.eq(Cat((self.in1.as_signed() * self.in2.as_signed())[:32], 0))
+
+ with m.Case(AluOpCodes.multsh.value):
+ m.d.comb += self.tmp.eq(Cat((self.in1.as_signed() * self.in2.as_signed())[32:], 0))
# bad juju,
@@ -264,6 +272,79 @@ def test_alu_arith_shift_right():
assert 0 == out, f'ERROR: {out} != {0}'
sim(dut, proc)
+# test low unsigned multiply
+def test_alu_mul_low_u(tests=NUM_RAND_TESTS):
+ dut = ALU(sim=True)
+ def proc():
+ yield dut.op.eq(AluOpCodes.multul.value)
+ yield dut.c_in.eq(0)
+
+ for _ in range(tests):
+ in1 = rand_bits_mix(32, sus='u')
+ in2 = rand_bits_mix(32, sus='u')
+ yield dut.in1.eq(in1)
+ yield dut.in2.eq(in2)
+ yield from eval()
+ expected = (in1 * in2) & 0xFFFFFFFF
+ assert (yield dut.out) == expected, f"mul_low_u failed: in1={hex(in1)}, in2={hex(in2)}, out={hex((yield dut.out))}, expected={hex(expected)}"
+
+ sim(dut, proc)
+
+# test high unsigned multiply
+def test_alu_mul_high_u(tests=NUM_RAND_TESTS):
+ dut = ALU(sim=True)
+ def proc():
+ yield dut.op.eq(AluOpCodes.multuh.value)
+ yield dut.c_in.eq(0)
+
+ for _ in range(tests):
+ in1 = rand_bits_mix(32, sus='u')
+ in2 = rand_bits_mix(32, sus='u')
+ yield dut.in1.eq(in1)
+ yield dut.in2.eq(in2)
+ yield from eval()
+ expected = ((in1 * in2) >> 32) & 0xFFFFFFFF
+ assert (yield dut.out) == expected, f"mul_high_u failed: in1={hex(in1)}, in2={hex(in2)}, out={hex((yield dut.out))}, expected={hex(expected)}"
+
+ sim(dut, proc)
+
+# test low signed multiply
+def test_alu_mul_low_s(tests=NUM_RAND_TESTS):
+ dut = ALU(sim=True)
+ def proc():
+ yield dut.op.eq(AluOpCodes.multsl.value)
+ yield dut.c_in.eq(0)
+
+ for _ in range(tests):
+ in1 = rand_bits_mix(32, sus='s')
+ in2 = rand_bits_mix(32, sus='s')
+ yield dut.in1.eq(in1)
+ yield dut.in2.eq(in2)
+ yield from eval()
+ expected = (in1 * in2) & 0xFFFFFFFF
+ assert (yield dut.out) == expected, f"mul_low_s failed: in1={hex(in1)}, in2={hex(in2)}, out={hex((yield dut.out))}, expected={hex(expected)}"
+
+ sim(dut, proc)
+
+# test high signed multiply
+def test_alu_mul_high_s(tests=NUM_RAND_TESTS):
+ dut = ALU(sim=True)
+ def proc():
+ yield dut.op.eq(AluOpCodes.multsh.value)
+ yield dut.c_in.eq(0)
+
+ for _ in range(tests):
+ in1 = rand_bits_mix(32, sus='s')
+ in2 = rand_bits_mix(32, sus='s')
+ yield dut.in1.eq(in1)
+ yield dut.in2.eq(in2)
+ yield from eval()
+ expected = ((in1 * in2) >> 32) & 0xFFFFFFFF
+ assert (yield dut.out) == expected, f"mul_high_s failed: in1={hex(in1)}, in2={hex(in2)}, out={hex((yield dut.out))}, expected={hex(expected)}"
+
+ sim(dut, proc)
+
+
# test unsigned overflow
def test_alu_unsigned_overflow():
dut = ALU(sim=True)