summaryrefslogtreecommitdiff
path: root/hdl/lib/shift_reg.py
blob: 3a217de7af3c90b2b2b3817cd9499abe93919bae (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from amaranth import *
from amaranth.sim import Simulator, Settle, Delay

from hdl.utils import cmd, sim, step
from hdl.lib.in_out_buff import InOutBuff

class ShiftReg(Elaboratable):
    def __init__(self, width):
        self.name = "shift_reg"

        self.load_val = Signal(width, reset=0, reset_less=True)
        self.load = Signal()
        self.reg = Signal(width)
        self.en = Signal()
        self.right_left = Signal()

        ports_in = [self.load_val, self.en, self.load, self.right_left]
        ports_out = [self.reg]
        self.ports = {'in': ports_in, 'out': ports_out}

    def elaborate(self, platform):
        m = Module()

        with m.If(self.load):
            m.d.sync += self.reg.eq(self.load_val)
        with m.Else():
            with m.If(self.en):
                with m.If(self.right_left):
                    m.d.sync += self.reg.eq(self.reg << 1)
                with m.Else():
                    m.d.sync += self.reg.eq(self.reg >> 1)

        return m

def test_shiftreg_right():
    dut = ShiftReg(8)
    def proc():
        val = 0xAB

        yield dut.load_val.eq(val)
        yield dut.en.eq(0)
        yield dut.load.eq(1)
        yield from step()
        yield dut.load.eq(0)
        yield dut.en.eq(1)   

        for _ in range(9):
            yield
            reg_val = yield dut.reg
            assert reg_val == val, f"Incorrect shift ---EXPECTED: {hex(val)}   ---GOT: {hex(reg_val)}"
            val = val >> 1
            
    sim(dut, proc)

def test_shiftreg_left():
    dut = ShiftReg(8)
    def proc():
        val = 0xBD
        yield dut.load_val.eq(val)
        yield dut.en.eq(0)
        yield dut.load.eq(1)
        yield dut.right_left.eq(1)
        yield from step()
        yield dut.load.eq(0)
        yield dut.en.eq(1)

        for _ in range(9):
            yield
            reg_val = yield dut.reg
            assert reg_val == val, f"Incorrect shift ---EXPECTED: {hex(val)}   ---GOT: {hex(reg_val)}"
            val = (val << 1) & 0xff
            
    sim(dut, proc)



if __name__ == '__main__':
    shift_reg = ShiftReg(8)
    cmd(shift_reg)