Untitled
unknown
plain_text
a month ago
41 kB
9
Indexable
module Simple_CPU (
input wire CLK,
input wire RSTN,
output wire dmem_en,
output reg dmem_we,
output reg [9:0] dmem_addr,
output reg [31:0] dmem_wdata,
input wire [31:0] dmem_rdata
);
assign dmem_en = 1'b1;
// ============================================================
// Instruction Memory
// ============================================================
reg [9:0] imem_addr;
wire [31:0] imem_rdata;
Instruction_Memory u_Instruction_Memory (
.clka (CLK),
.ena (1'b1),
.addra (imem_addr),
.douta (imem_rdata)
);
// ============================================================
// Opcode
// ============================================================
localparam OPCODE_RTYPE = 7'b0110011; // add, sub
localparam OPCODE_ITYPE = 7'b0010011; // addi
localparam OPCODE_LOAD = 7'b0000011; // lw
localparam OPCODE_STORE = 7'b0100011; // sw
localparam OPCODE_BRANCH = 7'b1100011; // beq, blt
// ============================================================
// FSM states
// Removed unused ST_WB_R / ST_WB_I for area cleanup
// ============================================================
localparam ST_FETCH_ADDR = 5'd0;
localparam ST_FETCH_WAIT1 = 5'd1;
localparam ST_FETCH_CAPTURE = 5'd2;
localparam ST_DECODE = 5'd3;
localparam ST_EXE_R = 5'd4;
localparam ST_EXE_I = 5'd5;
localparam ST_ADDR = 5'd6;
localparam ST_MEM_RD = 5'd7;
localparam ST_MEM_WAIT1 = 5'd8;
localparam ST_MEM_CAPTURE = 5'd9;
localparam ST_WB_LW = 5'd10;
localparam ST_MEM_WR_SETUP = 5'd11;
localparam ST_MEM_WR_DO = 5'd12;
localparam ST_BRANCH = 5'd13;
localparam ST_DONE = 5'd14;
localparam ST_HALT = 5'd15;
reg [4:0] state;
// ============================================================
// CPU internal registers
// ============================================================
reg [31:0] pc;
reg [31:0] pc_old;
reg [31:0] ir;
reg [31:0] reg_a;
reg [31:0] reg_b;
// Area optimization:
// Original version used 32-bit alu_out.
// Now lw/sw only need byte address [11:0], because dmem_addr uses [11:2].
reg [11:0] alu_addr;
reg [31:0] mdr;
// ============================================================
// Register file
// ============================================================
reg [31:0] regs [0:21];
integer i;
// ============================================================
// Instruction fields
// ============================================================
wire [6:0] opcode = ir[6:0];
wire [4:0] rd = ir[11:7];
wire [2:0] funct3 = ir[14:12];
wire [4:0] rs1 = ir[19:15];
wire [4:0] rs2 = ir[24:20];
wire [6:0] funct7 = ir[31:25];
// ============================================================
// Immediate generator
// ============================================================
wire signed [31:0] imm_i;
wire signed [31:0] imm_s;
wire signed [31:0] imm_b;
assign imm_i = {{20{ir[31]}}, ir[31:20]};
assign imm_s = {{20{ir[31]}}, ir[31:25], ir[11:7]};
assign imm_b = {{19{ir[31]}}, ir[31], ir[7],
ir[30:25], ir[11:8], 1'b0};
// ============================================================
// Instruction decode helper
// ============================================================
wire is_add = (opcode == OPCODE_RTYPE) &&
(funct3 == 3'b000) &&
(funct7 == 7'b0000000);
wire is_sub = (opcode == OPCODE_RTYPE) &&
(funct3 == 3'b000) &&
(funct7 == 7'b0100000);
wire is_addi = (opcode == OPCODE_ITYPE) &&
(funct3 == 3'b000);
wire is_lw = (opcode == OPCODE_LOAD) &&
(funct3 == 3'b010);
wire is_sw = (opcode == OPCODE_STORE) &&
(funct3 == 3'b010);
wire is_beq = (opcode == OPCODE_BRANCH) &&
(funct3 == 3'b000);
wire is_blt = (opcode == OPCODE_BRANCH) &&
(funct3 == 3'b100);
// Full address calculation wires.
// Do not write (reg_a + imm_i)[11:0], because some Verilog parsers dislike slicing expressions.
wire [31:0] mem_addr_i_full = reg_a + imm_i;
wire [31:0] mem_addr_s_full = reg_a + imm_s;
// ============================================================
// Main FSM
// ============================================================
always @(posedge CLK or negedge RSTN) begin
if (!RSTN) begin
state <= ST_FETCH_ADDR;
pc <= 32'd0;
pc_old <= 32'd0;
ir <= 32'd0;
reg_a <= 32'd0;
reg_b <= 32'd0;
alu_addr <= 12'd0;
mdr <= 32'd0;
imem_addr <= 10'd0;
dmem_addr <= 10'd0;
dmem_we <= 1'b0;
dmem_wdata <= 32'd0;
for (i = 0; i < 32; i = i + 1) begin
regs[i] <= 32'd0;
end
end
else begin
dmem_we <= 1'b0;
regs[0] <= 32'd0;
case (state)
// ====================================================
// Instruction fetch
// Keep instruction memory latency unchanged
// ====================================================
ST_FETCH_ADDR: begin
pc_old <= pc;
imem_addr <= pc[11:2];
state <= ST_FETCH_WAIT1;
end
ST_FETCH_WAIT1: begin
state <= ST_FETCH_CAPTURE;
end
ST_FETCH_CAPTURE: begin
ir <= imem_rdata;
pc <= pc + 32'd4;
state <= ST_DECODE;
end
// ====================================================
// Decode
// ====================================================
ST_DECODE: begin
reg_a <= (rs1 == 5'd0) ? 32'd0 : regs[rs1];
reg_b <= (rs2 == 5'd0) ? 32'd0 : regs[rs2];
case (opcode)
OPCODE_RTYPE: begin
if (is_add || is_sub)
state <= ST_EXE_R;
else
state <= ST_DONE;
end
OPCODE_ITYPE: begin
if (is_addi)
state <= ST_EXE_I;
else
state <= ST_DONE;
end
OPCODE_LOAD: begin
if (is_lw)
state <= ST_ADDR;
else
state <= ST_DONE;
end
OPCODE_STORE: begin
if (is_sw)
state <= ST_ADDR;
else
state <= ST_DONE;
end
OPCODE_BRANCH: begin
if (is_beq || is_blt)
state <= ST_BRANCH;
else
state <= ST_DONE;
end
default: begin
state <= ST_DONE;
end
endcase
end
// ====================================================
// R-type: add / sub
// EXE + WB merged, same as your current faster version
// ====================================================
ST_EXE_R: begin
if (rd != 5'd0) begin
if (is_sub)
regs[rd] <= reg_a - reg_b;
else
regs[rd] <= reg_a + reg_b;
end
state <= ST_FETCH_ADDR;
end
// ====================================================
// I-type: addi
// EXE + WB merged, same as your current faster version
// ====================================================
ST_EXE_I: begin
if (rd != 5'd0)
regs[rd] <= reg_a + imm_i;
state <= ST_FETCH_ADDR;
end
// ====================================================
// lw / sw address calculation
// Area optimization:
// 32-bit alu_out -> 12-bit alu_addr
// ====================================================
ST_ADDR: begin
if (opcode == OPCODE_LOAD) begin
alu_addr <= mem_addr_i_full[11:0];
state <= ST_MEM_RD;
end
else begin
alu_addr <= mem_addr_s_full[11:0];
state <= ST_MEM_WR_SETUP;
end
end
// ====================================================
// lw
// Keep BRAM latency behavior unchanged
// ====================================================
ST_MEM_RD: begin
dmem_addr <= alu_addr[11:2];
state <= ST_MEM_WAIT1;
end
ST_MEM_WAIT1: begin
state <= ST_MEM_CAPTURE;
end
ST_MEM_CAPTURE: begin
mdr <= dmem_rdata;
state <= ST_WB_LW;
end
ST_WB_LW: begin
if (rd != 5'd0)
regs[rd] <= mdr;
state <= ST_FETCH_ADDR;
end
// ====================================================
// sw
// Keep write setup / write enable alignment unchanged
// ====================================================
ST_MEM_WR_SETUP: begin
dmem_addr <= alu_addr[11:2];
dmem_wdata <= reg_b;
dmem_we <= 1'b0;
state <= ST_MEM_WR_DO;
end
ST_MEM_WR_DO: begin
dmem_we <= 1'b1;
state <= ST_FETCH_ADDR;
end
// ====================================================
// Branch: beq / blt
// Keep unchanged
// ====================================================
ST_BRANCH: begin
if (is_beq) begin
if (reg_a == reg_b)
pc <= pc_old + imm_b;
end
else if (is_blt) begin
if ($signed(reg_a) < $signed(reg_b))
pc <= pc_old + imm_b;
end
state <= ST_FETCH_ADDR;
end
// ====================================================
// Program done
// ====================================================
ST_DONE: begin
dmem_we <= 1'b0;
state <= ST_HALT;
end
ST_HALT: begin
dmem_we <= 1'b0;
state <= ST_HALT;
end
default: begin
state <= ST_FETCH_ADDR;
end
endcase
end
end
endmodule
module CNN (
input wire clk,
input wire rstn,
input wire [31:0] doutb,
output reg web,
output reg enb,
output reg [31:0] dinb,
output reg [9:0] addr,
output reg done
);
// ============================================================
// Fixed memory map
// ============================================================
localparam BASE_W0 = 10'd0;
localparam BASE_W1 = 10'd3;
localparam ADDR_SIZE = 10'd12;
localparam ADDR_STATUS = 10'd13;
localparam ADDR_BIAS0 = 10'd14;
localparam ADDR_BIAS1 = 10'd15;
localparam BASE_IN = 10'd16;
localparam BASE_MID = 10'd272;
localparam BASE_OUT = 10'd512;
localparam ADDR_CNN_START = 10'd900;
// ============================================================
// Helper functions
// ============================================================
function [7:0] get_byte;
input [31:0] word;
input [1:0] sel;
begin
case (sel)
2'd0: get_byte = word[31:24];
2'd1: get_byte = word[23:16];
2'd2: get_byte = word[15:8];
2'd3: get_byte = word[7:0];
default: get_byte = 8'd0;
endcase
end
endfunction
function [31:0] put_byte;
input [31:0] old_word;
input [1:0] sel;
input [7:0] new_byte;
begin
put_byte = old_word;
case (sel)
2'd0: put_byte[31:24] = new_byte;
2'd1: put_byte[23:16] = new_byte;
2'd2: put_byte[15:8] = new_byte;
2'd3: put_byte[7:0] = new_byte;
default: put_byte = old_word;
endcase
end
endfunction
// ============================================================
// Q12 -> Q6 round-to-nearest ties-to-even + saturation
// Timing optimized version:
// Use signed floor rounding directly.
// This removes abs / negate / negate-back logic from the critical path.
// ============================================================
function signed [7:0] round_sat_q12_to_q6;
input signed [19:0] in_q12;
reg signed [19:0] floor_q6;
reg [5:0] frac;
reg round_up;
reg signed [19:0] rounded_q6;
begin
// Arithmetic shift right by 6.
// This is floor division for signed two's-complement values.
floor_q6 = in_q12 >>> 6;
frac = in_q12[5:0];
// round-to-nearest ties-to-even
//
// For both positive and negative numbers:
// frac > 32 -> round toward +1 from floor
// frac = 32 -> choose even result
//
// Example:
// +1.5 -> floor=1, frac=32, floor odd -> +2
// +2.5 -> floor=2, frac=32, floor even -> +2
// -1.5 -> floor=-2, frac=32, floor even -> -2
// -0.5 -> floor=-1, frac=32, floor odd -> 0
if (frac > 6'd32)
round_up = 1'b1;
else if (frac == 6'd32)
round_up = floor_q6[0];
else
round_up = 1'b0;
rounded_q6 = floor_q6 + {{19{1'b0}}, round_up};
// Saturate to signed 8-bit range
if (rounded_q6 > 20'sd127)
round_sat_q12_to_q6 = 8'sh7F;
else if (rounded_q6 < -20'sd128)
round_sat_q12_to_q6 = 8'sh80;
else
round_sat_q12_to_q6 = rounded_q6[7:0];
end
endfunction
// ============================================================
// FSM states
// ============================================================
reg [5:0] state;
reg [5:0] after_wait_state;
localparam ST_IDLE = 6'd0;
localparam ST_WAIT1 = 6'd1;
localparam ST_CAP_SIZE = 6'd4;
localparam ST_CAP_BIAS0 = 6'd5;
localparam ST_CAP_BIAS1 = 6'd6;
localparam ST_CAP_W0_0 = 6'd7;
localparam ST_CAP_W0_1 = 6'd8;
localparam ST_CAP_W0_2 = 6'd9;
localparam ST_CAP_W1_0 = 6'd10;
localparam ST_CAP_W1_1 = 6'd11;
localparam ST_CAP_W1_2 = 6'd12;
localparam ST_BEGIN_LAYER = 6'd13;
localparam ST_PRELOAD_INIT = 6'd14;
localparam ST_PRELOAD_REQ = 6'd15;
localparam ST_PRELOAD_WAIT = 6'd16;
localparam ST_PRELOAD_CAP = 6'd17;
localparam ST_MAC_ROW0 = 6'd18;
localparam ST_MAC_ROW1 = 6'd19;
localparam ST_MAC_ROW2 = 6'd20;
localparam ST_NEXT_CAP1 = 6'd21;
localparam ST_NEXT_CAP2 = 6'd22; // kept but bypassed
localparam ST_ADD_BIAS = 6'd23;
localparam ST_ROUND = 6'd24;
localparam ST_PREP_STORE = 6'd25;
localparam ST_STORE_PACK = 6'd26;
localparam ST_WRITE_DO = 6'd27;
localparam ST_ADV_PIXEL = 6'd28;
localparam ST_WRITE_STATUS_SETUP = 6'd29;
localparam ST_WRITE_STATUS_DO = 6'd30;
localparam ST_DONE = 6'd31;
localparam ST_CHECK_START = 6'd32;
// ============================================================
// Basic registers
// ============================================================
reg [5:0] fmap_size;
reg [5:0] mid_size;
reg [5:0] out_size;
reg [5:0] run_in_size;
reg [5:0] run_out_size;
reg [9:0] run_in_base;
reg [9:0] run_out_base;
reg signed [7:0] bias0;
reg signed [7:0] bias1;
reg signed [7:0] w0 [0:8];
reg signed [7:0] w1 [0:8];
reg layer_id;
reg [5:0] ox;
reg [5:0] oy;
// ============================================================
// Optimized arithmetic widths
// 8b x 8b product = 16b signed
// 9 products + bias safely fit in 20b signed
// ============================================================
reg signed [19:0] acc_q12;
reg signed [7:0] conv_out_q6;
reg [11:0] out_linear_idx;
reg [31:0] pack_word;
// ============================================================
// 3x3 sliding window registers
// ============================================================
reg signed [7:0] win00, win01, win02;
reg signed [7:0] win10, win11, win12;
reg signed [7:0] win20, win21, win22;
reg signed [7:0] next_col0;
reg signed [7:0] next_col1;
reg signed [7:0] next_col2;
reg [3:0] preload_cnt;
reg [1:0] preload_row_q;
reg [1:0] preload_col_q;
reg [1:0] preload_byte_sel_q;
// Timing fix for preload address
reg [11:0] preload_base_idx_q;
reg [11:0] preload_current_idx_q;
reg [9:0] preload_addr_q;
reg [1:0] preload_bsel_q;
// Timing fix for next column address
reg [11:0] pixel_base_idx_q;
reg [1:0] next_byte_sel0;
reg [1:0] next_byte_sel1;
reg [1:0] next_byte_sel2;
reg [1:0] mac_row_idx;
// Registered multiplier outputs
reg signed [15:0] mul0_q;
reg signed [15:0] mul1_q;
reg signed [15:0] mul2_q;
integer i;
// ============================================================
// Preload row/col decode only for choosing winXX destination
// ============================================================
function [1:0] preload_row_func;
input [3:0] cnt;
begin
case (cnt)
4'd0, 4'd1, 4'd2: preload_row_func = 2'd0;
4'd3, 4'd4, 4'd5: preload_row_func = 2'd1;
default: preload_row_func = 2'd2;
endcase
end
endfunction
function [1:0] preload_col_func;
input [3:0] cnt;
begin
case (cnt)
4'd0, 4'd3, 4'd6: preload_col_func = 2'd0;
4'd1, 4'd4, 4'd7: preload_col_func = 2'd1;
default: preload_col_func = 2'd2;
endcase
end
endfunction
wire [1:0] preload_row_now = preload_row_func(preload_cnt);
wire [1:0] preload_col_now = preload_col_func(preload_cnt);
// Base index is calculated only in ST_PRELOAD_INIT, not directly to addr.
wire [11:0] preload_base_idx_now =
({6'd0, oy} * {6'd0, run_in_size}) + {6'd0, ox};
wire [9:0] preload_base_addr_now =
run_in_base + preload_base_idx_now[11:2];
wire [1:0] preload_base_bsel_now =
preload_base_idx_now[1:0];
wire [11:0] preload_idx_plus1 =
preload_current_idx_q + 12'd1;
wire [11:0] preload_idx_row1 =
preload_base_idx_q + {6'd0, run_in_size};
wire [11:0] preload_idx_row2 =
preload_base_idx_q + ({6'd0, run_in_size} << 1);
wire [11:0] preload_next_idx_calc =
(preload_cnt == 4'd2) ? preload_idx_row1 :
(preload_cnt == 4'd5) ? preload_idx_row2 :
preload_idx_plus1;
wire [9:0] preload_next_addr_calc =
run_in_base + preload_next_idx_calc[11:2];
wire [1:0] preload_next_bsel_calc =
preload_next_idx_calc[1:0];
// ============================================================
// Next column address calculation
// Use pixel_base_idx_q to avoid oy * run_in_size directly to addr.
// ============================================================
wire has_next_x =
(ox < (run_out_size - 6'd1));
wire [11:0] next_linear0 =
pixel_base_idx_q + 12'd3;
wire [11:0] next_linear1 =
pixel_base_idx_q + {6'd0, run_in_size} + 12'd3;
wire [11:0] next_linear2 =
pixel_base_idx_q + ({6'd0, run_in_size} << 1) + 12'd3;
wire [9:0] next_addr0 =
run_in_base + next_linear0[11:2];
wire [9:0] next_addr1 =
run_in_base + next_linear1[11:2];
wire [9:0] next_addr2 =
run_in_base + next_linear2[11:2];
wire [1:0] next_bsel0 = next_linear0[1:0];
wire [1:0] next_bsel1 = next_linear1[1:0];
wire [1:0] next_bsel2 = next_linear2[1:0];
// ============================================================
// 3 DSP row MAC
// ============================================================
wire signed [7:0] mac_pix0 =
(mac_row_idx == 2'd0) ? win00 :
(mac_row_idx == 2'd1) ? win10 :
win20;
wire signed [7:0] mac_pix1 =
(mac_row_idx == 2'd0) ? win01 :
(mac_row_idx == 2'd1) ? win11 :
win21;
wire signed [7:0] mac_pix2 =
(mac_row_idx == 2'd0) ? win02 :
(mac_row_idx == 2'd1) ? win12 :
win22;
wire signed [7:0] mac_w0 =
(layer_id == 1'b0) ?
((mac_row_idx == 2'd0) ? w0[0] :
(mac_row_idx == 2'd1) ? w0[3] : w0[6]) :
((mac_row_idx == 2'd0) ? w1[0] :
(mac_row_idx == 2'd1) ? w1[3] : w1[6]);
wire signed [7:0] mac_w1 =
(layer_id == 1'b0) ?
((mac_row_idx == 2'd0) ? w0[1] :
(mac_row_idx == 2'd1) ? w0[4] : w0[7]) :
((mac_row_idx == 2'd0) ? w1[1] :
(mac_row_idx == 2'd1) ? w1[4] : w1[7]);
wire signed [7:0] mac_w2 =
(layer_id == 1'b0) ?
((mac_row_idx == 2'd0) ? w0[2] :
(mac_row_idx == 2'd1) ? w0[5] : w0[8]) :
((mac_row_idx == 2'd0) ? w1[2] :
(mac_row_idx == 2'd1) ? w1[5] : w1[8]);
// 8-bit signed x 8-bit signed = 16-bit signed
(* use_dsp = "yes" *) wire signed [15:0] mul0 =
mac_pix0 * mac_w0;
(* use_dsp = "yes" *) wire signed [15:0] mul1 =
mac_pix1 * mac_w1;
(* use_dsp = "yes" *) wire signed [15:0] mul2 =
mac_pix2 * mac_w2;
wire signed [19:0] mul0_ext = {{4{mul0_q[15]}}, mul0_q};
wire signed [19:0] mul1_ext = {{4{mul1_q[15]}}, mul1_q};
wire signed [19:0] mul2_ext = {{4{mul2_q[15]}}, mul2_q};
wire signed [19:0] row_sum_q12 =
mul0_ext + mul1_ext + mul2_ext;
wire signed [19:0] bias0_q12 =
{{6{bias0[7]}}, bias0, 6'b000000};
wire signed [19:0] bias1_q12 =
{{6{bias1[7]}}, bias1, 6'b000000};
// ============================================================
// Output packing
// ============================================================
wire [1:0] out_byte_sel =
out_linear_idx[1:0];
wire [9:0] out_word_addr =
run_out_base + out_linear_idx[11:2];
wire is_last_pixel =
(ox == (run_out_size - 6'd1)) &&
(oy == (run_out_size - 6'd1));
wire should_write_word =
(out_byte_sel == 2'd3) || is_last_pixel;
wire [31:0] pack_word_next =
put_byte(pack_word, out_byte_sel, conv_out_q6);
// ============================================================
// Main FSM
// ============================================================
always @(posedge clk or negedge rstn) begin
if (!rstn) begin
state <= ST_IDLE;
after_wait_state <= ST_IDLE;
web <= 1'b0;
enb <= 1'b1;
dinb <= 32'd0;
addr <= 10'd0;
done <= 1'b0;
fmap_size <= 6'd0;
mid_size <= 6'd0;
out_size <= 6'd0;
run_in_size <= 6'd0;
run_out_size <= 6'd0;
run_in_base <= 10'd0;
run_out_base <= 10'd0;
bias0 <= 8'sd0;
bias1 <= 8'sd0;
for (i = 0; i < 9; i = i + 1) begin
w0[i] <= 8'sd0;
w1[i] <= 8'sd0;
end
layer_id <= 1'b0;
ox <= 6'd0;
oy <= 6'd0;
acc_q12 <= 20'sd0;
conv_out_q6 <= 8'sd0;
out_linear_idx <= 12'd0;
pack_word <= 32'd0;
win00 <= 8'sd0; win01 <= 8'sd0; win02 <= 8'sd0;
win10 <= 8'sd0; win11 <= 8'sd0; win12 <= 8'sd0;
win20 <= 8'sd0; win21 <= 8'sd0; win22 <= 8'sd0;
next_col0 <= 8'sd0;
next_col1 <= 8'sd0;
next_col2 <= 8'sd0;
preload_cnt <= 4'd0;
preload_row_q <= 2'd0;
preload_col_q <= 2'd0;
preload_byte_sel_q <= 2'd0;
preload_base_idx_q <= 12'd0;
preload_current_idx_q <= 12'd0;
preload_addr_q <= 10'd0;
preload_bsel_q <= 2'd0;
pixel_base_idx_q <= 12'd0;
next_byte_sel0 <= 2'd0;
next_byte_sel1 <= 2'd0;
next_byte_sel2 <= 2'd0;
mac_row_idx <= 2'd0;
mul0_q <= 16'sd0;
mul1_q <= 16'sd0;
mul2_q <= 16'sd0;
end
else begin
web <= 1'b0;
enb <= 1'b1;
case (state)
ST_IDLE: begin
done <= 1'b0;
addr <= ADDR_CNN_START;
after_wait_state <= ST_CHECK_START;
state <= ST_WAIT1;
end
ST_CHECK_START: begin
if (doutb[0] == 1'b1) begin
addr <= ADDR_SIZE;
after_wait_state <= ST_CAP_SIZE;
state <= ST_WAIT1;
end
else begin
addr <= ADDR_CNN_START;
after_wait_state <= ST_CHECK_START;
state <= ST_WAIT1;
end
end
ST_WAIT1: begin
state <= after_wait_state;
end
ST_CAP_SIZE: begin
fmap_size <= doutb[6:1];
mid_size <= doutb[6:1] - 6'd2;
out_size <= doutb[6:1] - 6'd4;
addr <= ADDR_BIAS0;
after_wait_state <= ST_CAP_BIAS0;
state <= ST_WAIT1;
end
ST_CAP_BIAS0: begin
bias0 <= doutb[7:0];
addr <= ADDR_BIAS1;
after_wait_state <= ST_CAP_BIAS1;
state <= ST_WAIT1;
end
ST_CAP_BIAS1: begin
bias1 <= doutb[7:0];
addr <= BASE_W0;
after_wait_state <= ST_CAP_W0_0;
state <= ST_WAIT1;
end
ST_CAP_W0_0: begin
w0[0] <= get_byte(doutb, 2'd0);
w0[1] <= get_byte(doutb, 2'd1);
w0[2] <= get_byte(doutb, 2'd2);
w0[3] <= get_byte(doutb, 2'd3);
addr <= BASE_W0 + 10'd1;
after_wait_state <= ST_CAP_W0_1;
state <= ST_WAIT1;
end
ST_CAP_W0_1: begin
w0[4] <= get_byte(doutb, 2'd0);
w0[5] <= get_byte(doutb, 2'd1);
w0[6] <= get_byte(doutb, 2'd2);
w0[7] <= get_byte(doutb, 2'd3);
addr <= BASE_W0 + 10'd2;
after_wait_state <= ST_CAP_W0_2;
state <= ST_WAIT1;
end
ST_CAP_W0_2: begin
w0[8] <= get_byte(doutb, 2'd0);
addr <= BASE_W1;
after_wait_state <= ST_CAP_W1_0;
state <= ST_WAIT1;
end
ST_CAP_W1_0: begin
w1[0] <= get_byte(doutb, 2'd0);
w1[1] <= get_byte(doutb, 2'd1);
w1[2] <= get_byte(doutb, 2'd2);
w1[3] <= get_byte(doutb, 2'd3);
addr <= BASE_W1 + 10'd1;
after_wait_state <= ST_CAP_W1_1;
state <= ST_WAIT1;
end
ST_CAP_W1_1: begin
w1[4] <= get_byte(doutb, 2'd0);
w1[5] <= get_byte(doutb, 2'd1);
w1[6] <= get_byte(doutb, 2'd2);
w1[7] <= get_byte(doutb, 2'd3);
addr <= BASE_W1 + 10'd2;
after_wait_state <= ST_CAP_W1_2;
state <= ST_WAIT1;
end
ST_CAP_W1_2: begin
w1[8] <= get_byte(doutb, 2'd0);
layer_id <= 1'b0;
state <= ST_BEGIN_LAYER;
end
ST_BEGIN_LAYER: begin
ox <= 6'd0;
oy <= 6'd0;
pack_word <= 32'd0;
out_linear_idx <= 12'd0;
if (layer_id == 1'b0) begin
run_in_size <= fmap_size;
run_out_size <= mid_size;
run_in_base <= BASE_IN;
run_out_base <= BASE_MID;
end
else begin
run_in_size <= mid_size;
run_out_size <= out_size;
run_in_base <= BASE_MID;
run_out_base <= BASE_OUT;
end
state <= ST_PRELOAD_INIT;
end
ST_PRELOAD_INIT: begin
preload_cnt <= 4'd0;
preload_base_idx_q <= preload_base_idx_now;
preload_current_idx_q <= preload_base_idx_now;
pixel_base_idx_q <= preload_base_idx_now;
preload_addr_q <= preload_base_addr_now;
preload_bsel_q <= preload_base_bsel_now;
state <= ST_PRELOAD_REQ;
end
ST_PRELOAD_REQ: begin
addr <= preload_addr_q;
preload_row_q <= preload_row_now;
preload_col_q <= preload_col_now;
preload_byte_sel_q <= preload_bsel_q;
state <= ST_PRELOAD_WAIT;
end
ST_PRELOAD_WAIT: begin
state <= ST_PRELOAD_CAP;
end
ST_PRELOAD_CAP: begin
case ({preload_row_q, preload_col_q})
4'b00_00: win00 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b00_01: win01 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b00_10: win02 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b01_00: win10 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b01_01: win11 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b01_10: win12 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b10_00: win20 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b10_01: win21 <= $signed(get_byte(doutb, preload_byte_sel_q));
4'b10_10: win22 <= $signed(get_byte(doutb, preload_byte_sel_q));
default: begin end
endcase
if (preload_cnt == 4'd8) begin
acc_q12 <= 20'sd0;
mac_row_idx <= 2'd0;
state <= ST_MAC_ROW0;
end
else begin
preload_cnt <= preload_cnt + 4'd1;
preload_current_idx_q <= preload_next_idx_calc;
preload_addr_q <= preload_next_addr_calc;
preload_bsel_q <= preload_next_bsel_calc;
state <= ST_PRELOAD_REQ;
end
end
ST_MAC_ROW0: begin
mul0_q <= mul0;
mul1_q <= mul1;
mul2_q <= mul2;
if (has_next_x) begin
addr <= next_addr0;
next_byte_sel0 <= next_bsel0;
end
mac_row_idx <= 2'd1;
state <= ST_MAC_ROW1;
end
ST_MAC_ROW1: begin
acc_q12 <= row_sum_q12;
mul0_q <= mul0;
mul1_q <= mul1;
mul2_q <= mul2;
if (has_next_x) begin
addr <= next_addr1;
next_byte_sel1 <= next_bsel1;
end
mac_row_idx <= 2'd2;
state <= ST_MAC_ROW2;
end
ST_MAC_ROW2: begin
acc_q12 <= acc_q12 + row_sum_q12;
mul0_q <= mul0;
mul1_q <= mul1;
mul2_q <= mul2;
if (has_next_x) begin
next_col0 <= $signed(get_byte(doutb, next_byte_sel0));
addr <= next_addr2;
next_byte_sel2 <= next_bsel2;
end
state <= ST_NEXT_CAP1;
end
// ====================================================
// Removed real ST_NEXT_CAP2 usage.
// ST_MAC_ROW2 issued next_addr2.
// ST_NEXT_CAP1 is the wait/capture stage for next_col1.
// ST_ADD_BIAS captures next_col2 while adding bias.
// This keeps 2-cycle BRAM latency for next_col2.
// ====================================================
ST_NEXT_CAP1: begin
acc_q12 <= acc_q12 + row_sum_q12;
if (has_next_x) begin
next_col1 <= $signed(get_byte(doutb, next_byte_sel1));
end
state <= ST_ADD_BIAS;
end
// Kept only as safety fallback; normally not entered.
ST_NEXT_CAP2: begin
state <= ST_ADD_BIAS;
end
ST_ADD_BIAS: begin
if (has_next_x) begin
next_col2 <= $signed(get_byte(doutb, next_byte_sel2));
end
if (layer_id == 1'b0)
acc_q12 <= acc_q12 + bias0_q12;
else
acc_q12 <= acc_q12 + bias1_q12;
state <= ST_ROUND;
end
ST_ROUND: begin
conv_out_q6 <= round_sat_q12_to_q6(acc_q12);
state <= ST_STORE_PACK;
end
ST_PREP_STORE: begin
state <= ST_STORE_PACK;
end
ST_STORE_PACK: begin
pack_word <= pack_word_next;
if (should_write_word) begin
addr <= out_word_addr;
dinb <= pack_word_next;
web <= 1'b0;
state <= ST_WRITE_DO;
end
else begin
state <= ST_ADV_PIXEL;
end
end
ST_WRITE_DO: begin
web <= 1'b1;
pack_word <= 32'd0;
state <= ST_ADV_PIXEL;
end
ST_ADV_PIXEL: begin
if (!is_last_pixel) begin
out_linear_idx <= out_linear_idx + 12'd1;
end
if (is_last_pixel) begin
if (layer_id == 1'b0) begin
layer_id <= 1'b1;
state <= ST_BEGIN_LAYER;
end
else begin
state <= ST_WRITE_STATUS_SETUP;
end
end
else begin
if (ox == (run_out_size - 6'd1)) begin
ox <= 6'd0;
oy <= oy + 6'd1;
state <= ST_PRELOAD_INIT;
end
else begin
// Slide 3x3 window left and insert prefetched right column
win00 <= win01;
win01 <= win02;
win02 <= next_col0;
win10 <= win11;
win11 <= win12;
win12 <= next_col1;
win20 <= win21;
win21 <= win22;
win22 <= next_col2;
ox <= ox + 6'd1;
pixel_base_idx_q <= pixel_base_idx_q + 12'd1;
acc_q12 <= 20'sd0;
mac_row_idx <= 2'd0;
state <= ST_MAC_ROW0;
end
end
end
ST_WRITE_STATUS_SETUP: begin
addr <= ADDR_STATUS;
dinb <= {21'd0, BASE_OUT, 1'b1};
web <= 1'b0;
state <= ST_WRITE_STATUS_DO;
end
ST_WRITE_STATUS_DO: begin
web <= 1'b1;
state <= ST_DONE;
end
ST_DONE: begin
web <= 1'b0;
done <= 1'b1;
state <= ST_DONE;
end
default: begin
state <= ST_IDLE;
end
endcase
end
end
endmodule
Editor is loading...
Leave a Comment