from enum import unique, Enum from amaranth import * from amaranth.sim import Settle from hdl.lib.in_out_buff import InOutBuff from hdl.utils import cmd, sim, step from hdl.core.alu import ALUFlags @unique class RegAddr(Enum): zx = 0 ax = 1 bx = 2 cx = 3 dx = 4 ex = 5 fx = 6 gx = 7 hx = 8 ip = 9 sp = 10 flg = 11 cs0 = 12 cs1 = 13 cs2 = 14 pog = 15 @unique class RegFLG(Enum): carry = ALUFlags.carry.value zero = ALUFlags.zero.value negative = ALUFlags.negative.value overflow = ALUFlags.overflow.value int_en = 16 user_mode = 17 page_en = 18 halt = 31 # halt/pause the processor, depends on if interupts are enabled class Reg(Elaboratable): def __init__(self, **kargs): # sim is only for modularity, does nothing for this ################ INPUTS ################ self.wr_en = Signal(1) self.stall = Signal(1) # stall instruction pointer increment self.rd = Signal(32) self.rd_addr = Signal(4) self.rs1_addr = Signal(4) self.rs2_addr = Signal(4) self.alu_flgs = Signal(len(ALUFlags)) # flags from alu # these signals should be used one hot only self.int_sig = Signal(1) # unconditional interrupt self.iret = Signal(1) # return from interrupt self.call = Signal(1) # call subroutine, save return address self.jump = Signal(1) # jump, do not save return address ################ OUTPUTS ################ self.rs1 = Signal(32) # read data 1 self.rs2 = Signal(32) # read data 2 self.int_en = Signal(1) #interupt enable output signal to control unit self.user_mode = Signal(1) # user mode output signal to control unit ################ INTERNAL SIGNALS ################ self._wr_alu_flg = Signal(1) self._inc_ip = Signal(1) self.zx = Signal(32) #0 self.ax = Signal(32) #1 self.bx = Signal(32) #2 self.bx = Signal(32) #3 self.cx = Signal(32) #4 self.dx = Signal(32) #5 self.ex = Signal(32) #6 self.fx = Signal(32) #7 self.gx = Signal(32) #8 self.hx = Signal(32) #9 self.ip = Signal(32) #10 self.sp = Signal(32) #11 self.flg = Signal(32) #12 self.cs0 = Signal(32) #13 self.cs1 = Signal(32) #14 self.cs2 = Signal(32) #15 self.pog = Signal(32) #16 # this is a shortcut for internal testing, use enum RegFLG if using outside of this module setattr(self.flg, 'c', self.flg[RegFLG.carry.value]) setattr(self.flg, 'ov', self.flg[RegFLG.overflow.value]) setattr(self.flg, 'z', self.flg[RegFLG.zero.value]) setattr(self.flg, 'n', self.flg[RegFLG.negative.value]) setattr(self.flg, 'int', self.flg[RegFLG.int_en.value]) setattr(self.flg, 'user_mode', self.flg[RegFLG.user_mode.value]) setattr(self.flg, 'page_en', self.flg[RegFLG.page_en.value]) setattr(self.flg, 'halt', self.flg[RegFLG.halt.value]) reg_list = [self.zx, self.ax, self.bx, self.cx, self.dx, self.ex, self.fx, self.gx, self.hx, self.ip, self.sp, self.flg, self.cs0, self.cs1, self.cs2, self.pog] for idx, reg in enumerate(reg_list): setattr(reg, 'idx', idx) # set idx attribute to each register self.reg_arr = Array(reg_list) ports_in = [self.wr_en, self.stall, self.alu_flgs, self.int_sig, self.iret, self.call, self.jump, self.rd_addr, self.stall, self.rd, self.rs1_addr, self.rs2_addr] ports_out = [self.int_en, self.user_mode, self.rs1, self.rs2, self.ip] self.ports = {'in': ports_in, 'out': ports_out} def elaborate(self, platform=None): m = Module() # output signals to control unit m.d.comb += self.int_en.eq(self.flg.int) m.d.comb += self.user_mode.eq(self.flg.user_mode & ~self.int_sig) # defualt value of internal signals m.d.comb += self._wr_alu_flg.eq(1) m.d.comb += self._inc_ip.eq(1) with m.If(self.int_sig): m.d.comb += self._inc_ip.eq(0) # do not increment instruction pointer on interrupt m.d.sync += self.ip.eq(self.rd) # get new ip from rd m.d.sync += self.cs0.eq(self.ip) # save return address in cs0 m.d.sync += self.cs1.eq(self.sp) # swap sp and cs1 m.d.sync += self.sp.eq(self.cs1) m.d.sync += self.cs2.eq(self.flg) # save old flags to cs2 m.d.sync += self.flg.user_mode.eq(0) # set to system mode or iret cannot be used m.d.sync += self.flg.int.eq(0) # clear int flag, essential because another interrupt can be triggered without this with m.Elif(self.iret & ~self.flg.user_mode): # don't allow iret in user mode, that would be uhhh.... BAD m.d.comb += self._inc_ip.eq(0) m.d.sync += self.ip.eq(self.cs0) #copy cs0 to ip m.d.sync += self.sp.eq(self.cs1) #swap back sp and cs1 m.d.sync += self.cs1.eq(self.sp) m.d.sync += self.flg.eq(self.cs2) #restore flags from cs2 with m.Elif(self.call): m.d.comb += self._inc_ip.eq(0) m.d.sync += self.ip.eq(self.cs0) # swap ip and cs0 m.d.sync += self.cs0.eq(self.ip) with m.Elif(self.jump): m.d.comb += self._inc_ip.eq(0) m.d.sync += self.ip.eq(self.cs0) # only copy cs0 to ip, do not swap with m.Elif(self.wr_en): with m.Switch(self.rd_addr): with m.Case(self.zx.idx): # do not write to zero register pass with m.Case(self.ip.idx): #do not directly write to ip register pass with m.Case(self.cs1.idx): with m.If(~self.flg.user_mode): m.d.sync += self.cs1.eq(self.rd) # do not allow writing to system stack pointer in user mode with m.Case(self.pog.idx): # do not write to pog register in user mode with m.If(~self.flg.user_mode): m.d.sync += self.pog.eq(self.rd) with m.Case(self.flg.idx): # mask top half of register to prevent writing to flags in user mode with m.If(~self.flg.user_mode): m.d.sync += self.flg.eq(self.rd) # system mode, full control with m.Else(): m.d.sync += self.flg.eq(Cat(self.rd[:16], self.flg[16:])) # usermode can only effect lower 16 bits # don't update flags from alu m.d.comb += self._wr_alu_flg.eq(0) with m.Case(): m.d.sync += self.reg_arr[self.rd_addr].eq(self.rd) with m.If(self._wr_alu_flg): # alu flags are written only if write enabled and not writing to flags register m.d.sync += self.flg.eq(Cat(self.alu_flgs, self.flg[len(self.alu_flgs):])) with m.If(self._inc_ip & ~self.stall): # increment ip if not directly writing to ip register m.d.sync += self.ip.eq(self.ip + 4) ### Combination signal outputs ### with m.Switch(self.rs1_addr): with m.Case(self.flg.idx): m.d.comb += self.rs1.eq(self.flg & Cat(Const(0xFFFF, 16), Repl(~self.flg.user_mode, 16))) with m.Case(): m.d.comb += self.rs1.eq(self.reg_arr[self.rs1_addr]) with m.Switch(self.rs2_addr): with m.Case(self.flg.idx): m.d.comb += self.rs2.eq(self.flg & Cat(Const(0xFFFF, 16), Repl(~self.flg.user_mode, 16))) with m.Case(): m.d.comb += self.rs2.eq(self.reg_arr[self.rs2_addr]) return m #--------------------------------- TEST BENCH BELOW ---------------------------------# def _init(dut): for i in range(16): yield dut.reg_arr[i].eq(i) yield Settle() # test combinational output def test_reg_comb_output(): dut = Reg() def proc(): yield from _init(dut) for i in range(16): yield dut.rs1_addr.eq(i) yield Settle() rs1 = yield dut.rs1 assert rs1 == i, f'ERROR reading {dut.reg_arr[i].name}: expected {i} != {rs1}' for i in range(16): yield dut.rs2_addr.eq(i) yield Settle() rs2 = yield dut.rs2 assert rs2 == i, f'EEROR reading {dut.reg_arr[i].name}: expected {i} != {rs2}' sim(dut, proc) # test writeback with writeback disabled def test_reg_writeback_dsb(): dut = Reg() def proc(): yield from _init(dut) for i in range(16): yield dut.rd_addr.eq(i) yield dut.rd.eq(i + 1) yield if (i != dut.ip.idx) and (i != dut.flg.idx): # flag gets update by the alu assert (yield dut.reg_arr[i]) == i, f'ERROR writing to {dut.reg_arr[i].name} != {i}' sim(dut, proc) # test writeback with writeback enabled def test_reg_writeback_en(): dut = Reg() def proc(): for i in range(16): yield from _init(dut) yield dut.wr_en.eq(1) yield dut.rd_addr.eq(i) yield dut.rd.eq(i - 1) yield from step() if (i != dut.zx.idx) and (i != dut.ip.idx): val = yield dut.reg_arr[i] assert val == i-1, f'ERROR writing to {dut.reg_arr[i].name}, expected {i-1} != {val}' elif i == dut.zx.idx: assert (yield dut.zx) == 0, f'ERROR {dut.zx.name}, expected 0 != {dut.zx}' elif i == dut.ip.idx: # ip should be incremented and not written to assert (yield dut.reg_arr[i]) == dut.ip.idx+4, f'ERROR {dut.ip.name} != {dut.ip.idx+4} should not be able to be directly written to' sim(dut, proc) # check to make sure alu is writing values def test_reg_flg_write_aluflg(): dut = Reg() def proc(): yield dut.flg.eq(0) yield dut.alu_flgs.eq(Repl(1, dut.alu_flgs.width)) yield dut.wr_en.eq(1) yield dut.flg.user_mode.eq(0) yield dut.rd_addr.eq(dut.zx.idx) # can be anything except flg register yield dut.rd.eq(0xFFFF0000) # this does not matter yield from step() assert (yield dut.flg) == (yield dut.alu_flgs), f'ERROR: alu is not writing to flg register' sim(dut, proc) def test_reg_flg_overwrite(): dut = Reg() def proc(): yield dut.flg.eq(0) yield dut.alu_flgs.eq(Repl(1, dut.alu_flgs.width)) yield dut.wr_en.eq(1) yield dut.flg.user_mode.eq(0) yield dut.flg[15].eq(1) yield dut.flg[31].eq(1) yield dut.rd_addr.eq(dut.flg.idx) yield dut.rd.eq(0xFFFF0000) yield from step() assert (yield dut.flg) == (0xFFFF0000), f'ERROR: alu status should not be to flag' sim(dut, proc) # test flag register security def test_reg_flg_read_usermode(): dut = Reg() def proc(): yield dut.flg.eq(0) yield dut.flg.user_mode.eq(1) yield dut.flg[15].eq(1) yield dut.flg[31].eq(1) yield dut.rs1_addr.eq(dut.flg.idx) yield Settle() assert (yield dut.rs1) == 0x00008000, f'ERROR: able to read upper 16 bits of flg reg in user mode' sim(dut, proc) # test flag register security def test_reg_flg_write_usermode(): dut = Reg() def proc(): yield dut.flg.eq(0) yield dut.wr_en.eq(1) yield dut.flg.user_mode.eq(1) yield dut.flg[15].eq(1) yield dut.flg[31].eq(1) yield dut.rd_addr.eq(dut.flg.idx) yield dut.rd.eq(0xABCD5789) yield from step() assert (yield dut.flg) == (0x80020000 | 0x5789), f'ERROR: able to write to upper 16 bits of flg reg in user mode' sim(dut, proc) def test_reg_flg_read_systemmode(): dut = Reg() def proc(): yield dut.flg.eq(0) yield dut.flg.user_mode.eq(0) yield dut.flg[15].eq(1) yield dut.flg[31].eq(1) yield dut.rs1_addr.eq(dut.flg.idx) yield Settle() assert (yield dut.rs1) == 0x80008000, f'ERROR: able to read all bits of flg reg in system mode' sim(dut, proc) # make sure not to write alu flags when directly writing to flg register def test_reg_flg_write_systemmode(): dut = Reg() def proc(): yield dut.flg.eq(0) yield dut.wr_en.eq(1) yield dut.flg.user_mode.eq(0) yield dut.flg[15].eq(1) yield dut.flg[31].eq(1) yield dut.rd_addr.eq(dut.flg.idx) yield dut.rd.eq(0xABCD5789) yield from step() assert (yield dut.flg) == (0xABCD5789), f'ERROR: unamble to write to all bits in supervisor mode' sim(dut, proc) if __name__ == '__main__': # reg = InOutBuff(Reg()) reg = Reg() cmd(reg)