summaryrefslogtreecommitdiff
path: root/hdl/core/jump_ctl.py
blob: 25caf86d33d0a139d9240091e779849669772a84 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from amaranth import *
from amaranth.sim import Settle
from enum import Enum, unique

from hdl.utils import sim, cmd, e2s, rand_bits_mix
from hdl.lib.in_out_buff import InOutBuff   # used for timing analysis

from hdl.core.alu import ALUFlags, ALU, AluOpCodes  #ALUOpCodes is for simulation only, not used in hardware
from hdl.config import NUM_RAND_TESTS

@unique
class JumpOpCodes(Enum):
    j_t = 0
    j_eq = 1
    j_ne = 2
    j_lt_u = 3
    j_lte_u = 4
    j_lt_s = 5
    j_lte_s = 6

class JumpCtl(Elaboratable):
    def __init__(self, **kargs):
        
        self.alu_flags = Signal(len(ALUFlags), reset_less=True)
        self.op = Signal(e2s(JumpOpCodes), reset_less=True)
        self.signed_bits = Signal(2, reset_less=True)

        self.cond_true = Signal(reset_less=True)     # true if jump condition is met

        ports_in = [self.alu_flags, self.op, self.signed_bits]
        ports_out = [self.cond_true]
        self.ports = {'in': ports_in, 'out': ports_out}

        self.sim = kargs["sim"] if "sim" in kargs else False

    def elaborate(self, platform=None):
        m = Module()

        # xor the bits if both are positive or negative, this is needed to prevent problems with overflow
        diff_sign = Signal(reset_less=True)
        m.d.comb += diff_sign.eq(self.signed_bits.xor())

        # checks conditions for ALU in1 and in2
        # e.g.  in1 less than in2
        # this is done by in1 - in2 and reading the negative and zero flags
        with m.Switch(self.op):
            with m.Case(JumpOpCodes.j_t.value):
                m.d.comb += self.cond_true.eq(1)    # always true

            with m.Case(JumpOpCodes.j_eq.value):
                m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.zero.value])

            with m.Case(JumpOpCodes.j_ne.value):
                m.d.comb += self.cond_true.eq(~self.alu_flags[ALUFlags.zero.value])

            with m.Case(JumpOpCodes.j_lt_u.value):
                m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.carry.value] & ~self.alu_flags[ALUFlags.zero.value])

            with m.Case(JumpOpCodes.j_lte_u.value):
                m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.carry.value] | self.alu_flags[ALUFlags.zero.value])

            with m.Case(JumpOpCodes.j_lt_s.value):
                with m.If(diff_sign):
                    m.d.comb += self.cond_true.eq(self.signed_bits[0] & ~self.alu_flags[ALUFlags.zero.value])   # signed bits are different, so use sign as condition to branch
                with m.Else():
                    m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.negative.value] & ~self.alu_flags[ALUFlags.zero.value])
                    
            with m.Case(JumpOpCodes.j_lte_s.value):
                with m.If(diff_sign):
                    m.d.comb += self.cond_true.eq(self.signed_bits[0] | self.alu_flags[ALUFlags.zero.value])
                with m.Else():
                    m.d.comb += self.cond_true.eq(self.alu_flags[ALUFlags.negative.value] | self.alu_flags[ALUFlags.zero.value])

            with m.Case():
                m.d.comb += self.cond_true.eq(0)
        
        return m


class DUT(Elaboratable):
    def __init__(self, **kargs):    # DUT will ONLY be used for simulation
        self.alu = ALU(sim=True)
        self.jump = JumpCtl(sim=True)

    def elaborate(self, platform=None):
        m = Module()
        m.submodules.alu = self.alu
        m.submodules.jump = self.jump

        m.d.comb += self.jump.alu_flags.eq(self.alu.alu_flags)
        m.d.comb += self.jump.signed_bits.eq(Cat(self.alu.in1[31], self.alu.in2[31]))
        return m

def _init_dut(dut):
    yield dut.alu.in1.eq(0)
    yield dut.alu.in2.eq(0)
    yield dut.alu.op.eq(AluOpCodes.sub)
    yield Settle()
    

# test jump if equal
def test_jump_eq(tests=NUM_RAND_TESTS):
    dut = DUT(sim=True)     # sim=True is not needed, but I am trying to be consistent
    def proc():
        yield from _init_dut(dut)
        yield dut.jump.op.eq(JumpOpCodes.j_eq.value)

        for _ in range(tests):
            in1 = rand_bits_mix(32)
            in2 = rand_bits_mix(32)
            yield dut.alu.in1.eq(in1)
            yield dut.alu.in2.eq(in2)
            yield Settle()
            assert (yield dut.jump.cond_true) == (in1 == in2), f"jump_eq failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"

    sim(dut, proc, sync=False)

# test jump if not equal
def test_jump_ne(tests=NUM_RAND_TESTS):
        dut = DUT(sim=True)
        def proc():
            yield from _init_dut(dut)
            yield dut.jump.op.eq(JumpOpCodes.j_ne.value)
    
            for _ in range(tests):
                in1 = rand_bits_mix(32)
                in2 = rand_bits_mix(32)
                yield dut.alu.in1.eq(in1)
                yield dut.alu.in2.eq(in2)
                yield Settle()
                assert (yield dut.jump.cond_true) == (in1 != in2), f"jump_ne failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"
    
        sim(dut, proc, sync=False)

# test jump if less than unsigned
def test_jump_lt_u(tests=NUM_RAND_TESTS):
            dut = DUT(sim=True)
            def proc():
                yield from _init_dut(dut)
                yield dut.jump.op.eq(JumpOpCodes.j_lt_u.value)
        
                for _ in range(tests):
                    in1 = rand_bits_mix(32, sus='u')
                    in2 = rand_bits_mix(32, sus='u')
                    yield dut.alu.in1.eq(in1)
                    yield dut.alu.in2.eq(in2)
                    yield Settle()
                    assert (yield dut.jump.cond_true) == (in1 < in2), f"jump_lt_u failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"
        
            sim(dut, proc, sync=False)

# test jump if less than or equal to unsigned
def test_jump_lte_u(tests=NUM_RAND_TESTS):
        dut = DUT(sim=True)
        def proc():
            yield from _init_dut(dut)
            yield dut.jump.op.eq(JumpOpCodes.j_lte_u.value)
    
            for _ in range(tests):
                in1 = rand_bits_mix(32, sus='u')
                in2 = rand_bits_mix(32, sus='u')
                yield dut.alu.in1.eq(in1)
                yield dut.alu.in2.eq(in2)
                yield Settle()
                assert (yield dut.jump.cond_true) == (in1 <= in2), f"jump_lte_u failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"
    
        sim(dut, proc, sync=False)

# test jump if less than signed
def test_jump_lt_s(tests=NUM_RAND_TESTS):
    dut = DUT(sim=True)
    def proc():
        yield from _init_dut(dut)
        yield dut.jump.op.eq(JumpOpCodes.j_lt_s.value)

        for _ in range(tests):
            in1 = rand_bits_mix(32, sus='s')
            in2 = rand_bits_mix(32, sus='s')
            yield dut.alu.in1.eq(in1)
            yield dut.alu.in2.eq(in2)
            yield Settle()
            assert (yield dut.jump.cond_true) == (in1 < in2), f"jump_lt_s failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"

    sim(dut, proc, sync=False)

# test jump if less than or equal to signed
def test_jump_lte_s(tests=NUM_RAND_TESTS):
    dut = DUT(sim=True)
    def proc():
        yield from _init_dut(dut)
        yield dut.jump.op.eq(JumpOpCodes.j_lte_s.value)

        for _ in range(tests):
            in1 = rand_bits_mix(32, sus='s')
            in2 = rand_bits_mix(32, sus='s')
            yield dut.alu.in1.eq(in1)
            yield dut.alu.in2.eq(in2)
            yield Settle()
            assert (yield dut.jump.cond_true) == (in1 <= in2), f"jump_lte_s failed: in1={hex(in1)}, in2={hex(in2)}, cond_true={(yield dut.jump.cond_true)}"

    sim(dut, proc, sync=False)

       
if __name__ == '__main__':
    hdl = JumpCtl()
    cmd(hdl)