Untitled

 avatar
unknown
plain_text
18 days ago
46 kB
3
Indexable
module Simple_CPU (
    input  wire        CLK,
    input  wire        RSTN,

    output wire        dmem_en,
    output reg         dmem_we,
    output reg  [9:0]  dmem_addr,
    output reg  [31:0] dmem_wdata,
    input  wire [31:0] dmem_rdata
);

    assign dmem_en = 1'b1;

    // ============================================================
    // Instruction Memory
    // ============================================================
    reg  [9:0]  imem_addr;
    wire [31:0] imem_rdata;

    Instruction_Memory u_Instruction_Memory (
        .clka  (CLK),
        .ena   (1'b1),
        .addra (imem_addr),
        .douta (imem_rdata)
    );

    // ============================================================
    // Opcode
    // ============================================================
    localparam OPCODE_RTYPE  = 7'b0110011; // add, sub
    localparam OPCODE_ITYPE  = 7'b0010011; // addi
    localparam OPCODE_LOAD   = 7'b0000011; // lw
    localparam OPCODE_STORE  = 7'b0100011; // sw
    localparam OPCODE_BRANCH = 7'b1100011; // beq, blt

    // ============================================================
    // FSM states
    // ============================================================
    localparam ST_FETCH_ADDR       = 5'd0;
    localparam ST_FETCH_WAIT1      = 5'd1;
    localparam ST_FETCH_CAPTURE    = 5'd2;
    localparam ST_DECODE           = 5'd3;

    localparam ST_EXE_R            = 5'd4;
    localparam ST_EXE_I            = 5'd5;

    localparam ST_ADDR             = 5'd6;

    localparam ST_MEM_RD           = 5'd7;
    localparam ST_MEM_WAIT1        = 5'd8;
    localparam ST_MEM_CAPTURE      = 5'd9;
    localparam ST_WB_LW            = 5'd10; // kept as fallback only

    localparam ST_MEM_WR_SETUP     = 5'd11;
    localparam ST_MEM_WR_DO        = 5'd12;

    localparam ST_BRANCH           = 5'd13;

    localparam ST_DONE             = 5'd14;
    localparam ST_HALT             = 5'd15;

    reg [4:0] state;

    // ============================================================
    // CPU internal registers
    // ============================================================
    reg [31:0] pc;
    reg [31:0] pc_old;
    reg [31:0] ir;

    reg [31:0] reg_a;
    reg [31:0] reg_b;

    // lw/sw only need byte address [11:0], because dmem_addr uses [11:2]
    reg [11:0] alu_addr;

    // ============================================================
    // Sparse Register File
    // COE fixed version:
    // Keep only x1~x13, x20, x21
    //
    // x0 is hardwired to zero.
    // x14~x19 and x22~x31 return 0 if read, ignore if written.
    // ============================================================
    reg [31:0] r1;
    reg [31:0] r2;
    reg [31:0] r3;
    reg [31:0] r4;
    reg [31:0] r5;
    reg [31:0] r6;
    reg [31:0] r7;
    reg [31:0] r8;
    reg [31:0] r9;
    reg [31:0] r10;
    reg [31:0] r11;
    reg [31:0] r12;
    reg [31:0] r13;
    reg [31:0] r20;
    reg [31:0] r21;

    // ============================================================
    // Instruction fields
    // ============================================================
    wire [6:0] opcode = ir[6:0];
    wire [4:0] rd     = ir[11:7];
    wire [2:0] funct3 = ir[14:12];
    wire [4:0] rs1    = ir[19:15];
    wire [4:0] rs2    = ir[24:20];
    wire [6:0] funct7 = ir[31:25];

    // ============================================================
    // Immediate generator
    // ============================================================
    wire signed [31:0] imm_i;
    wire signed [31:0] imm_s;
    wire signed [31:0] imm_b;

    assign imm_i = {{20{ir[31]}}, ir[31:20]};
    assign imm_s = {{20{ir[31]}}, ir[31:25], ir[11:7]};
    assign imm_b = {{19{ir[31]}}, ir[31], ir[7],
                    ir[30:25], ir[11:8], 1'b0};

    // ============================================================
    // Instruction decode helper
    // ============================================================
    wire is_add  = (opcode == OPCODE_RTYPE) &&
                   (funct3 == 3'b000) &&
                   (funct7 == 7'b0000000);

    wire is_sub  = (opcode == OPCODE_RTYPE) &&
                   (funct3 == 3'b000) &&
                   (funct7 == 7'b0100000);

    wire is_addi = (opcode == OPCODE_ITYPE) &&
                   (funct3 == 3'b000);

    wire is_lw   = (opcode == OPCODE_LOAD) &&
                   (funct3 == 3'b010);

    wire is_sw   = (opcode == OPCODE_STORE) &&
                   (funct3 == 3'b010);

    wire is_beq  = (opcode == OPCODE_BRANCH) &&
                   (funct3 == 3'b000);

    wire is_blt  = (opcode == OPCODE_BRANCH) &&
                   (funct3 == 3'b100);

    // Full address calculation wires.
    // Do not write (reg_a + imm_i)[11:0], because some Verilog parsers dislike slicing expressions.
    wire [31:0] mem_addr_i_full = reg_a + imm_i;
    wire [31:0] mem_addr_s_full = reg_a + imm_s;

    // ============================================================
    // Sparse register read
    // ============================================================
    function [31:0] read_reg;
        input [4:0] raddr;
        begin
            case (raddr)
                5'd0:  read_reg = 32'd0;

                5'd1:  read_reg = r1;
                5'd2:  read_reg = r2;
                5'd3:  read_reg = r3;
                5'd4:  read_reg = r4;
                5'd5:  read_reg = r5;
                5'd6:  read_reg = r6;
                5'd7:  read_reg = r7;
                5'd8:  read_reg = r8;
                5'd9:  read_reg = r9;
                5'd10: read_reg = r10;
                5'd11: read_reg = r11;
                5'd12: read_reg = r12;
                5'd13: read_reg = r13;

                5'd20: read_reg = r20;
                5'd21: read_reg = r21;

                default: read_reg = 32'd0;
            endcase
        end
    endfunction

    // ============================================================
    // Main FSM
    // ============================================================
    always @(posedge CLK or negedge RSTN) begin
        if (!RSTN) begin
            state      <= ST_FETCH_ADDR;

            pc         <= 32'd0;
            pc_old     <= 32'd0;
            ir         <= 32'd0;

            reg_a      <= 32'd0;
            reg_b      <= 32'd0;
            alu_addr   <= 12'd0;

            imem_addr  <= 10'd0;

            dmem_addr  <= 10'd0;
            dmem_we    <= 1'b0;
            dmem_wdata <= 32'd0;

            r1  <= 32'd0;
            r2  <= 32'd0;
            r3  <= 32'd0;
            r4  <= 32'd0;
            r5  <= 32'd0;
            r6  <= 32'd0;
            r7  <= 32'd0;
            r8  <= 32'd0;
            r9  <= 32'd0;
            r10 <= 32'd0;
            r11 <= 32'd0;
            r12 <= 32'd0;
            r13 <= 32'd0;
            r20 <= 32'd0;
            r21 <= 32'd0;
        end
        else begin
            dmem_we <= 1'b0;

            case (state)

                // ====================================================
                // Instruction fetch
                // Keep instruction memory latency unchanged
                // ====================================================
                ST_FETCH_ADDR: begin
                    pc_old    <= pc;
                    imem_addr <= pc[11:2];
                    state     <= ST_FETCH_WAIT1;
                end

                ST_FETCH_WAIT1: begin
                    state <= ST_FETCH_CAPTURE;
                end

                ST_FETCH_CAPTURE: begin
                    ir    <= imem_rdata;
                    pc    <= pc + 32'd4;
                    state <= ST_DECODE;
                end

                // ====================================================
                // Decode
                // ====================================================
                ST_DECODE: begin
                    reg_a <= read_reg(rs1);
                    reg_b <= read_reg(rs2);

                    case (opcode)
                        OPCODE_RTYPE: begin
                            if (is_add || is_sub)
                                state <= ST_EXE_R;
                            else
                                state <= ST_DONE;
                        end

                        OPCODE_ITYPE: begin
                            if (is_addi)
                                state <= ST_EXE_I;
                            else
                                state <= ST_DONE;
                        end

                        OPCODE_LOAD: begin
                            if (is_lw)
                                state <= ST_ADDR;
                            else
                                state <= ST_DONE;
                        end

                        OPCODE_STORE: begin
                            if (is_sw)
                                state <= ST_ADDR;
                            else
                                state <= ST_DONE;
                        end

                        OPCODE_BRANCH: begin
                            if (is_beq || is_blt)
                                state <= ST_BRANCH;
                            else
                                state <= ST_DONE;
                        end

                        default: begin
                            state <= ST_DONE;
                        end
                    endcase
                end

                // ====================================================
                // R-type: add / sub
                // EXE + WB merged
                // ====================================================
                ST_EXE_R: begin
                    if (rd != 5'd0) begin
                        case (rd)
                            5'd1:  r1  <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd2:  r2  <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd3:  r3  <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd4:  r4  <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd5:  r5  <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd6:  r6  <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd7:  r7  <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd8:  r8  <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd9:  r9  <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd10: r10 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd11: r11 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd12: r12 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd13: r13 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd20: r20 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            5'd21: r21 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
                            default: begin end
                        endcase
                    end

                    state <= ST_FETCH_ADDR;
                end

                // ====================================================
                // I-type: addi
                // EXE + WB merged
                // ====================================================
                ST_EXE_I: begin
                    if (rd != 5'd0) begin
                        case (rd)
                            5'd1:  r1  <= reg_a + imm_i;
                            5'd2:  r2  <= reg_a + imm_i;
                            5'd3:  r3  <= reg_a + imm_i;
                            5'd4:  r4  <= reg_a + imm_i;
                            5'd5:  r5  <= reg_a + imm_i;
                            5'd6:  r6  <= reg_a + imm_i;
                            5'd7:  r7  <= reg_a + imm_i;
                            5'd8:  r8  <= reg_a + imm_i;
                            5'd9:  r9  <= reg_a + imm_i;
                            5'd10: r10 <= reg_a + imm_i;
                            5'd11: r11 <= reg_a + imm_i;
                            5'd12: r12 <= reg_a + imm_i;
                            5'd13: r13 <= reg_a + imm_i;
                            5'd20: r20 <= reg_a + imm_i;
                            5'd21: r21 <= reg_a + imm_i;
                            default: begin end
                        endcase
                    end

                    state <= ST_FETCH_ADDR;
                end

                // ====================================================
                // lw / sw address calculation
                // 12-bit byte address only
                // ====================================================
                ST_ADDR: begin
                    if (opcode == OPCODE_LOAD) begin
                        alu_addr <= mem_addr_i_full[11:0];
                        state    <= ST_MEM_RD;
                    end
                    else begin
                        alu_addr <= mem_addr_s_full[11:0];
                        state    <= ST_MEM_WR_SETUP;
                    end
                end

                // ====================================================
                // lw
                // Keep BRAM latency behavior unchanged
                // ====================================================
                ST_MEM_RD: begin
                    dmem_addr <= alu_addr[11:2];
                    state     <= ST_MEM_WAIT1;
                end

                ST_MEM_WAIT1: begin
                    state <= ST_MEM_CAPTURE;
                end

                // ====================================================
                // lw write-back optimized:
                // Original:
                //   ST_MEM_CAPTURE: mdr <= dmem_rdata
                //   ST_WB_LW:       reg <= mdr
                //
                // New:
                //   ST_MEM_CAPTURE: reg <= dmem_rdata
                //
                // This removes mdr and saves 1 cycle per lw.
                // ====================================================
                ST_MEM_CAPTURE: begin
                    if (rd != 5'd0) begin
                        case (rd)
                            5'd1:  r1  <= dmem_rdata;
                            5'd2:  r2  <= dmem_rdata;
                            5'd3:  r3  <= dmem_rdata;
                            5'd4:  r4  <= dmem_rdata;
                            5'd5:  r5  <= dmem_rdata;
                            5'd6:  r6  <= dmem_rdata;
                            5'd7:  r7  <= dmem_rdata;
                            5'd8:  r8  <= dmem_rdata;
                            5'd9:  r9  <= dmem_rdata;
                            5'd10: r10 <= dmem_rdata;
                            5'd11: r11 <= dmem_rdata;
                            5'd12: r12 <= dmem_rdata;
                            5'd13: r13 <= dmem_rdata;
                            5'd20: r20 <= dmem_rdata;
                            5'd21: r21 <= dmem_rdata;
                            default: begin end
                        endcase
                    end

                    state <= ST_FETCH_ADDR;
                end

                // Kept only as fallback. Normally not entered.
                ST_WB_LW: begin
                    state <= ST_FETCH_ADDR;
                end

                // ====================================================
                // sw
                // Keep write setup / write enable alignment unchanged
                // ====================================================
                ST_MEM_WR_SETUP: begin
                    dmem_addr  <= alu_addr[11:2];
                    dmem_wdata <= reg_b;
                    dmem_we    <= 1'b0;
                    state      <= ST_MEM_WR_DO;
                end

                ST_MEM_WR_DO: begin
                    dmem_we <= 1'b1;
                    state   <= ST_FETCH_ADDR;
                end

                // ====================================================
                // Branch: beq / blt
                // Keep unchanged
                // ====================================================
                ST_BRANCH: begin
                    if (is_beq) begin
                        if (reg_a == reg_b)
                            pc <= pc_old + imm_b;
                    end
                    else if (is_blt) begin
                        if ($signed(reg_a) < $signed(reg_b))
                            pc <= pc_old + imm_b;
                    end

                    state <= ST_FETCH_ADDR;
                end

                // ====================================================
                // Program done
                // ====================================================
                ST_DONE: begin
                    dmem_we <= 1'b0;
                    state   <= ST_HALT;
                end

                ST_HALT: begin
                    dmem_we <= 1'b0;
                    state   <= ST_HALT;
                end

                default: begin
                    state <= ST_FETCH_ADDR;
                end
            endcase
        end
    end

endmodule





module CNN (
    input  wire        clk,
    input  wire        rstn,

    input  wire [31:0] doutb,
    output reg         web,
    output reg         enb,
    output reg  [31:0] dinb,
    output reg  [9:0]  addr,
    output reg         done
);

    // ============================================================
    // Fixed memory map
    // ============================================================
    localparam BASE_W0        = 10'd0;
    localparam BASE_W1        = 10'd3;
    localparam ADDR_SIZE      = 10'd12;
    localparam ADDR_STATUS    = 10'd13;
    localparam ADDR_BIAS0     = 10'd14;
    localparam ADDR_BIAS1     = 10'd15;
    localparam BASE_IN        = 10'd16;

    localparam BASE_MID       = 10'd272;
    localparam BASE_OUT       = 10'd512;
    localparam ADDR_CNN_START = 10'd900;

    // ============================================================
    // Helper functions
    // ============================================================
    function [7:0] get_byte;
        input [31:0] word;
        input [1:0] sel;
        begin
            case (sel)
                2'd0: get_byte = word[31:24];
                2'd1: get_byte = word[23:16];
                2'd2: get_byte = word[15:8];
                2'd3: get_byte = word[7:0];
                default: get_byte = 8'd0;
            endcase
        end
    endfunction

    function [31:0] put_byte;
        input [31:0] old_word;
        input [1:0] sel;
        input [7:0] new_byte;
        begin
            put_byte = old_word;
            case (sel)
                2'd0: put_byte[31:24] = new_byte;
                2'd1: put_byte[23:16] = new_byte;
                2'd2: put_byte[15:8]  = new_byte;
                2'd3: put_byte[7:0]   = new_byte;
                default: put_byte = old_word;
            endcase
        end
    endfunction

    // ============================================================
    // Q12 -> Q6 round-to-nearest ties-to-even + saturation
    // Timing optimized signed-floor version
    // ============================================================
    function signed [7:0] round_sat_q12_to_q6;
        input signed [19:0] in_q12;

        reg signed [19:0] floor_q6;
        reg [5:0]         frac;
        reg               round_up;
        reg signed [19:0] rounded_q6;

        begin
            floor_q6 = in_q12 >>> 6;
            frac     = in_q12[5:0];

            if (frac > 6'd32)
                round_up = 1'b1;
            else if (frac == 6'd32)
                round_up = floor_q6[0];
            else
                round_up = 1'b0;

            rounded_q6 = floor_q6 + {{19{1'b0}}, round_up};

            if (rounded_q6 > 20'sd127)
                round_sat_q12_to_q6 = 8'sh7F;
            else if (rounded_q6 < -20'sd128)
                round_sat_q12_to_q6 = 8'sh80;
            else
                round_sat_q12_to_q6 = rounded_q6[7:0];
        end
    endfunction

    // ============================================================
    // FSM states
    // ============================================================
    reg [5:0] state;
    reg [5:0] after_wait_state;

    localparam ST_IDLE          = 6'd0;
    localparam ST_WAIT1         = 6'd1;

    localparam ST_CAP_SIZE      = 6'd4;
    localparam ST_CAP_BIAS0     = 6'd5;
    localparam ST_CAP_BIAS1     = 6'd6;

    localparam ST_CAP_W0_0      = 6'd7;
    localparam ST_CAP_W0_1      = 6'd8;
    localparam ST_CAP_W0_2      = 6'd9;

    localparam ST_CAP_W1_0      = 6'd10;
    localparam ST_CAP_W1_1      = 6'd11;
    localparam ST_CAP_W1_2      = 6'd12;

    localparam ST_BEGIN_LAYER   = 6'd13;

    localparam ST_PRELOAD_INIT  = 6'd14;
    localparam ST_PRELOAD_REQ   = 6'd15;
    localparam ST_PRELOAD_WAIT  = 6'd16;
    localparam ST_PRELOAD_CAP   = 6'd17;

    localparam ST_MAC_ROW0      = 6'd18;
    localparam ST_MAC_ROW1      = 6'd19;
    localparam ST_MAC_ROW2      = 6'd20;
    localparam ST_NEXT_CAP1     = 6'd21;
    localparam ST_NEXT_CAP2     = 6'd22; // kept as fallback only

    localparam ST_ADD_BIAS      = 6'd23;
    localparam ST_ROUND         = 6'd24;
    localparam ST_PREP_STORE    = 6'd25;
    localparam ST_STORE_PACK    = 6'd26;
    localparam ST_WRITE_DO      = 6'd27;
    localparam ST_ADV_PIXEL     = 6'd28;

    localparam ST_WRITE_STATUS_SETUP = 6'd29;
    localparam ST_WRITE_STATUS_DO    = 6'd30;
    localparam ST_DONE               = 6'd31;
    localparam ST_CHECK_START        = 6'd32;

    // ============================================================
    // Basic registers
    // ============================================================
    reg [5:0] fmap_size;
    reg [5:0] mid_size;
    reg [5:0] out_size;

    reg [5:0] run_in_size;
    reg [5:0] run_out_size;
    reg [9:0] run_in_base;
    reg [9:0] run_out_base;

    reg signed [7:0] bias0;
    reg signed [7:0] bias1;

    reg signed [7:0] w0 [0:8];
    reg signed [7:0] w1 [0:8];

    reg layer_id;

    reg [5:0] ox;
    reg [5:0] oy;

    // 20-bit accumulator version
    reg signed [19:0] acc_q12;
    reg signed [7:0]  conv_out_q6;

    reg [11:0] out_linear_idx;
    reg [31:0] pack_word;

    // ============================================================
    // Timing optimization for 115 MHz address generation
    // ============================================================
    reg [11:0] row_base_idx_q;
    reg [11:0] run_in_size_ext_q;
    reg [11:0] run_in_size_x2_q;

    reg [11:0] preload_row1_idx_q;
    reg [11:0] preload_row2_idx_q;

    // ============================================================
    // 3x3 sliding window registers
    // ============================================================
    reg signed [7:0] win00, win01, win02;
    reg signed [7:0] win10, win11, win12;
    reg signed [7:0] win20, win21, win22;

    reg signed [7:0] next_col0;
    reg signed [7:0] next_col1;
    reg signed [7:0] next_col2;

    reg [3:0] preload_cnt;
    reg [1:0] preload_row_q;
    reg [1:0] preload_col_q;
    reg [1:0] preload_byte_sel_q;

    reg [11:0] preload_base_idx_q;
    reg [11:0] preload_current_idx_q;
    reg [9:0]  preload_addr_q;
    reg [1:0]  preload_bsel_q;

    reg [11:0] pixel_base_idx_q;

    reg [1:0] next_byte_sel0;
    reg [1:0] next_byte_sel1;
    reg [1:0] next_byte_sel2;

    reg [1:0] mac_row_idx;

    reg signed [15:0] mul0_q;
    reg signed [15:0] mul1_q;
    reg signed [15:0] mul2_q;

    integer i;

    // ============================================================
    // Preload row/col decode
    // ============================================================
    function [1:0] preload_row_func;
        input [3:0] cnt;
        begin
            case (cnt)
                4'd0, 4'd1, 4'd2: preload_row_func = 2'd0;
                4'd3, 4'd4, 4'd5: preload_row_func = 2'd1;
                default:          preload_row_func = 2'd2;
            endcase
        end
    endfunction

    function [1:0] preload_col_func;
        input [3:0] cnt;
        begin
            case (cnt)
                4'd0, 4'd3, 4'd6: preload_col_func = 2'd0;
                4'd1, 4'd4, 4'd7: preload_col_func = 2'd1;
                default:          preload_col_func = 2'd2;
            endcase
        end
    endfunction

    wire [1:0] preload_row_now = preload_row_func(preload_cnt);
    wire [1:0] preload_col_now = preload_col_func(preload_cnt);

    // ============================================================
    // Timing-optimized preload address calculation
    // No oy * run_in_size and no direct run_in_size -> preload_addr_q path
    // ============================================================
    wire [11:0] preload_base_idx_now =
        row_base_idx_q;

    wire [9:0] preload_base_addr_now =
        run_in_base + row_base_idx_q[11:2];

    wire [1:0] preload_base_bsel_now =
        row_base_idx_q[1:0];

    wire [11:0] preload_idx_plus1 =
        preload_current_idx_q + 12'd1;

    wire [11:0] preload_idx_row1 =
        preload_row1_idx_q;

    wire [11:0] preload_idx_row2 =
        preload_row2_idx_q;

    wire [11:0] preload_next_idx_calc =
        (preload_cnt == 4'd2) ? preload_idx_row1 :
        (preload_cnt == 4'd5) ? preload_idx_row2 :
                                preload_idx_plus1;

    wire [9:0] preload_next_addr_calc =
        run_in_base + preload_next_idx_calc[11:2];

    wire [1:0] preload_next_bsel_calc =
        preload_next_idx_calc[1:0];

    // ============================================================
    // Next column address calculation
    // Use registered run_in_size_ext_q, not run_in_size directly
    // ============================================================
    wire has_next_x =
        (ox < (run_out_size - 6'd1));

    wire [11:0] next_linear0 =
        pixel_base_idx_q + 12'd3;

    wire [11:0] next_linear1 =
        pixel_base_idx_q + run_in_size_ext_q + 12'd3;

    wire [11:0] next_linear2 =
        pixel_base_idx_q + run_in_size_x2_q + 12'd3;

    wire [9:0] next_addr0 =
        run_in_base + next_linear0[11:2];

    wire [9:0] next_addr1 =
        run_in_base + next_linear1[11:2];

    wire [9:0] next_addr2 =
        run_in_base + next_linear2[11:2];

    wire [1:0] next_bsel0 = next_linear0[1:0];
    wire [1:0] next_bsel1 = next_linear1[1:0];
    wire [1:0] next_bsel2 = next_linear2[1:0];

    // ============================================================
    // 3 DSP row MAC
    // ============================================================
    wire signed [7:0] mac_pix0 =
        (mac_row_idx == 2'd0) ? win00 :
        (mac_row_idx == 2'd1) ? win10 :
                                win20;

    wire signed [7:0] mac_pix1 =
        (mac_row_idx == 2'd0) ? win01 :
        (mac_row_idx == 2'd1) ? win11 :
                                win21;

    wire signed [7:0] mac_pix2 =
        (mac_row_idx == 2'd0) ? win02 :
        (mac_row_idx == 2'd1) ? win12 :
                                win22;

    wire signed [7:0] mac_w0 =
        (layer_id == 1'b0) ?
            ((mac_row_idx == 2'd0) ? w0[0] :
             (mac_row_idx == 2'd1) ? w0[3] : w0[6]) :
            ((mac_row_idx == 2'd0) ? w1[0] :
             (mac_row_idx == 2'd1) ? w1[3] : w1[6]);

    wire signed [7:0] mac_w1 =
        (layer_id == 1'b0) ?
            ((mac_row_idx == 2'd0) ? w0[1] :
             (mac_row_idx == 2'd1) ? w0[4] : w0[7]) :
            ((mac_row_idx == 2'd0) ? w1[1] :
             (mac_row_idx == 2'd1) ? w1[4] : w1[7]);

    wire signed [7:0] mac_w2 =
        (layer_id == 1'b0) ?
            ((mac_row_idx == 2'd0) ? w0[2] :
             (mac_row_idx == 2'd1) ? w0[5] : w0[8]) :
            ((mac_row_idx == 2'd0) ? w1[2] :
             (mac_row_idx == 2'd1) ? w1[5] : w1[8]);

    (* use_dsp = "yes" *) wire signed [15:0] mul0 =
        mac_pix0 * mac_w0;

    (* use_dsp = "yes" *) wire signed [15:0] mul1 =
        mac_pix1 * mac_w1;

    (* use_dsp = "yes" *) wire signed [15:0] mul2 =
        mac_pix2 * mac_w2;

    wire signed [19:0] mul0_ext = {{4{mul0_q[15]}}, mul0_q};
    wire signed [19:0] mul1_ext = {{4{mul1_q[15]}}, mul1_q};
    wire signed [19:0] mul2_ext = {{4{mul2_q[15]}}, mul2_q};

    wire signed [19:0] row_sum_q12 =
        mul0_ext + mul1_ext + mul2_ext;

    wire signed [19:0] bias0_q12 =
        {{6{bias0[7]}}, bias0, 6'b000000};

    wire signed [19:0] bias1_q12 =
        {{6{bias1[7]}}, bias1, 6'b000000};

    // ============================================================
    // Output packing
    // ============================================================
    wire [1:0] out_byte_sel =
        out_linear_idx[1:0];

    wire [9:0] out_word_addr =
        run_out_base + out_linear_idx[11:2];

    wire is_last_pixel =
        (ox == (run_out_size - 6'd1)) &&
        (oy == (run_out_size - 6'd1));

    wire should_write_word =
        (out_byte_sel == 2'd3) || is_last_pixel;

    wire [31:0] pack_word_next =
        put_byte(pack_word, out_byte_sel, conv_out_q6);

    // ============================================================
    // Main FSM
    // ============================================================
    always @(posedge clk or negedge rstn) begin
        if (!rstn) begin
            state <= ST_IDLE;
            after_wait_state <= ST_IDLE;

            web  <= 1'b0;
            enb  <= 1'b1;
            dinb <= 32'd0;
            addr <= 10'd0;
            done <= 1'b0;

            fmap_size <= 6'd0;
            mid_size  <= 6'd0;
            out_size  <= 6'd0;

            run_in_size  <= 6'd0;
            run_out_size <= 6'd0;
            run_in_base  <= 10'd0;
            run_out_base <= 10'd0;

            bias0 <= 8'sd0;
            bias1 <= 8'sd0;

            for (i = 0; i < 9; i = i + 1) begin
                w0[i] <= 8'sd0;
                w1[i] <= 8'sd0;
            end

            layer_id <= 1'b0;
            ox <= 6'd0;
            oy <= 6'd0;

            acc_q12 <= 20'sd0;
            conv_out_q6 <= 8'sd0;

            out_linear_idx <= 12'd0;
            pack_word <= 32'd0;

            row_base_idx_q    <= 12'd0;
            run_in_size_ext_q <= 12'd0;
            run_in_size_x2_q  <= 12'd0;

            preload_row1_idx_q <= 12'd0;
            preload_row2_idx_q <= 12'd0;

            win00 <= 8'sd0; win01 <= 8'sd0; win02 <= 8'sd0;
            win10 <= 8'sd0; win11 <= 8'sd0; win12 <= 8'sd0;
            win20 <= 8'sd0; win21 <= 8'sd0; win22 <= 8'sd0;

            next_col0 <= 8'sd0;
            next_col1 <= 8'sd0;
            next_col2 <= 8'sd0;

            preload_cnt <= 4'd0;
            preload_row_q <= 2'd0;
            preload_col_q <= 2'd0;
            preload_byte_sel_q <= 2'd0;

            preload_base_idx_q <= 12'd0;
            preload_current_idx_q <= 12'd0;
            preload_addr_q <= 10'd0;
            preload_bsel_q <= 2'd0;

            pixel_base_idx_q <= 12'd0;

            next_byte_sel0 <= 2'd0;
            next_byte_sel1 <= 2'd0;
            next_byte_sel2 <= 2'd0;

            mac_row_idx <= 2'd0;

            mul0_q <= 16'sd0;
            mul1_q <= 16'sd0;
            mul2_q <= 16'sd0;
        end
        else begin
            web <= 1'b0;
            enb <= 1'b1;

            case (state)

                ST_IDLE: begin
                    done <= 1'b0;
                    addr <= ADDR_CNN_START;
                    after_wait_state <= ST_CHECK_START;
                    state <= ST_WAIT1;
                end

                ST_CHECK_START: begin
                    if (doutb[0] == 1'b1) begin
                        addr <= ADDR_SIZE;
                        after_wait_state <= ST_CAP_SIZE;
                        state <= ST_WAIT1;
                    end
                    else begin
                        addr <= ADDR_CNN_START;
                        after_wait_state <= ST_CHECK_START;
                        state <= ST_WAIT1;
                    end
                end

                ST_WAIT1: begin
                    state <= after_wait_state;
                end

                ST_CAP_SIZE: begin
                    fmap_size <= doutb[6:1];
                    mid_size  <= doutb[6:1] - 6'd2;
                    out_size  <= doutb[6:1] - 6'd4;

                    addr <= ADDR_BIAS0;
                    after_wait_state <= ST_CAP_BIAS0;
                    state <= ST_WAIT1;
                end

                ST_CAP_BIAS0: begin
                    bias0 <= doutb[7:0];

                    addr <= ADDR_BIAS1;
                    after_wait_state <= ST_CAP_BIAS1;
                    state <= ST_WAIT1;
                end

                ST_CAP_BIAS1: begin
                    bias1 <= doutb[7:0];

                    addr <= BASE_W0;
                    after_wait_state <= ST_CAP_W0_0;
                    state <= ST_WAIT1;
                end

                ST_CAP_W0_0: begin
                    w0[0] <= get_byte(doutb, 2'd0);
                    w0[1] <= get_byte(doutb, 2'd1);
                    w0[2] <= get_byte(doutb, 2'd2);
                    w0[3] <= get_byte(doutb, 2'd3);

                    addr <= BASE_W0 + 10'd1;
                    after_wait_state <= ST_CAP_W0_1;
                    state <= ST_WAIT1;
                end

                ST_CAP_W0_1: begin
                    w0[4] <= get_byte(doutb, 2'd0);
                    w0[5] <= get_byte(doutb, 2'd1);
                    w0[6] <= get_byte(doutb, 2'd2);
                    w0[7] <= get_byte(doutb, 2'd3);

                    addr <= BASE_W0 + 10'd2;
                    after_wait_state <= ST_CAP_W0_2;
                    state <= ST_WAIT1;
                end

                ST_CAP_W0_2: begin
                    w0[8] <= get_byte(doutb, 2'd0);

                    addr <= BASE_W1;
                    after_wait_state <= ST_CAP_W1_0;
                    state <= ST_WAIT1;
                end

                ST_CAP_W1_0: begin
                    w1[0] <= get_byte(doutb, 2'd0);
                    w1[1] <= get_byte(doutb, 2'd1);
                    w1[2] <= get_byte(doutb, 2'd2);
                    w1[3] <= get_byte(doutb, 2'd3);

                    addr <= BASE_W1 + 10'd1;
                    after_wait_state <= ST_CAP_W1_1;
                    state <= ST_WAIT1;
                end

                ST_CAP_W1_1: begin
                    w1[4] <= get_byte(doutb, 2'd0);
                    w1[5] <= get_byte(doutb, 2'd1);
                    w1[6] <= get_byte(doutb, 2'd2);
                    w1[7] <= get_byte(doutb, 2'd3);

                    addr <= BASE_W1 + 10'd2;
                    after_wait_state <= ST_CAP_W1_2;
                    state <= ST_WAIT1;
                end

                ST_CAP_W1_2: begin
                    w1[8] <= get_byte(doutb, 2'd0);
                    layer_id <= 1'b0;
                    state <= ST_BEGIN_LAYER;
                end

                ST_BEGIN_LAYER: begin
                    ox <= 6'd0;
                    oy <= 6'd0;
                    pack_word <= 32'd0;
                    out_linear_idx <= 12'd0;

                    row_base_idx_q <= 12'd0;

                    if (layer_id == 1'b0) begin
                        run_in_size  <= fmap_size;
                        run_out_size <= mid_size;
                        run_in_base  <= BASE_IN;
                        run_out_base <= BASE_MID;

                        run_in_size_ext_q <= {6'd0, fmap_size};
                        run_in_size_x2_q  <= ({6'd0, fmap_size} << 1);
                    end
                    else begin
                        run_in_size  <= mid_size;
                        run_out_size <= out_size;
                        run_in_base  <= BASE_MID;
                        run_out_base <= BASE_OUT;

                        run_in_size_ext_q <= {6'd0, mid_size};
                        run_in_size_x2_q  <= ({6'd0, mid_size} << 1);
                    end

                    state <= ST_PRELOAD_INIT;
                end

                ST_PRELOAD_INIT: begin
                    preload_cnt <= 4'd0;

                    preload_base_idx_q    <= row_base_idx_q;
                    preload_current_idx_q <= row_base_idx_q;
                    pixel_base_idx_q      <= row_base_idx_q;

                    preload_row1_idx_q <= row_base_idx_q + run_in_size_ext_q;
                    preload_row2_idx_q <= row_base_idx_q + run_in_size_x2_q;

                    preload_addr_q <= preload_base_addr_now;
                    preload_bsel_q <= preload_base_bsel_now;

                    state <= ST_PRELOAD_REQ;
                end

                ST_PRELOAD_REQ: begin
                    addr <= preload_addr_q;

                    preload_row_q <= preload_row_now;
                    preload_col_q <= preload_col_now;
                    preload_byte_sel_q <= preload_bsel_q;

                    state <= ST_PRELOAD_WAIT;
                end

                ST_PRELOAD_WAIT: begin
                    state <= ST_PRELOAD_CAP;
                end

                ST_PRELOAD_CAP: begin
                    case ({preload_row_q, preload_col_q})
                        4'b00_00: win00 <= $signed(get_byte(doutb, preload_byte_sel_q));
                        4'b00_01: win01 <= $signed(get_byte(doutb, preload_byte_sel_q));
                        4'b00_10: win02 <= $signed(get_byte(doutb, preload_byte_sel_q));

                        4'b01_00: win10 <= $signed(get_byte(doutb, preload_byte_sel_q));
                        4'b01_01: win11 <= $signed(get_byte(doutb, preload_byte_sel_q));
                        4'b01_10: win12 <= $signed(get_byte(doutb, preload_byte_sel_q));

                        4'b10_00: win20 <= $signed(get_byte(doutb, preload_byte_sel_q));
                        4'b10_01: win21 <= $signed(get_byte(doutb, preload_byte_sel_q));
                        4'b10_10: win22 <= $signed(get_byte(doutb, preload_byte_sel_q));
                        default: begin end
                    endcase

                    if (preload_cnt == 4'd8) begin
                        acc_q12 <= 20'sd0;
                        mac_row_idx <= 2'd0;
                        state <= ST_MAC_ROW0;
                    end
                    else begin
                        preload_cnt <= preload_cnt + 4'd1;

                        preload_current_idx_q <= preload_next_idx_calc;
                        preload_addr_q <= preload_next_addr_calc;
                        preload_bsel_q <= preload_next_bsel_calc;

                        state <= ST_PRELOAD_REQ;
                    end
                end

                ST_MAC_ROW0: begin
                    mul0_q <= mul0;
                    mul1_q <= mul1;
                    mul2_q <= mul2;

                    if (has_next_x) begin
                        addr <= next_addr0;
                        next_byte_sel0 <= next_bsel0;
                    end

                    mac_row_idx <= 2'd1;
                    state <= ST_MAC_ROW1;
                end

                ST_MAC_ROW1: begin
                    acc_q12 <= row_sum_q12;

                    mul0_q <= mul0;
                    mul1_q <= mul1;
                    mul2_q <= mul2;

                    if (has_next_x) begin
                        addr <= next_addr1;
                        next_byte_sel1 <= next_bsel1;
                    end

                    mac_row_idx <= 2'd2;
                    state <= ST_MAC_ROW2;
                end

                ST_MAC_ROW2: begin
                    acc_q12 <= acc_q12 + row_sum_q12;

                    mul0_q <= mul0;
                    mul1_q <= mul1;
                    mul2_q <= mul2;

                    if (has_next_x) begin
                        next_col0 <= $signed(get_byte(doutb, next_byte_sel0));

                        addr <= next_addr2;
                        next_byte_sel2 <= next_bsel2;
                    end

                    state <= ST_NEXT_CAP1;
                end

                ST_NEXT_CAP1: begin
                    acc_q12 <= acc_q12 + row_sum_q12;

                    if (has_next_x) begin
                        next_col1 <= $signed(get_byte(doutb, next_byte_sel1));
                    end

                    state <= ST_ADD_BIAS;
                end

                ST_NEXT_CAP2: begin
                    state <= ST_ADD_BIAS;
                end

                ST_ADD_BIAS: begin
                    if (has_next_x) begin
                        next_col2 <= $signed(get_byte(doutb, next_byte_sel2));
                    end

                    if (layer_id == 1'b0)
                        acc_q12 <= acc_q12 + bias0_q12;
                    else
                        acc_q12 <= acc_q12 + bias1_q12;

                    state <= ST_ROUND;
                end

                ST_ROUND: begin
                    conv_out_q6 <= round_sat_q12_to_q6(acc_q12);
                    state <= ST_STORE_PACK;
                end

                ST_PREP_STORE: begin
                    state <= ST_STORE_PACK;
                end

                ST_STORE_PACK: begin
                    pack_word <= pack_word_next;

                    if (should_write_word) begin
                        addr <= out_word_addr;
                        dinb <= pack_word_next;
                        web  <= 1'b0;
                        state <= ST_WRITE_DO;
                    end
                    else begin
                        state <= ST_ADV_PIXEL;
                    end
                end

                ST_WRITE_DO: begin
                    web <= 1'b1;
                    pack_word <= 32'd0;
                    state <= ST_ADV_PIXEL;
                end

                ST_ADV_PIXEL: begin
                    if (!is_last_pixel) begin
                        out_linear_idx <= out_linear_idx + 12'd1;
                    end

                    if (is_last_pixel) begin
                        if (layer_id == 1'b0) begin
                            layer_id <= 1'b1;
                            state <= ST_BEGIN_LAYER;
                        end
                        else begin
                            state <= ST_WRITE_STATUS_SETUP;
                        end
                    end
                    else begin
                        if (ox == (run_out_size - 6'd1)) begin
                            ox <= 6'd0;
                            oy <= oy + 6'd1;

                            row_base_idx_q <= row_base_idx_q + run_in_size_ext_q;

                            state <= ST_PRELOAD_INIT;
                        end
                        else begin
                            win00 <= win01;
                            win01 <= win02;
                            win02 <= next_col0;

                            win10 <= win11;
                            win11 <= win12;
                            win12 <= next_col1;

                            win20 <= win21;
                            win21 <= win22;
                            win22 <= next_col2;

                            ox <= ox + 6'd1;
                            pixel_base_idx_q <= pixel_base_idx_q + 12'd1;

                            acc_q12 <= 20'sd0;
                            mac_row_idx <= 2'd0;
                            state <= ST_MAC_ROW0;
                        end
                    end
                end

                ST_WRITE_STATUS_SETUP: begin
                    addr <= ADDR_STATUS;
                    dinb <= {21'd0, BASE_OUT, 1'b1};
                    web  <= 1'b0;
                    state <= ST_WRITE_STATUS_DO;
                end

                ST_WRITE_STATUS_DO: begin
                    web <= 1'b1;
                    state <= ST_DONE;
                end

                ST_DONE: begin
                    web  <= 1'b0;
                    done <= 1'b1;
                    state <= ST_DONE;
                end

                default: begin
                    state <= ST_IDLE;
                end

            endcase
        end
    end

endmodule

Editor is loading...
Leave a Comment