Untitled

 avatar
unknown
plain_text
20 days ago
40 kB
3
Indexable
module Simple_CPU (
    input  wire        CLK,
    input  wire        RSTN,

    output wire        dmem_en,
    output reg         dmem_we,
    output reg  [9:0]  dmem_addr,
    output reg  [31:0] dmem_wdata,
    input  wire [31:0] dmem_rdata
);

    assign dmem_en = 1'b1;

    // ============================================================
    // Instruction Memory
    // ============================================================
    reg  [9:0]  imem_addr;
    wire [31:0] imem_rdata;

Instruction_Memory u_Instruction_Memory (
    .clka  (CLK),
    .ena   (1'b1),
    .addra (imem_addr),
    .douta (imem_rdata)
);

    // ============================================================
    // Opcode
    // ============================================================
    localparam OPCODE_RTYPE  = 7'b0110011; // add, sub
    localparam OPCODE_ITYPE  = 7'b0010011; // addi
    localparam OPCODE_LOAD   = 7'b0000011; // lw
    localparam OPCODE_STORE  = 7'b0100011; // sw
    localparam OPCODE_BRANCH = 7'b1100011; // beq, blt

    // ============================================================
    // FSM states
    // ============================================================
    localparam ST_FETCH_ADDR       = 6'd0;
    localparam ST_FETCH_WAIT1      = 6'd1;
    localparam ST_FETCH_WAIT2      = 6'd2;
    localparam ST_FETCH_WAIT3      = 6'd3;
    localparam ST_FETCH_CAPTURE    = 6'd4;
    localparam ST_DECODE           = 6'd5;

    localparam ST_EXE_R            = 6'd6;
    localparam ST_WB_R             = 6'd7;

    localparam ST_EXE_I            = 6'd8;
    localparam ST_WB_I             = 6'd9;

    localparam ST_ADDR             = 6'd10;

    localparam ST_MEM_RD           = 6'd11;
    localparam ST_MEM_WAIT1        = 6'd12;
    localparam ST_MEM_WAIT2        = 6'd13;
    localparam ST_MEM_WAIT3        = 6'd14;
    localparam ST_MEM_CAPTURE      = 6'd15;
    localparam ST_WB_LW            = 6'd16;

    localparam ST_MEM_WR_SETUP     = 6'd17;
    localparam ST_MEM_WR_DO        = 6'd18;

    localparam ST_BRANCH           = 6'd19;

    localparam ST_DONE             = 6'd20;

    // Bias generation states
    localparam ST_BIAS_RD16        = 6'd21;
    localparam ST_BIAS_WAIT16_1    = 6'd22;
    localparam ST_BIAS_WAIT16_2    = 6'd23;
    localparam ST_BIAS_WAIT16_3    = 6'd24;
    localparam ST_BIAS_CAP16       = 6'd25;

    localparam ST_BIAS_RD17        = 6'd26;
    localparam ST_BIAS_WAIT17_1    = 6'd27;
    localparam ST_BIAS_WAIT17_2    = 6'd28;
    localparam ST_BIAS_WAIT17_3    = 6'd29;
    localparam ST_BIAS_CAP17       = 6'd30;

    localparam ST_BIAS_RD18        = 6'd31;
    localparam ST_BIAS_WAIT18_1    = 6'd32;
    localparam ST_BIAS_WAIT18_2    = 6'd33;
    localparam ST_BIAS_WAIT18_3    = 6'd34;
    localparam ST_BIAS_CAP18       = 6'd35;

    localparam ST_BIAS0_SETUP      = 6'd36;
    localparam ST_BIAS0_WRITE      = 6'd37;
    localparam ST_BIAS1_SETUP      = 6'd38;
    localparam ST_BIAS1_WRITE      = 6'd39;

    localparam ST_CNN_FLAG_SETUP   = 6'd40;
    localparam ST_CNN_FLAG_WRITE   = 6'd41;
    localparam ST_HALT             = 6'd42;

    reg [5:0] state;

    // ============================================================
    // CPU internal registers
    // ============================================================
    reg [31:0] pc;
    reg [31:0] pc_old;
    reg [31:0] ir;

    reg [31:0] reg_a;
    reg [31:0] reg_b;
    reg [31:0] alu_out;
    reg [31:0] mdr;

    // Bias generation registers
    reg [31:0] bias_src16;
    reg [31:0] bias_src17;
    reg [31:0] bias_src18;
    reg [31:0] bias0_word;
    reg [31:0] bias1_word;

    // ============================================================
    // Register file
    // ============================================================
    reg [31:0] regs [0:31];
    integer i;

    // ============================================================
    // Instruction fields
    // ============================================================
    wire [6:0] opcode = ir[6:0];
    wire [4:0] rd     = ir[11:7];
    wire [2:0] funct3 = ir[14:12];
    wire [4:0] rs1    = ir[19:15];
    wire [4:0] rs2    = ir[24:20];
    wire [6:0] funct7 = ir[31:25];

    // ============================================================
    // Immediate generator
    // ============================================================
    wire signed [31:0] imm_i;
    wire signed [31:0] imm_s;
    wire signed [31:0] imm_b;

    assign imm_i = {{20{ir[31]}}, ir[31:20]};
    assign imm_s = {{20{ir[31]}}, ir[31:25], ir[11:7]};
    assign imm_b = {{19{ir[31]}}, ir[31], ir[7],
                    ir[30:25], ir[11:8], 1'b0};

    // ============================================================
    // Instruction decode helper
    // ============================================================
    wire is_add  = (opcode == OPCODE_RTYPE) &&
                   (funct3 == 3'b000) &&
                   (funct7 == 7'b0000000);

    wire is_sub  = (opcode == OPCODE_RTYPE) &&
                   (funct3 == 3'b000) &&
                   (funct7 == 7'b0100000);

    wire is_addi = (opcode == OPCODE_ITYPE) &&
                   (funct3 == 3'b000);

    wire is_lw   = (opcode == OPCODE_LOAD) &&
                   (funct3 == 3'b010);

    wire is_sw   = (opcode == OPCODE_STORE) &&
                   (funct3 == 3'b010);

    wire is_beq  = (opcode == OPCODE_BRANCH) &&
                   (funct3 == 3'b000);

    wire is_blt  = (opcode == OPCODE_BRANCH) &&
                   (funct3 == 3'b100);

    // ============================================================
    // Main FSM
    // ============================================================
    always @(posedge CLK or negedge RSTN) begin
        if (!RSTN) begin
            state      <= ST_FETCH_ADDR;

            pc         <= 32'd0;
            pc_old     <= 32'd0;
            ir         <= 32'd0;

            reg_a      <= 32'd0;
            reg_b      <= 32'd0;
            alu_out    <= 32'd0;
            mdr        <= 32'd0;

            bias_src16 <= 32'd0;
            bias_src17 <= 32'd0;
            bias_src18 <= 32'd0;
            bias0_word <= 32'd0;
            bias1_word <= 32'd0;

            imem_addr  <= 10'd0;

            dmem_addr  <= 10'd0;
            dmem_we    <= 1'b0;
            dmem_wdata <= 32'd0;

            for (i = 0; i < 32; i = i + 1) begin
                regs[i] <= 32'd0;
            end
        end
        else begin
            dmem_we <= 1'b0;
            regs[0] <= 32'd0;

            case (state)

                // ====================================================
                // Instruction fetch
                // ====================================================
                ST_FETCH_ADDR: begin
                    pc_old    <= pc;
                    imem_addr <= pc[11:2];
                    state     <= ST_FETCH_WAIT1;
                end

                ST_FETCH_WAIT1: begin
                    state <= ST_FETCH_CAPTURE;
                end

                ST_FETCH_WAIT2: begin
                    state <= ST_FETCH_WAIT3;
                end

                ST_FETCH_WAIT3: begin
                    state <= ST_FETCH_CAPTURE;
                end

                ST_FETCH_CAPTURE: begin
                    ir <= imem_rdata;
                    pc <= pc + 32'd4;
                    state <= ST_DECODE;
                end

                // ====================================================
                // Decode
                // ====================================================
                ST_DECODE: begin
                    reg_a <= (rs1 == 5'd0) ? 32'd0 : regs[rs1];
                    reg_b <= (rs2 == 5'd0) ? 32'd0 : regs[rs2];

                    case (opcode)
                        OPCODE_RTYPE: begin
                            if (is_add || is_sub)
                                state <= ST_EXE_R;
                            else
                                state <= ST_DONE;
                        end

                        OPCODE_ITYPE: begin
                            if (is_addi)
                                state <= ST_EXE_I;
                            else
                                state <= ST_DONE;
                        end

                        OPCODE_LOAD: begin
                            if (is_lw)
                                state <= ST_ADDR;
                            else
                                state <= ST_DONE;
                        end

                        OPCODE_STORE: begin
                            if (is_sw)
                                state <= ST_ADDR;
                            else
                                state <= ST_DONE;
                        end

                        OPCODE_BRANCH: begin
                            if (is_beq || is_blt)
                                state <= ST_BRANCH;
                            else
                                state <= ST_DONE;
                        end

                        default: begin
                            state <= ST_DONE;
                        end
                    endcase
                end

                // ====================================================
                // R-type
                // ====================================================
                ST_EXE_R: begin
                    if (is_sub)
                        alu_out <= reg_a - reg_b;
                    else
                        alu_out <= reg_a + reg_b;

                    state <= ST_WB_R;
                end

                ST_WB_R: begin
                    if (rd != 5'd0)
                        regs[rd] <= alu_out;

                    state <= ST_FETCH_ADDR;
                end

                // ====================================================
                // I-type addi
                // ====================================================
                ST_EXE_I: begin
                    alu_out <= reg_a + imm_i;
                    state <= ST_WB_I;
                end

                ST_WB_I: begin
                    if (rd != 5'd0)
                        regs[rd] <= alu_out;

                    state <= ST_FETCH_ADDR;
                end

                // ====================================================
                // lw / sw address calculation
                // ====================================================
                ST_ADDR: begin
                    if (opcode == OPCODE_LOAD) begin
                        alu_out <= reg_a + imm_i;
                        state   <= ST_MEM_RD;
                    end
                    else begin
                        alu_out <= reg_a + imm_s;
                        state   <= ST_MEM_WR_SETUP;
                    end
                end

                ST_MEM_RD: begin
                    dmem_addr <= alu_out[11:2];
                    state <= ST_MEM_WAIT1;
                end
                
                ST_MEM_WAIT1: begin
                    state <= ST_MEM_CAPTURE;
                end
                
                ST_MEM_WAIT2: begin
                    state <= ST_MEM_CAPTURE;
                end
                
                ST_MEM_WAIT3: begin
                    state <= ST_MEM_CAPTURE;
                end
                
                ST_MEM_CAPTURE: begin
                    mdr <= dmem_rdata;
                    state <= ST_WB_LW;
                end
                
                ST_WB_LW: begin
                    if (rd != 5'd0)
                        regs[rd] <= mdr;

                    state <= ST_FETCH_ADDR;
                end

                ST_MEM_WR_SETUP: begin
                    dmem_addr  <= alu_out[11:2];
                    dmem_wdata <= reg_b;
                    dmem_we    <= 1'b0;
                    state      <= ST_MEM_WR_DO;
                end

                ST_MEM_WR_DO: begin
                    dmem_we <= 1'b1;
                    state <= ST_FETCH_ADDR;
                end

                // ====================================================
                // Branch
                // ====================================================
                ST_BRANCH: begin
                    if (is_beq) begin
                        if (reg_a == reg_b)
                            pc <= pc_old + imm_b;
                    end
                    else if (is_blt) begin
                        if ($signed(reg_a) < $signed(reg_b))
                            pc <= pc_old + imm_b;
                    end

                    state <= ST_FETCH_ADDR;
                end

                // ====================================================
                // Program done, then generate bias before CNN
                // ====================================================
                ST_DONE: begin
                    dmem_addr  <= 10'd16;
                    dmem_wdata <= 32'd0;
                    dmem_we    <= 1'b0;
                    state      <= ST_BIAS_RD16;
                end

                // ====================================================
                // Bias generation
                // Bias0 = Data Memory[16] + [17] + [18]
                // Bias1 = if Data Memory[17] < Data Memory[16]
                //         then Data Memory[16] - Data Memory[17]
                //         else 1
                // ====================================================
                ST_BIAS_RD16: begin
                    dmem_addr <= 10'd16;
                    dmem_we   <= 1'b0;
                    state     <= ST_BIAS_WAIT16_1;
                end

               ST_BIAS_WAIT16_1: begin
                    state <= ST_BIAS_CAP16;
                end
                
                ST_BIAS_WAIT16_2: begin
                    state <= ST_BIAS_CAP16;
                end
                
                ST_BIAS_WAIT16_3: begin
                    state <= ST_BIAS_CAP16;
                end
                
                ST_BIAS_CAP16: begin
                    bias_src16 <= dmem_rdata;

                    dmem_addr <= 10'd17;
                    dmem_we   <= 1'b0;
                    state     <= ST_BIAS_RD17;
                end

                ST_BIAS_RD17: begin
                    dmem_addr <= 10'd17;
                    dmem_we   <= 1'b0;
                    state     <= ST_BIAS_WAIT17_1;
                end

               ST_BIAS_WAIT17_1: begin
                    state <= ST_BIAS_CAP17;
                end
                
                ST_BIAS_WAIT17_2: begin
                    state <= ST_BIAS_CAP17;
                end
                
                ST_BIAS_WAIT17_3: begin
                    state <= ST_BIAS_CAP17;
                end
                
                ST_BIAS_CAP17: begin
                    bias_src17 <= dmem_rdata;

                    dmem_addr <= 10'd18;
                    dmem_we   <= 1'b0;
                    state     <= ST_BIAS_RD18;
                end

                ST_BIAS_RD18: begin
                    dmem_addr <= 10'd18;
                    dmem_we   <= 1'b0;
                    state     <= ST_BIAS_WAIT18_1;
                end

               ST_BIAS_WAIT18_1: begin
                    state <= ST_BIAS_CAP18;
                end
                
                ST_BIAS_WAIT18_2: begin
                    state <= ST_BIAS_CAP18;
                end
                
                ST_BIAS_WAIT18_3: begin
                    state <= ST_BIAS_CAP18;
                end
                
                ST_BIAS_CAP18: begin
                    bias_src18 <= dmem_rdata;

                    bias0_word <= bias_src16 + bias_src17 + dmem_rdata;

                    if ($signed(bias_src17) < $signed(bias_src16))
                        bias1_word <= bias_src16 - bias_src17+ 32'd1;
                    else
                        bias1_word <= -32'sd15;

                    state <= ST_BIAS0_SETUP;
                end

                ST_BIAS0_SETUP: begin
                    dmem_addr  <= 10'd14;
                    dmem_wdata <= bias0_word;
                    dmem_we    <= 1'b0;
                    state      <= ST_BIAS0_WRITE;
                end

                ST_BIAS0_WRITE: begin
                    dmem_addr  <= 10'd14;
                    dmem_wdata <= bias0_word;
                    dmem_we    <= 1'b1;
                    state      <= ST_BIAS1_SETUP;
                end

                ST_BIAS1_SETUP: begin
                    dmem_addr  <= 10'd15;
                    dmem_wdata <= bias1_word;
                    dmem_we    <= 1'b0;
                    state      <= ST_BIAS1_WRITE;
                end

                ST_BIAS1_WRITE: begin
                    dmem_addr  <= 10'd15;
                    dmem_wdata <= bias1_word;
                    dmem_we    <= 1'b1;
                    state      <= ST_CNN_FLAG_SETUP;
                end

                // ====================================================
                // Start CNN after bias is ready
                // ====================================================
                ST_CNN_FLAG_SETUP: begin
                    dmem_addr  <= 10'd900;
                    dmem_wdata <= 32'd1;
                    dmem_we    <= 1'b0;
                    state      <= ST_CNN_FLAG_WRITE;
                end

                ST_CNN_FLAG_WRITE: begin
                    dmem_addr  <= 10'd900;
                    dmem_wdata <= 32'd1;
                    dmem_we    <= 1'b1;
                    state      <= ST_HALT;
                end

                ST_HALT: begin
                    dmem_we <= 1'b0;
                    state   <= ST_HALT;
                end

                default: begin
                    state <= ST_FETCH_ADDR;
                end

            endcase
        end
    end

endmodule

module CNN (
    input  wire        clk,
    input  wire        rstn,

    input  wire [31:0] doutb,
    output reg         web,
    output reg         enb,
    output reg  [31:0] dinb,
    output reg  [9:0]  addr,
    output reg         done
);

    // ============================================================
    // Fixed memory map
    // ============================================================
    localparam BASE_W0        = 10'd0;
    localparam BASE_W1        = 10'd3;
    localparam ADDR_SIZE      = 10'd12;
    localparam ADDR_STATUS    = 10'd13;
    localparam ADDR_BIAS0     = 10'd14;
    localparam ADDR_BIAS1     = 10'd15;
    localparam BASE_IN        = 10'd16;

    localparam BASE_MID       = 10'd272;
    localparam BASE_OUT       = 10'd512;
    localparam ADDR_CNN_START = 10'd900;

    // ============================================================
    // Helper functions
    // ============================================================
    function [7:0] get_byte;
        input [31:0] word;
        input [1:0] sel;
        begin
            case (sel)
                2'd0: get_byte = word[31:24];
                2'd1: get_byte = word[23:16];
                2'd2: get_byte = word[15:8];
                2'd3: get_byte = word[7:0];
                default: get_byte = 8'd0;
            endcase
        end
    endfunction

    function [31:0] put_byte;
        input [31:0] old_word;
        input [1:0] sel;
        input [7:0] new_byte;
        begin
            put_byte = old_word;
            case (sel)
                2'd0: put_byte[31:24] = new_byte;
                2'd1: put_byte[23:16] = new_byte;
                2'd2: put_byte[15:8]  = new_byte;
                2'd3: put_byte[7:0]   = new_byte;
                default: put_byte = old_word;
            endcase
        end
    endfunction

    function signed [7:0] round_sat_q12_to_q6;
        input signed [31:0] in_q12;

        reg sign;
        reg [31:0] abs_val;
        reg [31:0] main_abs;
        reg [5:0]  frac_abs;
        reg        round_up;
        reg [31:0] rounded_abs;
        reg signed [31:0] rounded_signed;

        begin
            if (in_q12 < 0) begin
                sign    = 1'b1;
                abs_val = -in_q12;
            end
            else begin
                sign    = 1'b0;
                abs_val = in_q12;
            end

            main_abs = abs_val >> 6;
            frac_abs = abs_val[5:0];

            // round-to-nearest, ties-to-even
            if (frac_abs > 6'd32)
                round_up = 1'b1;
            else if (frac_abs == 6'd32)
                round_up = main_abs[0];
            else
                round_up = 1'b0;

            rounded_abs = main_abs + (round_up ? 32'd1 : 32'd0);

            if (sign)
                rounded_signed = -$signed(rounded_abs);
            else
                rounded_signed =  $signed(rounded_abs);

            if (rounded_signed > 32'sd127)
                round_sat_q12_to_q6 = 8'sh7F;
            else if (rounded_signed < -32'sd128)
                round_sat_q12_to_q6 = 8'sh80;
            else
                round_sat_q12_to_q6 = rounded_signed[7:0];
        end
    endfunction

    // ============================================================
    // Kernel coordinate functions
    // k order:
    // 0 1 2
    // 3 4 5
    // 6 7 8
    // ============================================================
    function [5:0] k_dx_func;
        input [3:0] k;
        begin
            case (k)
                4'd1, 4'd4, 4'd7: k_dx_func = 6'd1;
                4'd2, 4'd5, 4'd8: k_dx_func = 6'd2;
                default:          k_dx_func = 6'd0;
            endcase
        end
    endfunction

    function [5:0] k_dy_func;
        input [3:0] k;
        begin
            case (k)
                4'd3, 4'd4, 4'd5: k_dy_func = 6'd1;
                4'd6, 4'd7, 4'd8: k_dy_func = 6'd2;
                default:          k_dy_func = 6'd0;
            endcase
        end
    endfunction

    // ============================================================
    // FSM states
    // ============================================================
    reg [5:0] state;
    reg [5:0] after_wait_state;

    localparam ST_IDLE          = 6'd0;
    localparam ST_WAIT1         = 6'd1;
    localparam ST_WAIT2         = 6'd2;
    localparam ST_WAIT3         = 6'd3;

    localparam ST_CAP_SIZE      = 6'd4;
    localparam ST_CAP_BIAS0     = 6'd5;
    localparam ST_CAP_BIAS1     = 6'd6;

    localparam ST_CAP_W0_0      = 6'd7;
    localparam ST_CAP_W0_1      = 6'd8;
    localparam ST_CAP_W0_2      = 6'd9;

    localparam ST_CAP_W1_0      = 6'd10;
    localparam ST_CAP_W1_1      = 6'd11;
    localparam ST_CAP_W1_2      = 6'd12;

    localparam ST_BEGIN_LAYER   = 6'd13;
    localparam ST_START_PIXEL   = 6'd14;
    localparam ST_PIPE_MAC      = 6'd15;

    localparam ST_ADD_BIAS      = 6'd17;
    localparam ST_ROUND         = 6'd18;
    localparam ST_PREP_STORE    = 6'd19;
    localparam ST_STORE_PACK    = 6'd20;
    localparam ST_WRITE_DO      = 6'd21;
    localparam ST_ADV_PIXEL     = 6'd22;

    localparam ST_WRITE_STATUS_SETUP = 6'd23;
    localparam ST_WRITE_STATUS_DO    = 6'd24;
    localparam ST_DONE               = 6'd25;
    localparam ST_CHECK_START        = 6'd26;

    // ============================================================
    // Registers
    // ============================================================
    reg [5:0] fmap_size;
    reg [5:0] mid_size;
    reg [5:0] out_size;

    reg [5:0] run_in_size;
    reg [5:0] run_out_size;
    reg [9:0] run_in_base;
    reg [9:0] run_out_base;

    reg signed [7:0] bias0;
    reg signed [7:0] bias1;

    reg signed [7:0] w0 [0:8];
    reg signed [7:0] w1 [0:8];

    reg layer_id; // 0 = layer1, 1 = layer2

    reg [5:0] ox;
    reg [5:0] oy;

    // k_idx is kept for debug visibility.
    // Actual pipeline uses issue_k / pipe_k0 / pipe_k1.
    reg [3:0] k_idx;
    reg [3:0] issue_k;

    reg [3:0] pipe_k0;
    reg [3:0] pipe_k1;

    reg [1:0] pipe_byte0;
    reg [1:0] pipe_byte1;

    reg pipe_v0;
    reg pipe_v1;
    wire [1:0] read_byte_sel = pipe_byte1;

    // Keep these for debug testbench
    reg signed [7:0] pix_q6;

    reg signed [31:0] acc_q12;
    reg signed [7:0]  conv_out_q6;

    reg [11:0] out_linear_idx;
    reg [31:0] pack_word;

    integer i;

    // ============================================================
    // Issue address for current issue_k
    // ============================================================
    wire [5:0] issue_x = ox + k_dx_func(issue_k);
    wire [5:0] issue_y = oy + k_dy_func(issue_k);

    wire [11:0] issue_linear_idx =
        ({6'd0, issue_y} * {6'd0, run_in_size}) + {6'd0, issue_x};

    wire [9:0] issue_word_addr =
        run_in_base + issue_linear_idx[11:2];

    wire [1:0] issue_byte_now =
        issue_linear_idx[1:0];

    // ============================================================
    // Output packing
    // ============================================================
    wire [11:0] out_linear_now =
        ({6'd0, oy} * {6'd0, run_out_size}) + {6'd0, ox};

    wire [1:0] out_byte_sel =
        out_linear_idx[1:0];

    wire [9:0] out_word_addr =
        run_out_base + out_linear_idx[11:2];

    wire is_last_pixel =
        (ox == (run_out_size - 6'd1)) &&
        (oy == (run_out_size - 6'd1));

    wire should_write_word =
        (out_byte_sel == 2'd3) || is_last_pixel;

    wire [31:0] pack_word_next =
        put_byte(pack_word, out_byte_sel, conv_out_q6);

    // ============================================================
    // MAC data aligned with pipe_k1 / pipe_byte1
    // doutb corresponds to address issued two CNN cycles earlier.
    // ============================================================
    wire signed [7:0] mac_pixel =
        $signed(get_byte(doutb, pipe_byte1));

    wire signed [7:0] mac_weight =
        (layer_id == 1'b0) ? w0[pipe_k1] : w1[pipe_k1];

    wire signed [31:0] mac_mul =
        $signed({{24{mac_pixel[7]}}, mac_pixel}) *
        $signed({{24{mac_weight[7]}}, mac_weight});

    // ============================================================
    // Main FSM
    // ============================================================
    always @(posedge clk or negedge rstn) begin
        if (!rstn) begin
            state <= ST_IDLE;
            after_wait_state <= ST_IDLE;

            web  <= 1'b0;
            enb  <= 1'b1;
            dinb <= 32'd0;
            addr <= 10'd0;
            done <= 1'b0;

            fmap_size <= 6'd0;
            mid_size  <= 6'd0;
            out_size  <= 6'd0;

            run_in_size  <= 6'd0;
            run_out_size <= 6'd0;
            run_in_base  <= 10'd0;
            run_out_base <= 10'd0;

            bias0 <= 8'sd0;
            bias1 <= 8'sd0;

            for (i = 0; i < 9; i = i + 1) begin
                w0[i] <= 8'sd0;
                w1[i] <= 8'sd0;
            end

            layer_id <= 1'b0;

            ox <= 6'd0;
            oy <= 6'd0;

            k_idx   <= 4'd0;
            issue_k <= 4'd0;

            pipe_k0 <= 4'd0;
            pipe_k1 <= 4'd0;

            pipe_byte0 <= 2'd0;
            pipe_byte1 <= 2'd0;

            pipe_v0 <= 1'b0;
            pipe_v1 <= 1'b0;

            pix_q6 <= 8'sd0;

            acc_q12 <= 32'sd0;
            conv_out_q6 <= 8'sd0;

            out_linear_idx <= 12'd0;
            pack_word <= 32'd0;
        end
        else begin
            web <= 1'b0;
            enb <= 1'b1;

            case (state)

                // ====================================================
                // Wait for CPU flag Data Memory[900] = 1
                // ====================================================
                ST_IDLE: begin
                    done <= 1'b0;

                    addr <= ADDR_CNN_START;
                    after_wait_state <= ST_CHECK_START;
                    state <= ST_WAIT1;
                end

                ST_CHECK_START: begin
                    if (doutb[0] == 1'b1) begin
                        addr <= ADDR_SIZE;
                        after_wait_state <= ST_CAP_SIZE;
                        state <= ST_WAIT1;
                    end
                    else begin
                        addr <= ADDR_CNN_START;
                        after_wait_state <= ST_CHECK_START;
                        state <= ST_WAIT1;
                    end
                end

                // ====================================================
                // Config / weight / polling read wait.
                // addr is registered in CNN, and BRAM latency is 1,
                // so one WAIT state before capture is still needed.
                // ====================================================
                ST_WAIT1: begin
                    state <= after_wait_state;
                end

                ST_WAIT2: begin
                    state <= after_wait_state;
                end

                ST_WAIT3: begin
                    state <= after_wait_state;
                end

                // ====================================================
                // Load size / bias / weights
                // ====================================================
                ST_CAP_SIZE: begin
                    fmap_size <= doutb[6:1];
                    mid_size  <= doutb[6:1] - 6'd2;
                    out_size  <= doutb[6:1] - 6'd4;

                    addr <= ADDR_BIAS0;
                    after_wait_state <= ST_CAP_BIAS0;
                    state <= ST_WAIT1;
                end

                ST_CAP_BIAS0: begin
                    bias0 <= doutb[7:0];

                    addr <= ADDR_BIAS1;
                    after_wait_state <= ST_CAP_BIAS1;
                    state <= ST_WAIT1;
                end

                ST_CAP_BIAS1: begin
                    bias1 <= doutb[7:0];

                    addr <= BASE_W0;
                    after_wait_state <= ST_CAP_W0_0;
                    state <= ST_WAIT1;
                end

                // Weight packing:
                // address0: w00 w01 w02 w10
                // address1: w11 w12 w20 w21
                // address2: w22 blank blank blank
                ST_CAP_W0_0: begin
                    w0[0] <= get_byte(doutb, 2'd0);
                    w0[1] <= get_byte(doutb, 2'd1);
                    w0[2] <= get_byte(doutb, 2'd2);
                    w0[3] <= get_byte(doutb, 2'd3);

                    addr <= BASE_W0 + 10'd1;
                    after_wait_state <= ST_CAP_W0_1;
                    state <= ST_WAIT1;
                end

                ST_CAP_W0_1: begin
                    w0[4] <= get_byte(doutb, 2'd0);
                    w0[5] <= get_byte(doutb, 2'd1);
                    w0[6] <= get_byte(doutb, 2'd2);
                    w0[7] <= get_byte(doutb, 2'd3);

                    addr <= BASE_W0 + 10'd2;
                    after_wait_state <= ST_CAP_W0_2;
                    state <= ST_WAIT1;
                end

                ST_CAP_W0_2: begin
                    w0[8] <= get_byte(doutb, 2'd0);

                    addr <= BASE_W1;
                    after_wait_state <= ST_CAP_W1_0;
                    state <= ST_WAIT1;
                end

                ST_CAP_W1_0: begin
                    w1[0] <= get_byte(doutb, 2'd0);
                    w1[1] <= get_byte(doutb, 2'd1);
                    w1[2] <= get_byte(doutb, 2'd2);
                    w1[3] <= get_byte(doutb, 2'd3);

                    addr <= BASE_W1 + 10'd1;
                    after_wait_state <= ST_CAP_W1_1;
                    state <= ST_WAIT1;
                end

                ST_CAP_W1_1: begin
                    w1[4] <= get_byte(doutb, 2'd0);
                    w1[5] <= get_byte(doutb, 2'd1);
                    w1[6] <= get_byte(doutb, 2'd2);
                    w1[7] <= get_byte(doutb, 2'd3);

                    addr <= BASE_W1 + 10'd2;
                    after_wait_state <= ST_CAP_W1_2;
                    state <= ST_WAIT1;
                end

                ST_CAP_W1_2: begin
                    w1[8] <= get_byte(doutb, 2'd0);

                    layer_id <= 1'b0;
                    state <= ST_BEGIN_LAYER;
                end

                // ====================================================
                // Begin layer
                // ====================================================
                ST_BEGIN_LAYER: begin
                    ox <= 6'd0;
                    oy <= 6'd0;
                    pack_word <= 32'd0;
                    out_linear_idx <= 12'd0;

                    if (layer_id == 1'b0) begin
                        run_in_size  <= fmap_size;
                        run_out_size <= mid_size;
                        run_in_base  <= BASE_IN;
                        run_out_base <= BASE_MID;
                    end
                    else begin
                        run_in_size  <= mid_size;
                        run_out_size <= out_size;
                        run_in_base  <= BASE_MID;
                        run_out_base <= BASE_OUT;
                    end

                    state <= ST_START_PIXEL;
                end

                // ====================================================
                // Start one output pixel.
                // Initialize pipeline.
                // ====================================================
                ST_START_PIXEL: begin
                    acc_q12 <= 32'sd0;

                    issue_k <= 4'd0;
                    k_idx   <= 4'd0;

                    pipe_k0 <= 4'd0;
                    pipe_k1 <= 4'd0;

                    pipe_byte0 <= 2'd0;
                    pipe_byte1 <= 2'd0;

                    pipe_v0 <= 1'b0;
                    pipe_v1 <= 1'b0;

                    state <= ST_PIPE_MAC;
                end

                // ====================================================
                // Pipelined MAC loop.
                //
                // Effective timing with registered addr + BRAM latency=1:
                // cycle 0: issue k0
                // cycle 1: issue k1
                // cycle 2: MAC k0, issue k2
                // cycle 3: MAC k1, issue k3
                // ...
                // cycle 10: MAC k8, then go add bias
                // ====================================================
                ST_PIPE_MAC: begin
                    // 1. MAC data issued two CNN cycles ago
                    if (pipe_v1) begin
                        pix_q6  <= mac_pixel;
                        acc_q12 <= acc_q12 + mac_mul;
                        k_idx   <= pipe_k1;
                    end

                    // 2. Issue new address every cycle while issue_k <= 8
                    if (issue_k <= 4'd8) begin
                        addr <= issue_word_addr;

                        pipe_k0    <= issue_k;
                        pipe_byte0 <= issue_byte_now;
                        pipe_v0    <= 1'b1;

                        issue_k <= issue_k + 4'd1;
                    end
                    else begin
                        pipe_k0    <= 4'd0;
                        pipe_byte0 <= 2'd0;
                        pipe_v0    <= 1'b0;
                    end

                    // 3. Shift metadata pipeline
                    pipe_k1    <= pipe_k0;
                    pipe_byte1 <= pipe_byte0;
                    pipe_v1    <= pipe_v0;

                    // 4. If k8 has been MACed, finish this pixel
                    if (pipe_v1 && pipe_k1 == 4'd8) begin
                        state <= ST_ADD_BIAS;
                    end
                    else begin
                        state <= ST_PIPE_MAC;
                    end
                end

                ST_ADD_BIAS: begin
                    if (layer_id == 1'b0)
                        acc_q12 <= acc_q12 + ($signed({{24{bias0[7]}}, bias0}) <<< 6);
                    else
                        acc_q12 <= acc_q12 + ($signed({{24{bias1[7]}}, bias1}) <<< 6);

                    state <= ST_ROUND;
                end

                ST_ROUND: begin
                    conv_out_q6 <= round_sat_q12_to_q6(acc_q12);
                    state <= ST_PREP_STORE;
                end

                ST_PREP_STORE: begin
                    out_linear_idx <= out_linear_now;
                    state <= ST_STORE_PACK;
                end

                ST_STORE_PACK: begin
                    pack_word <= pack_word_next;

                    if (should_write_word) begin
                        addr <= out_word_addr;
                        dinb <= pack_word_next;
                        web  <= 1'b0;
                        state <= ST_WRITE_DO;
                    end
                    else begin
                        state <= ST_ADV_PIXEL;
                    end
                end

                ST_WRITE_DO: begin
                    web <= 1'b1;
                    pack_word <= 32'd0;
                    state <= ST_ADV_PIXEL;
                end

                ST_ADV_PIXEL: begin
                    if (is_last_pixel) begin
                        if (layer_id == 1'b0) begin
                            layer_id <= 1'b1;
                            state <= ST_BEGIN_LAYER;
                        end
                        else begin
                            state <= ST_WRITE_STATUS_SETUP;
                        end
                    end
                    else begin
                        if (ox == (run_out_size - 6'd1)) begin
                            ox <= 6'd0;
                            oy <= oy + 6'd1;
                        end
                        else begin
                            ox <= ox + 6'd1;
                        end

                        state <= ST_START_PIXEL;
                    end
                end

                // ====================================================
                // Finish
                // Data Memory[13] = 0x00000401
                // [10:1] = BASE_OUT = 512
                // [0]    = done
                // ====================================================
                ST_WRITE_STATUS_SETUP: begin
                    addr <= ADDR_STATUS;
                    dinb <= {21'd0, BASE_OUT, 1'b1};
                    web  <= 1'b0;
                    state <= ST_WRITE_STATUS_DO;
                end

                ST_WRITE_STATUS_DO: begin
                    web <= 1'b1;
                    state <= ST_DONE;
                end

                ST_DONE: begin
                    web <= 1'b0;
                    done <= 1'b1;
                    state <= ST_DONE;
                end

                default: begin
                    state <= ST_IDLE;
                end

            endcase
        end
    end

endmodule


Editor is loading...
Leave a Comment