Untitled
unknown
plain_text
20 days ago
40 kB
3
Indexable
module Simple_CPU (
input wire CLK,
input wire RSTN,
output wire dmem_en,
output reg dmem_we,
output reg [9:0] dmem_addr,
output reg [31:0] dmem_wdata,
input wire [31:0] dmem_rdata
);
assign dmem_en = 1'b1;
// ============================================================
// Instruction Memory
// ============================================================
reg [9:0] imem_addr;
wire [31:0] imem_rdata;
Instruction_Memory u_Instruction_Memory (
.clka (CLK),
.ena (1'b1),
.addra (imem_addr),
.douta (imem_rdata)
);
// ============================================================
// Opcode
// ============================================================
localparam OPCODE_RTYPE = 7'b0110011; // add, sub
localparam OPCODE_ITYPE = 7'b0010011; // addi
localparam OPCODE_LOAD = 7'b0000011; // lw
localparam OPCODE_STORE = 7'b0100011; // sw
localparam OPCODE_BRANCH = 7'b1100011; // beq, blt
// ============================================================
// FSM states
// ============================================================
localparam ST_FETCH_ADDR = 6'd0;
localparam ST_FETCH_WAIT1 = 6'd1;
localparam ST_FETCH_WAIT2 = 6'd2;
localparam ST_FETCH_WAIT3 = 6'd3;
localparam ST_FETCH_CAPTURE = 6'd4;
localparam ST_DECODE = 6'd5;
localparam ST_EXE_R = 6'd6;
localparam ST_WB_R = 6'd7;
localparam ST_EXE_I = 6'd8;
localparam ST_WB_I = 6'd9;
localparam ST_ADDR = 6'd10;
localparam ST_MEM_RD = 6'd11;
localparam ST_MEM_WAIT1 = 6'd12;
localparam ST_MEM_WAIT2 = 6'd13;
localparam ST_MEM_WAIT3 = 6'd14;
localparam ST_MEM_CAPTURE = 6'd15;
localparam ST_WB_LW = 6'd16;
localparam ST_MEM_WR_SETUP = 6'd17;
localparam ST_MEM_WR_DO = 6'd18;
localparam ST_BRANCH = 6'd19;
localparam ST_DONE = 6'd20;
// Bias generation states
localparam ST_BIAS_RD16 = 6'd21;
localparam ST_BIAS_WAIT16_1 = 6'd22;
localparam ST_BIAS_WAIT16_2 = 6'd23;
localparam ST_BIAS_WAIT16_3 = 6'd24;
localparam ST_BIAS_CAP16 = 6'd25;
localparam ST_BIAS_RD17 = 6'd26;
localparam ST_BIAS_WAIT17_1 = 6'd27;
localparam ST_BIAS_WAIT17_2 = 6'd28;
localparam ST_BIAS_WAIT17_3 = 6'd29;
localparam ST_BIAS_CAP17 = 6'd30;
localparam ST_BIAS_RD18 = 6'd31;
localparam ST_BIAS_WAIT18_1 = 6'd32;
localparam ST_BIAS_WAIT18_2 = 6'd33;
localparam ST_BIAS_WAIT18_3 = 6'd34;
localparam ST_BIAS_CAP18 = 6'd35;
localparam ST_BIAS0_SETUP = 6'd36;
localparam ST_BIAS0_WRITE = 6'd37;
localparam ST_BIAS1_SETUP = 6'd38;
localparam ST_BIAS1_WRITE = 6'd39;
localparam ST_CNN_FLAG_SETUP = 6'd40;
localparam ST_CNN_FLAG_WRITE = 6'd41;
localparam ST_HALT = 6'd42;
reg [5:0] state;
// ============================================================
// CPU internal registers
// ============================================================
reg [31:0] pc;
reg [31:0] pc_old;
reg [31:0] ir;
reg [31:0] reg_a;
reg [31:0] reg_b;
reg [31:0] alu_out;
reg [31:0] mdr;
// Bias generation registers
reg [31:0] bias_src16;
reg [31:0] bias_src17;
reg [31:0] bias_src18;
reg [31:0] bias0_word;
reg [31:0] bias1_word;
// ============================================================
// Register file
// ============================================================
reg [31:0] regs [0:31];
integer i;
// ============================================================
// Instruction fields
// ============================================================
wire [6:0] opcode = ir[6:0];
wire [4:0] rd = ir[11:7];
wire [2:0] funct3 = ir[14:12];
wire [4:0] rs1 = ir[19:15];
wire [4:0] rs2 = ir[24:20];
wire [6:0] funct7 = ir[31:25];
// ============================================================
// Immediate generator
// ============================================================
wire signed [31:0] imm_i;
wire signed [31:0] imm_s;
wire signed [31:0] imm_b;
assign imm_i = {{20{ir[31]}}, ir[31:20]};
assign imm_s = {{20{ir[31]}}, ir[31:25], ir[11:7]};
assign imm_b = {{19{ir[31]}}, ir[31], ir[7],
ir[30:25], ir[11:8], 1'b0};
// ============================================================
// Instruction decode helper
// ============================================================
wire is_add = (opcode == OPCODE_RTYPE) &&
(funct3 == 3'b000) &&
(funct7 == 7'b0000000);
wire is_sub = (opcode == OPCODE_RTYPE) &&
(funct3 == 3'b000) &&
(funct7 == 7'b0100000);
wire is_addi = (opcode == OPCODE_ITYPE) &&
(funct3 == 3'b000);
wire is_lw = (opcode == OPCODE_LOAD) &&
(funct3 == 3'b010);
wire is_sw = (opcode == OPCODE_STORE) &&
(funct3 == 3'b010);
wire is_beq = (opcode == OPCODE_BRANCH) &&
(funct3 == 3'b000);
wire is_blt = (opcode == OPCODE_BRANCH) &&
(funct3 == 3'b100);
// ============================================================
// Main FSM
// ============================================================
always @(posedge CLK or negedge RSTN) begin
if (!RSTN) begin
state <= ST_FETCH_ADDR;
pc <= 32'd0;
pc_old <= 32'd0;
ir <= 32'd0;
reg_a <= 32'd0;
reg_b <= 32'd0;
alu_out <= 32'd0;
mdr <= 32'd0;
bias_src16 <= 32'd0;
bias_src17 <= 32'd0;
bias_src18 <= 32'd0;
bias0_word <= 32'd0;
bias1_word <= 32'd0;
imem_addr <= 10'd0;
dmem_addr <= 10'd0;
dmem_we <= 1'b0;
dmem_wdata <= 32'd0;
for (i = 0; i < 32; i = i + 1) begin
regs[i] <= 32'd0;
end
end
else begin
dmem_we <= 1'b0;
regs[0] <= 32'd0;
case (state)
// ====================================================
// Instruction fetch
// ====================================================
ST_FETCH_ADDR: begin
pc_old <= pc;
imem_addr <= pc[11:2];
state <= ST_FETCH_WAIT1;
end
ST_FETCH_WAIT1: begin
state <= ST_FETCH_CAPTURE;
end
ST_FETCH_WAIT2: begin
state <= ST_FETCH_WAIT3;
end
ST_FETCH_WAIT3: begin
state <= ST_FETCH_CAPTURE;
end
ST_FETCH_CAPTURE: begin
ir <= imem_rdata;
pc <= pc + 32'd4;
state <= ST_DECODE;
end
// ====================================================
// Decode
// ====================================================
ST_DECODE: begin
reg_a <= (rs1 == 5'd0) ? 32'd0 : regs[rs1];
reg_b <= (rs2 == 5'd0) ? 32'd0 : regs[rs2];
case (opcode)
OPCODE_RTYPE: begin
if (is_add || is_sub)
state <= ST_EXE_R;
else
state <= ST_DONE;
end
OPCODE_ITYPE: begin
if (is_addi)
state <= ST_EXE_I;
else
state <= ST_DONE;
end
OPCODE_LOAD: begin
if (is_lw)
state <= ST_ADDR;
else
state <= ST_DONE;
end
OPCODE_STORE: begin
if (is_sw)
state <= ST_ADDR;
else
state <= ST_DONE;
end
OPCODE_BRANCH: begin
if (is_beq || is_blt)
state <= ST_BRANCH;
else
state <= ST_DONE;
end
default: begin
state <= ST_DONE;
end
endcase
end
// ====================================================
// R-type
// ====================================================
ST_EXE_R: begin
if (is_sub)
alu_out <= reg_a - reg_b;
else
alu_out <= reg_a + reg_b;
state <= ST_WB_R;
end
ST_WB_R: begin
if (rd != 5'd0)
regs[rd] <= alu_out;
state <= ST_FETCH_ADDR;
end
// ====================================================
// I-type addi
// ====================================================
ST_EXE_I: begin
alu_out <= reg_a + imm_i;
state <= ST_WB_I;
end
ST_WB_I: begin
if (rd != 5'd0)
regs[rd] <= alu_out;
state <= ST_FETCH_ADDR;
end
// ====================================================
// lw / sw address calculation
// ====================================================
ST_ADDR: begin
if (opcode == OPCODE_LOAD) begin
alu_out <= reg_a + imm_i;
state <= ST_MEM_RD;
end
else begin
alu_out <= reg_a + imm_s;
state <= ST_MEM_WR_SETUP;
end
end
ST_MEM_RD: begin
dmem_addr <= alu_out[11:2];
state <= ST_MEM_WAIT1;
end
ST_MEM_WAIT1: begin
state <= ST_MEM_CAPTURE;
end
ST_MEM_WAIT2: begin
state <= ST_MEM_CAPTURE;
end
ST_MEM_WAIT3: begin
state <= ST_MEM_CAPTURE;
end
ST_MEM_CAPTURE: begin
mdr <= dmem_rdata;
state <= ST_WB_LW;
end
ST_WB_LW: begin
if (rd != 5'd0)
regs[rd] <= mdr;
state <= ST_FETCH_ADDR;
end
ST_MEM_WR_SETUP: begin
dmem_addr <= alu_out[11:2];
dmem_wdata <= reg_b;
dmem_we <= 1'b0;
state <= ST_MEM_WR_DO;
end
ST_MEM_WR_DO: begin
dmem_we <= 1'b1;
state <= ST_FETCH_ADDR;
end
// ====================================================
// Branch
// ====================================================
ST_BRANCH: begin
if (is_beq) begin
if (reg_a == reg_b)
pc <= pc_old + imm_b;
end
else if (is_blt) begin
if ($signed(reg_a) < $signed(reg_b))
pc <= pc_old + imm_b;
end
state <= ST_FETCH_ADDR;
end
// ====================================================
// Program done, then generate bias before CNN
// ====================================================
ST_DONE: begin
dmem_addr <= 10'd16;
dmem_wdata <= 32'd0;
dmem_we <= 1'b0;
state <= ST_BIAS_RD16;
end
// ====================================================
// Bias generation
// Bias0 = Data Memory[16] + [17] + [18]
// Bias1 = if Data Memory[17] < Data Memory[16]
// then Data Memory[16] - Data Memory[17]
// else 1
// ====================================================
ST_BIAS_RD16: begin
dmem_addr <= 10'd16;
dmem_we <= 1'b0;
state <= ST_BIAS_WAIT16_1;
end
ST_BIAS_WAIT16_1: begin
state <= ST_BIAS_CAP16;
end
ST_BIAS_WAIT16_2: begin
state <= ST_BIAS_CAP16;
end
ST_BIAS_WAIT16_3: begin
state <= ST_BIAS_CAP16;
end
ST_BIAS_CAP16: begin
bias_src16 <= dmem_rdata;
dmem_addr <= 10'd17;
dmem_we <= 1'b0;
state <= ST_BIAS_RD17;
end
ST_BIAS_RD17: begin
dmem_addr <= 10'd17;
dmem_we <= 1'b0;
state <= ST_BIAS_WAIT17_1;
end
ST_BIAS_WAIT17_1: begin
state <= ST_BIAS_CAP17;
end
ST_BIAS_WAIT17_2: begin
state <= ST_BIAS_CAP17;
end
ST_BIAS_WAIT17_3: begin
state <= ST_BIAS_CAP17;
end
ST_BIAS_CAP17: begin
bias_src17 <= dmem_rdata;
dmem_addr <= 10'd18;
dmem_we <= 1'b0;
state <= ST_BIAS_RD18;
end
ST_BIAS_RD18: begin
dmem_addr <= 10'd18;
dmem_we <= 1'b0;
state <= ST_BIAS_WAIT18_1;
end
ST_BIAS_WAIT18_1: begin
state <= ST_BIAS_CAP18;
end
ST_BIAS_WAIT18_2: begin
state <= ST_BIAS_CAP18;
end
ST_BIAS_WAIT18_3: begin
state <= ST_BIAS_CAP18;
end
ST_BIAS_CAP18: begin
bias_src18 <= dmem_rdata;
bias0_word <= bias_src16 + bias_src17 + dmem_rdata;
if ($signed(bias_src17) < $signed(bias_src16))
bias1_word <= bias_src16 - bias_src17+ 32'd1;
else
bias1_word <= -32'sd15;
state <= ST_BIAS0_SETUP;
end
ST_BIAS0_SETUP: begin
dmem_addr <= 10'd14;
dmem_wdata <= bias0_word;
dmem_we <= 1'b0;
state <= ST_BIAS0_WRITE;
end
ST_BIAS0_WRITE: begin
dmem_addr <= 10'd14;
dmem_wdata <= bias0_word;
dmem_we <= 1'b1;
state <= ST_BIAS1_SETUP;
end
ST_BIAS1_SETUP: begin
dmem_addr <= 10'd15;
dmem_wdata <= bias1_word;
dmem_we <= 1'b0;
state <= ST_BIAS1_WRITE;
end
ST_BIAS1_WRITE: begin
dmem_addr <= 10'd15;
dmem_wdata <= bias1_word;
dmem_we <= 1'b1;
state <= ST_CNN_FLAG_SETUP;
end
// ====================================================
// Start CNN after bias is ready
// ====================================================
ST_CNN_FLAG_SETUP: begin
dmem_addr <= 10'd900;
dmem_wdata <= 32'd1;
dmem_we <= 1'b0;
state <= ST_CNN_FLAG_WRITE;
end
ST_CNN_FLAG_WRITE: begin
dmem_addr <= 10'd900;
dmem_wdata <= 32'd1;
dmem_we <= 1'b1;
state <= ST_HALT;
end
ST_HALT: begin
dmem_we <= 1'b0;
state <= ST_HALT;
end
default: begin
state <= ST_FETCH_ADDR;
end
endcase
end
end
endmodule
module CNN (
input wire clk,
input wire rstn,
input wire [31:0] doutb,
output reg web,
output reg enb,
output reg [31:0] dinb,
output reg [9:0] addr,
output reg done
);
// ============================================================
// Fixed memory map
// ============================================================
localparam BASE_W0 = 10'd0;
localparam BASE_W1 = 10'd3;
localparam ADDR_SIZE = 10'd12;
localparam ADDR_STATUS = 10'd13;
localparam ADDR_BIAS0 = 10'd14;
localparam ADDR_BIAS1 = 10'd15;
localparam BASE_IN = 10'd16;
localparam BASE_MID = 10'd272;
localparam BASE_OUT = 10'd512;
localparam ADDR_CNN_START = 10'd900;
// ============================================================
// Helper functions
// ============================================================
function [7:0] get_byte;
input [31:0] word;
input [1:0] sel;
begin
case (sel)
2'd0: get_byte = word[31:24];
2'd1: get_byte = word[23:16];
2'd2: get_byte = word[15:8];
2'd3: get_byte = word[7:0];
default: get_byte = 8'd0;
endcase
end
endfunction
function [31:0] put_byte;
input [31:0] old_word;
input [1:0] sel;
input [7:0] new_byte;
begin
put_byte = old_word;
case (sel)
2'd0: put_byte[31:24] = new_byte;
2'd1: put_byte[23:16] = new_byte;
2'd2: put_byte[15:8] = new_byte;
2'd3: put_byte[7:0] = new_byte;
default: put_byte = old_word;
endcase
end
endfunction
function signed [7:0] round_sat_q12_to_q6;
input signed [31:0] in_q12;
reg sign;
reg [31:0] abs_val;
reg [31:0] main_abs;
reg [5:0] frac_abs;
reg round_up;
reg [31:0] rounded_abs;
reg signed [31:0] rounded_signed;
begin
if (in_q12 < 0) begin
sign = 1'b1;
abs_val = -in_q12;
end
else begin
sign = 1'b0;
abs_val = in_q12;
end
main_abs = abs_val >> 6;
frac_abs = abs_val[5:0];
// round-to-nearest, ties-to-even
if (frac_abs > 6'd32)
round_up = 1'b1;
else if (frac_abs == 6'd32)
round_up = main_abs[0];
else
round_up = 1'b0;
rounded_abs = main_abs + (round_up ? 32'd1 : 32'd0);
if (sign)
rounded_signed = -$signed(rounded_abs);
else
rounded_signed = $signed(rounded_abs);
if (rounded_signed > 32'sd127)
round_sat_q12_to_q6 = 8'sh7F;
else if (rounded_signed < -32'sd128)
round_sat_q12_to_q6 = 8'sh80;
else
round_sat_q12_to_q6 = rounded_signed[7:0];
end
endfunction
// ============================================================
// Kernel coordinate functions
// k order:
// 0 1 2
// 3 4 5
// 6 7 8
// ============================================================
function [5:0] k_dx_func;
input [3:0] k;
begin
case (k)
4'd1, 4'd4, 4'd7: k_dx_func = 6'd1;
4'd2, 4'd5, 4'd8: k_dx_func = 6'd2;
default: k_dx_func = 6'd0;
endcase
end
endfunction
function [5:0] k_dy_func;
input [3:0] k;
begin
case (k)
4'd3, 4'd4, 4'd5: k_dy_func = 6'd1;
4'd6, 4'd7, 4'd8: k_dy_func = 6'd2;
default: k_dy_func = 6'd0;
endcase
end
endfunction
// ============================================================
// FSM states
// ============================================================
reg [5:0] state;
reg [5:0] after_wait_state;
localparam ST_IDLE = 6'd0;
localparam ST_WAIT1 = 6'd1;
localparam ST_WAIT2 = 6'd2;
localparam ST_WAIT3 = 6'd3;
localparam ST_CAP_SIZE = 6'd4;
localparam ST_CAP_BIAS0 = 6'd5;
localparam ST_CAP_BIAS1 = 6'd6;
localparam ST_CAP_W0_0 = 6'd7;
localparam ST_CAP_W0_1 = 6'd8;
localparam ST_CAP_W0_2 = 6'd9;
localparam ST_CAP_W1_0 = 6'd10;
localparam ST_CAP_W1_1 = 6'd11;
localparam ST_CAP_W1_2 = 6'd12;
localparam ST_BEGIN_LAYER = 6'd13;
localparam ST_START_PIXEL = 6'd14;
localparam ST_PIPE_MAC = 6'd15;
localparam ST_ADD_BIAS = 6'd17;
localparam ST_ROUND = 6'd18;
localparam ST_PREP_STORE = 6'd19;
localparam ST_STORE_PACK = 6'd20;
localparam ST_WRITE_DO = 6'd21;
localparam ST_ADV_PIXEL = 6'd22;
localparam ST_WRITE_STATUS_SETUP = 6'd23;
localparam ST_WRITE_STATUS_DO = 6'd24;
localparam ST_DONE = 6'd25;
localparam ST_CHECK_START = 6'd26;
// ============================================================
// Registers
// ============================================================
reg [5:0] fmap_size;
reg [5:0] mid_size;
reg [5:0] out_size;
reg [5:0] run_in_size;
reg [5:0] run_out_size;
reg [9:0] run_in_base;
reg [9:0] run_out_base;
reg signed [7:0] bias0;
reg signed [7:0] bias1;
reg signed [7:0] w0 [0:8];
reg signed [7:0] w1 [0:8];
reg layer_id; // 0 = layer1, 1 = layer2
reg [5:0] ox;
reg [5:0] oy;
// k_idx is kept for debug visibility.
// Actual pipeline uses issue_k / pipe_k0 / pipe_k1.
reg [3:0] k_idx;
reg [3:0] issue_k;
reg [3:0] pipe_k0;
reg [3:0] pipe_k1;
reg [1:0] pipe_byte0;
reg [1:0] pipe_byte1;
reg pipe_v0;
reg pipe_v1;
wire [1:0] read_byte_sel = pipe_byte1;
// Keep these for debug testbench
reg signed [7:0] pix_q6;
reg signed [31:0] acc_q12;
reg signed [7:0] conv_out_q6;
reg [11:0] out_linear_idx;
reg [31:0] pack_word;
integer i;
// ============================================================
// Issue address for current issue_k
// ============================================================
wire [5:0] issue_x = ox + k_dx_func(issue_k);
wire [5:0] issue_y = oy + k_dy_func(issue_k);
wire [11:0] issue_linear_idx =
({6'd0, issue_y} * {6'd0, run_in_size}) + {6'd0, issue_x};
wire [9:0] issue_word_addr =
run_in_base + issue_linear_idx[11:2];
wire [1:0] issue_byte_now =
issue_linear_idx[1:0];
// ============================================================
// Output packing
// ============================================================
wire [11:0] out_linear_now =
({6'd0, oy} * {6'd0, run_out_size}) + {6'd0, ox};
wire [1:0] out_byte_sel =
out_linear_idx[1:0];
wire [9:0] out_word_addr =
run_out_base + out_linear_idx[11:2];
wire is_last_pixel =
(ox == (run_out_size - 6'd1)) &&
(oy == (run_out_size - 6'd1));
wire should_write_word =
(out_byte_sel == 2'd3) || is_last_pixel;
wire [31:0] pack_word_next =
put_byte(pack_word, out_byte_sel, conv_out_q6);
// ============================================================
// MAC data aligned with pipe_k1 / pipe_byte1
// doutb corresponds to address issued two CNN cycles earlier.
// ============================================================
wire signed [7:0] mac_pixel =
$signed(get_byte(doutb, pipe_byte1));
wire signed [7:0] mac_weight =
(layer_id == 1'b0) ? w0[pipe_k1] : w1[pipe_k1];
wire signed [31:0] mac_mul =
$signed({{24{mac_pixel[7]}}, mac_pixel}) *
$signed({{24{mac_weight[7]}}, mac_weight});
// ============================================================
// Main FSM
// ============================================================
always @(posedge clk or negedge rstn) begin
if (!rstn) begin
state <= ST_IDLE;
after_wait_state <= ST_IDLE;
web <= 1'b0;
enb <= 1'b1;
dinb <= 32'd0;
addr <= 10'd0;
done <= 1'b0;
fmap_size <= 6'd0;
mid_size <= 6'd0;
out_size <= 6'd0;
run_in_size <= 6'd0;
run_out_size <= 6'd0;
run_in_base <= 10'd0;
run_out_base <= 10'd0;
bias0 <= 8'sd0;
bias1 <= 8'sd0;
for (i = 0; i < 9; i = i + 1) begin
w0[i] <= 8'sd0;
w1[i] <= 8'sd0;
end
layer_id <= 1'b0;
ox <= 6'd0;
oy <= 6'd0;
k_idx <= 4'd0;
issue_k <= 4'd0;
pipe_k0 <= 4'd0;
pipe_k1 <= 4'd0;
pipe_byte0 <= 2'd0;
pipe_byte1 <= 2'd0;
pipe_v0 <= 1'b0;
pipe_v1 <= 1'b0;
pix_q6 <= 8'sd0;
acc_q12 <= 32'sd0;
conv_out_q6 <= 8'sd0;
out_linear_idx <= 12'd0;
pack_word <= 32'd0;
end
else begin
web <= 1'b0;
enb <= 1'b1;
case (state)
// ====================================================
// Wait for CPU flag Data Memory[900] = 1
// ====================================================
ST_IDLE: begin
done <= 1'b0;
addr <= ADDR_CNN_START;
after_wait_state <= ST_CHECK_START;
state <= ST_WAIT1;
end
ST_CHECK_START: begin
if (doutb[0] == 1'b1) begin
addr <= ADDR_SIZE;
after_wait_state <= ST_CAP_SIZE;
state <= ST_WAIT1;
end
else begin
addr <= ADDR_CNN_START;
after_wait_state <= ST_CHECK_START;
state <= ST_WAIT1;
end
end
// ====================================================
// Config / weight / polling read wait.
// addr is registered in CNN, and BRAM latency is 1,
// so one WAIT state before capture is still needed.
// ====================================================
ST_WAIT1: begin
state <= after_wait_state;
end
ST_WAIT2: begin
state <= after_wait_state;
end
ST_WAIT3: begin
state <= after_wait_state;
end
// ====================================================
// Load size / bias / weights
// ====================================================
ST_CAP_SIZE: begin
fmap_size <= doutb[6:1];
mid_size <= doutb[6:1] - 6'd2;
out_size <= doutb[6:1] - 6'd4;
addr <= ADDR_BIAS0;
after_wait_state <= ST_CAP_BIAS0;
state <= ST_WAIT1;
end
ST_CAP_BIAS0: begin
bias0 <= doutb[7:0];
addr <= ADDR_BIAS1;
after_wait_state <= ST_CAP_BIAS1;
state <= ST_WAIT1;
end
ST_CAP_BIAS1: begin
bias1 <= doutb[7:0];
addr <= BASE_W0;
after_wait_state <= ST_CAP_W0_0;
state <= ST_WAIT1;
end
// Weight packing:
// address0: w00 w01 w02 w10
// address1: w11 w12 w20 w21
// address2: w22 blank blank blank
ST_CAP_W0_0: begin
w0[0] <= get_byte(doutb, 2'd0);
w0[1] <= get_byte(doutb, 2'd1);
w0[2] <= get_byte(doutb, 2'd2);
w0[3] <= get_byte(doutb, 2'd3);
addr <= BASE_W0 + 10'd1;
after_wait_state <= ST_CAP_W0_1;
state <= ST_WAIT1;
end
ST_CAP_W0_1: begin
w0[4] <= get_byte(doutb, 2'd0);
w0[5] <= get_byte(doutb, 2'd1);
w0[6] <= get_byte(doutb, 2'd2);
w0[7] <= get_byte(doutb, 2'd3);
addr <= BASE_W0 + 10'd2;
after_wait_state <= ST_CAP_W0_2;
state <= ST_WAIT1;
end
ST_CAP_W0_2: begin
w0[8] <= get_byte(doutb, 2'd0);
addr <= BASE_W1;
after_wait_state <= ST_CAP_W1_0;
state <= ST_WAIT1;
end
ST_CAP_W1_0: begin
w1[0] <= get_byte(doutb, 2'd0);
w1[1] <= get_byte(doutb, 2'd1);
w1[2] <= get_byte(doutb, 2'd2);
w1[3] <= get_byte(doutb, 2'd3);
addr <= BASE_W1 + 10'd1;
after_wait_state <= ST_CAP_W1_1;
state <= ST_WAIT1;
end
ST_CAP_W1_1: begin
w1[4] <= get_byte(doutb, 2'd0);
w1[5] <= get_byte(doutb, 2'd1);
w1[6] <= get_byte(doutb, 2'd2);
w1[7] <= get_byte(doutb, 2'd3);
addr <= BASE_W1 + 10'd2;
after_wait_state <= ST_CAP_W1_2;
state <= ST_WAIT1;
end
ST_CAP_W1_2: begin
w1[8] <= get_byte(doutb, 2'd0);
layer_id <= 1'b0;
state <= ST_BEGIN_LAYER;
end
// ====================================================
// Begin layer
// ====================================================
ST_BEGIN_LAYER: begin
ox <= 6'd0;
oy <= 6'd0;
pack_word <= 32'd0;
out_linear_idx <= 12'd0;
if (layer_id == 1'b0) begin
run_in_size <= fmap_size;
run_out_size <= mid_size;
run_in_base <= BASE_IN;
run_out_base <= BASE_MID;
end
else begin
run_in_size <= mid_size;
run_out_size <= out_size;
run_in_base <= BASE_MID;
run_out_base <= BASE_OUT;
end
state <= ST_START_PIXEL;
end
// ====================================================
// Start one output pixel.
// Initialize pipeline.
// ====================================================
ST_START_PIXEL: begin
acc_q12 <= 32'sd0;
issue_k <= 4'd0;
k_idx <= 4'd0;
pipe_k0 <= 4'd0;
pipe_k1 <= 4'd0;
pipe_byte0 <= 2'd0;
pipe_byte1 <= 2'd0;
pipe_v0 <= 1'b0;
pipe_v1 <= 1'b0;
state <= ST_PIPE_MAC;
end
// ====================================================
// Pipelined MAC loop.
//
// Effective timing with registered addr + BRAM latency=1:
// cycle 0: issue k0
// cycle 1: issue k1
// cycle 2: MAC k0, issue k2
// cycle 3: MAC k1, issue k3
// ...
// cycle 10: MAC k8, then go add bias
// ====================================================
ST_PIPE_MAC: begin
// 1. MAC data issued two CNN cycles ago
if (pipe_v1) begin
pix_q6 <= mac_pixel;
acc_q12 <= acc_q12 + mac_mul;
k_idx <= pipe_k1;
end
// 2. Issue new address every cycle while issue_k <= 8
if (issue_k <= 4'd8) begin
addr <= issue_word_addr;
pipe_k0 <= issue_k;
pipe_byte0 <= issue_byte_now;
pipe_v0 <= 1'b1;
issue_k <= issue_k + 4'd1;
end
else begin
pipe_k0 <= 4'd0;
pipe_byte0 <= 2'd0;
pipe_v0 <= 1'b0;
end
// 3. Shift metadata pipeline
pipe_k1 <= pipe_k0;
pipe_byte1 <= pipe_byte0;
pipe_v1 <= pipe_v0;
// 4. If k8 has been MACed, finish this pixel
if (pipe_v1 && pipe_k1 == 4'd8) begin
state <= ST_ADD_BIAS;
end
else begin
state <= ST_PIPE_MAC;
end
end
ST_ADD_BIAS: begin
if (layer_id == 1'b0)
acc_q12 <= acc_q12 + ($signed({{24{bias0[7]}}, bias0}) <<< 6);
else
acc_q12 <= acc_q12 + ($signed({{24{bias1[7]}}, bias1}) <<< 6);
state <= ST_ROUND;
end
ST_ROUND: begin
conv_out_q6 <= round_sat_q12_to_q6(acc_q12);
state <= ST_PREP_STORE;
end
ST_PREP_STORE: begin
out_linear_idx <= out_linear_now;
state <= ST_STORE_PACK;
end
ST_STORE_PACK: begin
pack_word <= pack_word_next;
if (should_write_word) begin
addr <= out_word_addr;
dinb <= pack_word_next;
web <= 1'b0;
state <= ST_WRITE_DO;
end
else begin
state <= ST_ADV_PIXEL;
end
end
ST_WRITE_DO: begin
web <= 1'b1;
pack_word <= 32'd0;
state <= ST_ADV_PIXEL;
end
ST_ADV_PIXEL: begin
if (is_last_pixel) begin
if (layer_id == 1'b0) begin
layer_id <= 1'b1;
state <= ST_BEGIN_LAYER;
end
else begin
state <= ST_WRITE_STATUS_SETUP;
end
end
else begin
if (ox == (run_out_size - 6'd1)) begin
ox <= 6'd0;
oy <= oy + 6'd1;
end
else begin
ox <= ox + 6'd1;
end
state <= ST_START_PIXEL;
end
end
// ====================================================
// Finish
// Data Memory[13] = 0x00000401
// [10:1] = BASE_OUT = 512
// [0] = done
// ====================================================
ST_WRITE_STATUS_SETUP: begin
addr <= ADDR_STATUS;
dinb <= {21'd0, BASE_OUT, 1'b1};
web <= 1'b0;
state <= ST_WRITE_STATUS_DO;
end
ST_WRITE_STATUS_DO: begin
web <= 1'b1;
state <= ST_DONE;
end
ST_DONE: begin
web <= 1'b0;
done <= 1'b1;
state <= ST_DONE;
end
default: begin
state <= ST_IDLE;
end
endcase
end
end
endmodule
Editor is loading...
Leave a Comment