Untitled
unknown
plain_text
18 days ago
46 kB
3
Indexable
module Simple_CPU (
input wire CLK,
input wire RSTN,
output wire dmem_en,
output reg dmem_we,
output reg [9:0] dmem_addr,
output reg [31:0] dmem_wdata,
input wire [31:0] dmem_rdata
);
assign dmem_en = 1'b1;
// ============================================================
// Instruction Memory
// ============================================================
reg [9:0] imem_addr;
wire [31:0] imem_rdata;
Instruction_Memory u_Instruction_Memory (
.clka (CLK),
.ena (1'b1),
.addra (imem_addr),
.douta (imem_rdata)
);
// ============================================================
// Opcode
// ============================================================
localparam OPCODE_RTYPE = 7'b0110011; // add, sub
localparam OPCODE_ITYPE = 7'b0010011; // addi
localparam OPCODE_LOAD = 7'b0000011; // lw
localparam OPCODE_STORE = 7'b0100011; // sw
localparam OPCODE_BRANCH = 7'b1100011; // beq, blt
// ============================================================
// FSM states
// ============================================================
localparam ST_FETCH_ADDR = 5'd0;
localparam ST_FETCH_WAIT1 = 5'd1;
localparam ST_FETCH_CAPTURE = 5'd2;
localparam ST_DECODE = 5'd3;
localparam ST_EXE_R = 5'd4;
localparam ST_EXE_I = 5'd5;
localparam ST_ADDR = 5'd6;
localparam ST_MEM_RD = 5'd7;
localparam ST_MEM_WAIT1 = 5'd8;
localparam ST_MEM_CAPTURE = 5'd9;
localparam ST_WB_LW = 5'd10; // kept as fallback only
localparam ST_MEM_WR_SETUP = 5'd11;
localparam ST_MEM_WR_DO = 5'd12;
localparam ST_BRANCH = 5'd13;
localparam ST_DONE = 5'd14;
localparam ST_HALT = 5'd15;
reg [4:0] state;
// ============================================================
// CPU internal registers
// ============================================================
reg [31:0] pc;
reg [31:0] pc_old;
reg [31:0] ir;
reg [31:0] reg_a;
reg [31:0] reg_b;
// lw/sw only need byte address [11:0], because dmem_addr uses [11:2]
reg [11:0] alu_addr;
// ============================================================
// Sparse Register File
// COE fixed version:
// Keep only x1~x13, x20, x21
//
// x0 is hardwired to zero.
// x14~x19 and x22~x31 return 0 if read, ignore if written.
// ============================================================
reg [31:0] r1;
reg [31:0] r2;
reg [31:0] r3;
reg [31:0] r4;
reg [31:0] r5;
reg [31:0] r6;
reg [31:0] r7;
reg [31:0] r8;
reg [31:0] r9;
reg [31:0] r10;
reg [31:0] r11;
reg [31:0] r12;
reg [31:0] r13;
reg [31:0] r20;
reg [31:0] r21;
// ============================================================
// Instruction fields
// ============================================================
wire [6:0] opcode = ir[6:0];
wire [4:0] rd = ir[11:7];
wire [2:0] funct3 = ir[14:12];
wire [4:0] rs1 = ir[19:15];
wire [4:0] rs2 = ir[24:20];
wire [6:0] funct7 = ir[31:25];
// ============================================================
// Immediate generator
// ============================================================
wire signed [31:0] imm_i;
wire signed [31:0] imm_s;
wire signed [31:0] imm_b;
assign imm_i = {{20{ir[31]}}, ir[31:20]};
assign imm_s = {{20{ir[31]}}, ir[31:25], ir[11:7]};
assign imm_b = {{19{ir[31]}}, ir[31], ir[7],
ir[30:25], ir[11:8], 1'b0};
// ============================================================
// Instruction decode helper
// ============================================================
wire is_add = (opcode == OPCODE_RTYPE) &&
(funct3 == 3'b000) &&
(funct7 == 7'b0000000);
wire is_sub = (opcode == OPCODE_RTYPE) &&
(funct3 == 3'b000) &&
(funct7 == 7'b0100000);
wire is_addi = (opcode == OPCODE_ITYPE) &&
(funct3 == 3'b000);
wire is_lw = (opcode == OPCODE_LOAD) &&
(funct3 == 3'b010);
wire is_sw = (opcode == OPCODE_STORE) &&
(funct3 == 3'b010);
wire is_beq = (opcode == OPCODE_BRANCH) &&
(funct3 == 3'b000);
wire is_blt = (opcode == OPCODE_BRANCH) &&
(funct3 == 3'b100);
// Full address calculation wires.
// Do not write (reg_a + imm_i)[11:0], because some Verilog parsers dislike slicing expressions.
wire [31:0] mem_addr_i_full = reg_a + imm_i;
wire [31:0] mem_addr_s_full = reg_a + imm_s;
// ============================================================
// Sparse register read
// ============================================================
function [31:0] read_reg;
input [4:0] raddr;
begin
case (raddr)
5'd0: read_reg = 32'd0;
5'd1: read_reg = r1;
5'd2: read_reg = r2;
5'd3: read_reg = r3;
5'd4: read_reg = r4;
5'd5: read_reg = r5;
5'd6: read_reg = r6;
5'd7: read_reg = r7;
5'd8: read_reg = r8;
5'd9: read_reg = r9;
5'd10: read_reg = r10;
5'd11: read_reg = r11;
5'd12: read_reg = r12;
5'd13: read_reg = r13;
5'd20: read_reg = r20;
5'd21: read_reg = r21;
default: read_reg = 32'd0;
endcase
end
endfunction
// ============================================================
// Main FSM
// ============================================================
always @(posedge CLK or negedge RSTN) begin
if (!RSTN) begin
state <= ST_FETCH_ADDR;
pc <= 32'd0;
pc_old <= 32'd0;
ir <= 32'd0;
reg_a <= 32'd0;
reg_b <= 32'd0;
alu_addr <= 12'd0;
imem_addr <= 10'd0;
dmem_addr <= 10'd0;
dmem_we <= 1'b0;
dmem_wdata <= 32'd0;
r1 <= 32'd0;
r2 <= 32'd0;
r3 <= 32'd0;
r4 <= 32'd0;
r5 <= 32'd0;
r6 <= 32'd0;
r7 <= 32'd0;
r8 <= 32'd0;
r9 <= 32'd0;
r10 <= 32'd0;
r11 <= 32'd0;
r12 <= 32'd0;
r13 <= 32'd0;
r20 <= 32'd0;
r21 <= 32'd0;
end
else begin
dmem_we <= 1'b0;
case (state)
// ====================================================
// Instruction fetch
// Keep instruction memory latency unchanged
// ====================================================
ST_FETCH_ADDR: begin
pc_old <= pc;
imem_addr <= pc[11:2];
state <= ST_FETCH_WAIT1;
end
ST_FETCH_WAIT1: begin
state <= ST_FETCH_CAPTURE;
end
ST_FETCH_CAPTURE: begin
ir <= imem_rdata;
pc <= pc + 32'd4;
state <= ST_DECODE;
end
// ====================================================
// Decode
// ====================================================
ST_DECODE: begin
reg_a <= read_reg(rs1);
reg_b <= read_reg(rs2);
case (opcode)
OPCODE_RTYPE: begin
if (is_add || is_sub)
state <= ST_EXE_R;
else
state <= ST_DONE;
end
OPCODE_ITYPE: begin
if (is_addi)
state <= ST_EXE_I;
else
state <= ST_DONE;
end
OPCODE_LOAD: begin
if (is_lw)
state <= ST_ADDR;
else
state <= ST_DONE;
end
OPCODE_STORE: begin
if (is_sw)
state <= ST_ADDR;
else
state <= ST_DONE;
end
OPCODE_BRANCH: begin
if (is_beq || is_blt)
state <= ST_BRANCH;
else
state <= ST_DONE;
end
default: begin
state <= ST_DONE;
end
endcase
end
// ====================================================
// R-type: add / sub
// EXE + WB merged
// ====================================================
ST_EXE_R: begin
if (rd != 5'd0) begin
case (rd)
5'd1: r1 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd2: r2 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd3: r3 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd4: r4 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd5: r5 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd6: r6 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd7: r7 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd8: r8 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd9: r9 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd10: r10 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd11: r11 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd12: r12 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd13: r13 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd20: r20 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
5'd21: r21 <= is_sub ? (reg_a - reg_b) : (reg_a + reg_b);
default: begin end
endcase
end
state <= ST_FETCH_ADDR;
end
// ====================================================
// I-type: addi
// EXE + WB merged
// ====================================================
ST_EXE_I: begin
if (rd != 5'd0) begin
case (rd)
5'd1: r1 <= reg_a + imm_i;
5'd2: r2 <= reg_a + imm_i;
5'd3: r3 <= reg_a + imm_i;
5'd4: r4 <= reg_a + imm_i;
5'd5: r5 <= reg_a + imm_i;
5'd6: r6 <= reg_a + imm_i;
5'd7: r7 <= reg_a + imm_i;
5'd8: r8 <= reg_a + imm_i;
5'd9: r9 <= reg_a + imm_i;
5'd10: r10 <= reg_a + imm_i;
5'd11: r11 <= reg_a + imm_i;
5'd12: r12 <= reg_a + imm_i;
5'd13: r13 <= reg_a + imm_i;
5'd20: r20 <= reg_a + imm_i;
5'd21: r21 <= reg_a + imm_i;
default: begin end
endcase
end
state <= ST_FETCH_ADDR;
end
// ====================================================
// lw / sw address calculation
// 12-bit byte address only
// ====================================================
ST_ADDR: begin
if (opcode == OPCODE_LOAD) begin
alu_addr <= mem_addr_i_full[11:0];
state <= ST_MEM_RD;
end
else begin
alu_addr <= mem_addr_s_full[11:0];
state <= ST_MEM_WR_SETUP;
end
end
// ====================================================
// lw
// Keep BRAM latency behavior unchanged
// ====================================================
ST_MEM_RD: begin
dmem_addr <= alu_addr[11:2];
state <= ST_MEM_WAIT1;
end
ST_MEM_WAIT1: begin
state <= ST_MEM_CAPTURE;
end
// ====================================================
// lw write-back optimized:
// Original:
// ST_MEM_CAPTURE: mdr <= dmem_rdata
// ST_WB_LW: reg <= mdr
//
// New:
// ST_MEM_CAPTURE: reg <= dmem_rdata
//
// This removes mdr and saves 1 cycle per lw.
// ====================================================
ST_MEM_CAPTURE: begin
if (rd != 5'd0) begin
case (rd)
5'd1: r1 <= dmem_rdata;
5'd2: r2 <= dmem_rdata;
5'd3: r3 <= dmem_rdata;
5'd4: r4 <= dmem_rdata;
5'd5: r5 <= dmem_rdata;
5'd6: r6 <= dmem_rdata;
5'd7: r7 <= dmem_rdata;
5'd8: r8 <= dmem_rdata;
5'd9: r9 <= dmem_rdata;
5'd10: r10 <= dmem_rdata;
5'd11: r11 <= dmem_rdata;
5'd12: r12 <= dmem_rdata;
5'd13: r13 <= dmem_rdata;
5'd20: r20 <= dmem_rdata;
5'd21: r21 <= dmem_rdata;
default: begin end
endcase
end
state <= ST_FETCH_ADDR;
end
// Kept only as fallback. Normally not entered.
ST_WB_LW: begin
state <= ST_FETCH_ADDR;
end
// ====================================================
// sw
// Keep write setup / write enable alignment unchanged
// ====================================================
ST_MEM_WR_SETUP: begin
dmem_addr <= alu_addr[11:2];
dmem_wdata <= reg_b;
dmem_we <= 1'b0;
state <= ST_MEM_WR_DO;
end
ST_MEM_WR_DO: begin
dmem_we <= 1'b1;
state <= ST_FETCH_ADDR;
end
// ====================================================
// Branch: beq / blt
// Keep unchanged
// ====================================================
ST_BRANCH: begin
if (is_beq) begin
if (reg_a == reg_b)
pc <= pc_old + imm_b;
end
else if (is_blt) begin
if ($signed(reg_a) < $signed(reg_b))
pc <= pc_old + imm_b;
end
state <= ST_FETCH_ADDR;
end
// ====================================================
// Program done
// ====================================================
ST_DONE: begin
dmem_we <= 1'b0;
state <= ST_HALT;
end
ST_HALT: begin
dmem_we <= 1'b0;
state <= ST_HALT;
end
default: begin
state <= ST_FETCH_ADDR;
end
endcase
end
end
endmodule
module CNN (
input wire clk,
input wire rstn,
input wire [31:0] doutb,
output reg web,
output reg enb,
output reg [31:0] dinb,
output reg [9:0] addr,
output reg done
);
// ============================================================
// Fixed memory map
// ============================================================
localparam BASE_W0 = 10'd0;
localparam BASE_W1 = 10'd3;
localparam ADDR_SIZE = 10'd12;
localparam ADDR_STATUS = 10'd13;
localparam ADDR_BIAS0 = 10'd14;
localparam ADDR_BIAS1 = 10'd15;
localparam BASE_IN = 10'd16;
localparam BASE_MID = 10'd272;
localparam BASE_OUT = 10'd512;
localparam ADDR_CNN_START = 10'd900;
// ============================================================
// Helper functions
// ============================================================
function [7:0] get_byte;
input [31:0] word;
input [1:0] sel;
begin
case (sel)
2'd0: get_byte = word[31:24];
2'd1: get_byte = word[23:16];
2'd2: get_byte = word[15:8];
2'd3: get_byte = word[7:0];
default: get_byte = 8'd0;
endcase
end
endfunction
function [31:0] put_byte;
input [31:0] old_word;
input [1:0] sel;
input [7:0] new_byte;
begin
put_byte = old_word;
case (sel)
2'd0: put_byte[31:24] = new_byte;
2'd1: put_byte[23:16] = new_byte;
2'd2: put_byte[15:8] = new_byte;
2'd3: put_byte[7:0] = new_byte;
default: put_byte = old_word;
endcase
end
endfunction
// ============================================================
// Q12 -> Q6 round-to-nearest ties-to-even + saturation
// Timing optimized signed-floor version
// ============================================================
function signed [7:0] round_sat_q12_to_q6;
input signed [19:0] in_q12;
reg signed [19:0] floor_q6;
reg [5:0] frac;
reg round_up;
reg signed [19:0] rounded_q6;
begin
floor_q6 = in_q12 >>> 6;
frac = in_q12[5:0];
if (frac > 6'd32)
round_up = 1'b1;
else if (frac == 6'd32)
round_up = floor_q6[0];
else
round_up = 1'b0;
rounded_q6 = floor_q6 + {{19{1'b0}}, round_up};
if (rounded_q6 > 20'sd127)
round_sat_q12_to_q6 = 8'sh7F;
else if (rounded_q6 < -20'sd128)
round_sat_q12_to_q6 = 8'sh80;
else
round_sat_q12_to_q6 = rounded_q6[7:0];
end
endfunction
// ============================================================
// FSM states
// ============================================================
reg [5:0] state;
reg [5:0] after_wait_state;
localparam ST_IDLE = 6'd0;
localparam ST_WAIT1 = 6'd1;
localparam ST_CAP_SIZE = 6'd4;
localparam ST_CAP_BIAS0 = 6'd5;
localparam ST_CAP_BIAS1 = 6'd6;
localparam ST_CAP_W0_0 = 6'd7;
localparam ST_CAP_W0_1 = 6'd8;
localparam ST_CAP_W0_2 = 6'd9;
localparam ST_CAP_W1_0 = 6'd10;
localparam ST_CAP_W1_1 = 6'd11;
localparam ST_CAP_W1_2 = 6'd12;
localparam ST_BEGIN_LAYER = 6'd13;
localparam ST_PRELOAD_INIT = 6'd14;
localparam ST_PRELOAD_REQ = 6'd15;
localparam ST_PRELOAD_WAIT = 6'd16;
localparam ST_PRELOAD_CAP = 6'd17;
localparam ST_MAC_ROW0 = 6'd18;
localparam ST_MAC_ROW1 = 6'd19;
localparam ST_MAC_ROW2 = 6'd20;
localparam ST_NEXT_CAP1 = 6'd21;
localparam ST_NEXT_CAP2 = 6'd22; // kept as fallback only
localparam ST_ADD_BIAS = 6'd23;
localparam ST_ROUND = 6'd24;
localparam ST_PREP_STORE = 6'd25;
localparam ST_STORE_PACK = 6'd26;
localparam ST_WRITE_DO = 6'd27;
localparam ST_ADV_PIXEL = 6'd28;
localparam ST_WRITE_STATUS_SETUP = 6'd29;
localparam ST_WRITE_STATUS_DO = 6'd30;
localparam ST_DONE = 6'd31;
localparam ST_CHECK_START = 6'd32;
// ============================================================
// Basic registers
// ============================================================
reg [5:0] fmap_size;
reg [5:0] mid_size;
reg [5:0] out_size;
reg [5:0] run_in_size;
reg [5:0] run_out_size;
reg [9:0] run_in_base;
reg [9:0] run_out_base;
reg signed [7:0] bias0;
reg signed [7:0] bias1;
reg signed [7:0] w0 [0:8];
reg signed [7:0] w1 [0:8];
reg layer_id;
reg [5:0] ox;
reg [5:0] oy;
// 20-bit accumulator version
reg signed [19:0] acc_q12;
reg signed [7:0] conv_out_q6;
reg [11:0] out_linear_idx;
reg [31:0] pack_word;
// ============================================================
// Timing optimization for 115 MHz address generation
// ============================================================
reg [11:0] row_base_idx_q;
reg [11:0] run_in_size_ext_q;
reg [11:0] run_in_size_x2_q;
reg [11:0] preload_row1_idx_q;
reg [11:0] preload_row2_idx_q;
// ============================================================
// 3x3 sliding window registers
// ============================================================
reg signed [7:0] win00, win01, win02;
reg signed [7:0] win10, win11, win12;
reg signed [7:0] win20, win21, win22;
reg signed [7:0] next_col0;
reg signed [7:0] next_col1;
reg signed [7:0] next_col2;
reg [3:0] preload_cnt;
reg [1:0] preload_row_q;
reg [1:0] preload_col_q;
reg [1:0] preload_byte_sel_q;
reg [11:0] preload_base_idx_q;
reg [11:0] preload_current_idx_q;
reg [9:0] preload_addr_q;
reg [1:0] preload_bsel_q;
reg [11:0] pixel_base_idx_q;
reg [1:0] next_byte_sel0;
reg [1:0] next_byte_sel1;
reg [1:0] next_byte_sel2;
reg [1:0] mac_row_idx;
reg signed [15:0] mul0_q;
reg signed [15:0] mul1_q;
reg signed [15:0] mul2_q;
integer i;
// ============================================================
// Preload row/col decode
// ============================================================
function [1:0] preload_row_func;
input [3:0] cnt;
begin
case (cnt)
4'd0, 4'd1, 4'd2: preload_row_func = 2'd0;
4'd3, 4'd4, 4'd5: preload_row_func = 2'd1;
default: preload_row_func = 2'd2;
endcase
end
endfunction
function [1:0] preload_col_func;
input [3:0] cnt;
begin
case (cnt)
4'd0, 4'd3, 4'd6: preload_col_func = 2'd0;
4'd1, 4'd4, 4'd7: preload_col_func = 2'd1;
default: preload_col_func = 2'd2;
endcase
end
endfunction
wire [1:0] preload_row_now = preload_row_func(preload_cnt);
wire [1:0] preload_col_now = preload_col_func(preload_cnt);
// ============================================================
// Timing-optimized preload address calculation
// No oy * run_in_size and no direct run_in_size -> preload_addr_q path
// ============================================================
wire [11:0] preload_base_idx_now =
row_base_idx_q;
wire [9:0] preload_base_addr_now =
run_in_base + row_base_idx_q[11:2];
wire [1:0] preload_base_bsel_now =
row_base_idx_q[1:0];
wire [11:0] preload_idx_plus1 =
preload_current_idx_q + 12'd1;
wire [11:0] preload_idx_row1 =
preload_row1_idx_q;
wire [11:0] preload_idx_row2 =
preload_row2_idx_q;
wire [11:0] preload_next_idx_calc =
(preload_cnt == 4'd2) ? preload_idx_row1 :
(preload_cnt == 4'd5) ? preload_idx_row2 :
preload_idx_plus1;
wire [9:0] preload_next_addr_calc =
run_in_base + preload_next_idx_calc[11:2];
wire [1:0] preload_next_bsel_calc =
preload_next_idx_calc[1:0];
// ============================================================
// Next column address calculation
// Use registered run_in_size_ext_q, not run_in_size directly
// ============================================================
wire has_next_x =
(ox < (run_out_size - 6'd1));
wire [11:0] next_linear0 =
pixel_base_idx_q + 12'd3;
wire [11:0] next_linear1 =
pixel_base_idx_q + run_in_size_ext_q + 12'd3;
wire [11:0] next_linear2 =
pixel_base_idx_q + run_in_size_x2_q + 12'd3;
wire [9:0] next_addr0 =
run_in_base + next_linear0[11:2];
wire [9:0] next_addr1 =
run_in_base + next_linear1[11:2];
wire [9:0] next_addr2 =
run_in_base + next_linear2[11:2];
wire [1:0] next_bsel0 = next_linear0[1:0];
wire [1:0] next_bsel1 = next_linear1[1:0];
wire [1:0] next_bsel2 = next_linear2[1:0];
// ============================================================
// 3 DSP row MAC
// ============================================================
wire signed [7:0] mac_pix0 =
(mac_row_idx == 2'd0) ? win00 :
(mac_row_idx == 2'd1) ? win10 :
win20;
wire signed [7:0] mac_pix1 =
(mac_row_idx == 2'd0) ? win01 :
(mac_row_idx == 2'd1) ? win11 :
win21;
wire signed [7:0] mac_pix2 =
(mac_row_idx == 2'd0) ? win02 :
(mac_row_idx == 2'd1) ? win12 :
win22;
wire signed [7:0] mac_w0 =
(layer_id == 1'b0) ?
((mac_row_idx == 2'd0) ? w0[0] :
(mac_row_idx == 2'd1) ? w0[3] : w0[6]) :
((mac_row_idx == 2'd0) ? w1[0] :
(mac_row_idx == 2'd1) ? w1[3] : w1[6]);
wire signed [7:0] mac_w1 =
(layer_id == 1'b0) ?
((mac_row_idx == 2'd0) ? w0[1] :
(mac_row_idx == 2'd1) ? w0[4] : w0[7]) :
((mac_row_idx == 2'd0) ? w1[1] :
(mac_row_idx == 2'd1) ? w1[4] : w1[7]);
wire signed [7:0] mac_w2 =
(layer_id == 1'b0) ?
((mac_row_idx == 2'd0) ? w0[2] :
(mac_row_idx == 2'd1) ? w0[5] : w0[8]) :
((mac_row_idx == 2'd0) ? w1[2] :
(mac_row_idx == 2'd1) ? w1[5] : w1[8]);
(* use_dsp = "yes" *) wire signed [15:0] mul0 =
mac_pix0 * mac_w0;
(* use_dsp = "yes" *) wire signed [15:0] mul1 =
mac_pix1 * mac_w1;
(* use_dsp = "yes" *) wire signed [15:0] mul2 =
mac_pix2 * mac_w2;
wire signed [19:0] mul0_ext = {{4{mul0_q[15]}}, mul0_q};
wire signed [19:0] mul1_ext = {{4{mul1_q[15]}}, mul1_q};
wire signed [19:0] mul2_ext = {{4{mul2_q[15]}}, mul2_q};
wire signed [19:0] row_sum_q12 =
mul0_ext + mul1_ext + mul2_ext;
wire signed [19:0] bias0_q12 =
{{6{bias0[7]}}, bias0, 6'b000000};
wire signed [19:0] bias1_q12 =
{{6{bias1[7]}}, bias1, 6'b000000};
// ============================================================
// Output packing
// ============================================================
wire [1:0] out_byte_sel =
out_linear_idx[1:0];
wire [9:0] out_word_addr =
run_out_base + out_linear_idx[11:2];
wire is_last_pixel =
(ox == (run_out_size - 6'd1)) &&
(oy == (run_out_size - 6'd1));
wire should_write_word =
(out_byte_sel == 2'd3) || is_last_pixel;
wire [31:0] pack_word_next =
put_byte(pack_word, out_byte_sel, conv_out_q6);
// ============================================================
// Main FSM
// ============================================================
always @(posedge clk or negedge rstn) begin
if (!rstn) begin
state <= ST_IDLE;
after_wait_state <= ST_IDLE;
web <= 1'b0;
enb <= 1'b1;
dinb <= 32'd0;
addr <= 10'd0;
done <= 1'b0;
fmap_size <= 6'd0;
mid_size <= 6'd0;
out_size <= 6'd0;
run_in_size <= 6'd0;
run_out_size <= 6'd0;
run_in_base <= 10'd0;
run_out_base <= 10'd0;
bias0 <= 8'sd0;
bias1 <= 8'sd0;
for (i = 0; i < 9; i = i + 1) begin
w0[i] <= 8'sd0;
w1[i] <= 8'sd0;
end
layer_id <= 1'b0;
ox <= 6'd0;
oy <= 6'd0;
acc_q12 <= 20'sd0;
conv_out_q6 <= 8'sd0;
out_linear_idx <= 12'd0;
pack_word <= 32'd0;
row_base_idx_q <= 12'd0;
run_in_size_ext_q <= 12'd0;
run_in_size_x2_q <= 12'd0;
preload_row1_idx_q <= 12'd0;
preload_row2_idx_q <= 12'd0;
win00 <= 8'sd0; win01 <= 8'sd0; win02 <= 8'sd0;
win10 <= 8'sd0; win11 <= 8'sd0; win12 <= 8'sd0;
win20 <= 8'sd0; win21 <= 8'sd0; win22 <= 8'sd0;
next_col0 <= 8'sd0;
next_col1 <= 8'sd0;
next_col2 <= 8'sd0;
preload_cnt <= 4'd0;
preload_row_q <= 2'd0;
preload_col_q <= 2'd0;
preload_byte_sel_q <= 2'd0;
preload_base_idx_q <= 12'd0;
preload_current_idx_q <= 12'd0;
preload_addr_q <= 10'd0;
preload_bsel_q <= 2'd0;
pixel_base_idx_q <= 12'd0;
next_byte_sel0 <= 2'd0;
next_byte_sel1 <= 2'd0;
next_byte_sel2 <= 2'd0;
mac_row_idx <= 2'd0;
mul0_q <= 16'sd0;
mul1_q <= 16'sd0;
mul2_q <= 16'sd0;
end
else begin
web <= 1'b0;
enb <= 1'b1;
case (state)
ST_IDLE: begin
done <= 1'b0;
addr <= ADDR_CNN_START;
after_wait_state <= ST_CHECK_START;
state <= ST_WAIT1;
end
ST_CHECK_START: begin
if (doutb[0] == 1'b1) begin
addr <= ADDR_SIZE;
after_wait_state <= ST_CAP_SIZE;
state <= ST_WAIT1;
end
else begin
addr <= ADDR_CNN_START;
after_wait_state <= ST_CHECK_START;
state <= ST_WAIT1;
end
end
ST_WAIT1: begin
state <= after_wait_state;
end
ST_CAP_SIZE: begin
fmap_size <= doutb[6:1];
mid_size <= doutb[6:1] - 6'd2;
out_size <= doutb[6:1] - 6'd4;
addr <= ADDR_BIAS0;
after_wait_state <= ST_CAP_BIAS0;
state <= ST_WAIT1;
end
ST_CAP_BIAS0: begin
bias0 <= doutb[7:0];
addr <= ADDR_BIAS1;
after_wait_state <= ST_CAP_BIAS1;
state <= ST_WAIT1;
end
ST_CAP_BIAS1: begin
bias1 <= doutb[7:0];
addr <= BASE_W0;
after_wait_state <= ST_CAP_W0_0;
state <= ST_WAIT1;
end
ST_CAP_W0_0: begin
w0[0] <= get_byte(doutb, 2'd0);
w0[1] <= get_byte(doutb, 2'd1);
w0[2] <= get_byte(doutb, 2'd2);
w0[3] <= get_byte(doutb, 2'd3);
addr <= BASE_W0 + 10'd1;
after_wait_state <= ST_CAP_W0_1;
state <= ST_WAIT1;
end
ST_CAP_W0_1: begin
w0[4] <= get_byte(doutb, 2'd0);
w0[5] <= get_byte(doutb, 2'd1);
w0[6] <= get_byte(doutb, 2'd2);
w0[7] <= get_byte(doutb, 2'd3);
addr <= BASE_W0 + 10'd2;
after_wait_state <= ST_CAP_W0_2;
state <= ST_WAIT1;
end
ST_CAP_W0_2: begin
w0[8] <= get_byte(doutb, 2'd0);
addr <= BASE_W1;
after_wait_state <= ST_CAP_W1_0;
state <= ST_WAIT1;
end
ST_CAP_W1_0: begin
w1[0] <= get_byte(doutb, 2'd0);
w1[1] <= get_byte(doutb, 2'd1);
w1[2] <= get_byte(doutb, 2'd2);
w1[3] <= get_byte(doutb, 2'd3);
addr <= BASE_W1 + 10'd1;
after_wait_state <= ST_CAP_W1_1;
state <= ST_WAIT1;
end
ST_CAP_W1_1: begin
w1[4] <= get_byte(doutb, 2'd0);
w1[5] <= get_byte(doutb, 2'd1);
w1[6] <= get_byte(doutb, 2'd2);
w1[7] <= get_byte(doutb, 2'd3);
addr <= BASE_W1 + 10'd2;
after_wait_state <= ST_CAP_W1_2;
state <= ST_WAIT1;
end
ST_CAP_W1_2: begin
w1[8] <= get_byte(doutb, 2'd0);
layer_id <= 1'b0;
state <= ST_BEGIN_LAYER;
end
ST_BEGIN_LAYER: begin
ox <= 6'd0;
oy <= 6'd0;
pack_word <= 32'd0;
out_linear_idx <= 12'd0;
row_base_idx_q <= 12'd0;
if (layer_id == 1'b0) begin
run_in_size <= fmap_size;
run_out_size <= mid_size;
run_in_base <= BASE_IN;
run_out_base <= BASE_MID;
run_in_size_ext_q <= {6'd0, fmap_size};
run_in_size_x2_q <= ({6'd0, fmap_size} << 1);
end
else begin
run_in_size <= mid_size;
run_out_size <= out_size;
run_in_base <= BASE_MID;
run_out_base <= BASE_OUT;
run_in_size_ext_q <= {6'd0, mid_size};
run_in_size_x2_q <= ({6'd0, mid_size} << 1);
end
state <= ST_PRELOAD_INIT;
end
ST_PRELOAD_INIT: begin
preload_cnt <= 4'd0;
preload_base_idx_q <= row_base_idx_q;
preload_current_idx_q <= row_base_idx_q;
pixel_base_idx_q <= row_base_idx_q;
preload_row1_idx_q <= row_base_idx_q + run_in_size_ext_q;
preload_row2_idx_q <= row_base_idx_q + run_in_size_x2_q;
preload_addr_q <= preload_base_addr_now;
preload_bsel_q <= preload_base_bsel_now;
state <= ST_PRELOAD_REQ;
end
ST_PRELOAD_REQ: begin
addr <= preload_addr_q;
preload_row_q <= preload_row_now;
preload_col_q <= preload_col_now;
preload_byte_sel_q <= preload_bsel_q;
state <= ST_PRELOAD_WAIT;
end
ST_PRELOAD_WAIT: begin
state <= ST_PRELOAD_CAP;
end
ST_PRELOAD_CAP: begin
case ({preload_row_q, preload_col_q})
4'b00_00: win00 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b00_01: win01 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b00_10: win02 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b01_00: win10 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b01_01: win11 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b01_10: win12 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b10_00: win20 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b10_01: win21 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b10_10: win22 <= $signed(get_byte(doutb, preload_byte_sel_q));
default: begin end
endcase
if (preload_cnt == 4'd8) begin
acc_q12 <= 20'sd0;
mac_row_idx <= 2'd0;
state <= ST_MAC_ROW0;
end
else begin
preload_cnt <= preload_cnt + 4'd1;
preload_current_idx_q <= preload_next_idx_calc;
preload_addr_q <= preload_next_addr_calc;
preload_bsel_q <= preload_next_bsel_calc;
state <= ST_PRELOAD_REQ;
end
end
ST_MAC_ROW0: begin
mul0_q <= mul0;
mul1_q <= mul1;
mul2_q <= mul2;
if (has_next_x) begin
addr <= next_addr0;
next_byte_sel0 <= next_bsel0;
end
mac_row_idx <= 2'd1;
state <= ST_MAC_ROW1;
end
ST_MAC_ROW1: begin
acc_q12 <= row_sum_q12;
mul0_q <= mul0;
mul1_q <= mul1;
mul2_q <= mul2;
if (has_next_x) begin
addr <= next_addr1;
next_byte_sel1 <= next_bsel1;
end
mac_row_idx <= 2'd2;
state <= ST_MAC_ROW2;
end
ST_MAC_ROW2: begin
acc_q12 <= acc_q12 + row_sum_q12;
mul0_q <= mul0;
mul1_q <= mul1;
mul2_q <= mul2;
if (has_next_x) begin
next_col0 <= $signed(get_byte(doutb, next_byte_sel0));
addr <= next_addr2;
next_byte_sel2 <= next_bsel2;
end
state <= ST_NEXT_CAP1;
end
ST_NEXT_CAP1: begin
acc_q12 <= acc_q12 + row_sum_q12;
if (has_next_x) begin
next_col1 <= $signed(get_byte(doutb, next_byte_sel1));
end
state <= ST_ADD_BIAS;
end
ST_NEXT_CAP2: begin
state <= ST_ADD_BIAS;
end
ST_ADD_BIAS: begin
if (has_next_x) begin
next_col2 <= $signed(get_byte(doutb, next_byte_sel2));
end
if (layer_id == 1'b0)
acc_q12 <= acc_q12 + bias0_q12;
else
acc_q12 <= acc_q12 + bias1_q12;
state <= ST_ROUND;
end
ST_ROUND: begin
conv_out_q6 <= round_sat_q12_to_q6(acc_q12);
state <= ST_STORE_PACK;
end
ST_PREP_STORE: begin
state <= ST_STORE_PACK;
end
ST_STORE_PACK: begin
pack_word <= pack_word_next;
if (should_write_word) begin
addr <= out_word_addr;
dinb <= pack_word_next;
web <= 1'b0;
state <= ST_WRITE_DO;
end
else begin
state <= ST_ADV_PIXEL;
end
end
ST_WRITE_DO: begin
web <= 1'b1;
pack_word <= 32'd0;
state <= ST_ADV_PIXEL;
end
ST_ADV_PIXEL: begin
if (!is_last_pixel) begin
out_linear_idx <= out_linear_idx + 12'd1;
end
if (is_last_pixel) begin
if (layer_id == 1'b0) begin
layer_id <= 1'b1;
state <= ST_BEGIN_LAYER;
end
else begin
state <= ST_WRITE_STATUS_SETUP;
end
end
else begin
if (ox == (run_out_size - 6'd1)) begin
ox <= 6'd0;
oy <= oy + 6'd1;
row_base_idx_q <= row_base_idx_q + run_in_size_ext_q;
state <= ST_PRELOAD_INIT;
end
else begin
win00 <= win01;
win01 <= win02;
win02 <= next_col0;
win10 <= win11;
win11 <= win12;
win12 <= next_col1;
win20 <= win21;
win21 <= win22;
win22 <= next_col2;
ox <= ox + 6'd1;
pixel_base_idx_q <= pixel_base_idx_q + 12'd1;
acc_q12 <= 20'sd0;
mac_row_idx <= 2'd0;
state <= ST_MAC_ROW0;
end
end
end
ST_WRITE_STATUS_SETUP: begin
addr <= ADDR_STATUS;
dinb <= {21'd0, BASE_OUT, 1'b1};
web <= 1'b0;
state <= ST_WRITE_STATUS_DO;
end
ST_WRITE_STATUS_DO: begin
web <= 1'b1;
state <= ST_DONE;
end
ST_DONE: begin
web <= 1'b0;
done <= 1'b1;
state <= ST_DONE;
end
default: begin
state <= ST_IDLE;
end
endcase
end
end
endmodule
Editor is loading...
Leave a Comment