Untitled

 avatar
unknown
plain_text
3 months ago
55 kB
4
Indexable
`timescale 1ns/1ps

// =============================================================
// Top for Layer2: 18x18 feature map
//   → Flatten(72b x 256) + Mask(9b x 256)  (保留原樣:仍然用 BRAM)
//   → CRP pipeline: Conv + ReLU + MaxPool (中間不落 BRAM)
//   → Pooling out 寫入 blk_mem_gen_pooling2 (8-bit x 64)
//   → Padding(8x8→10x10) (保留原樣:讀 pooling2)
// =============================================================
module top_layer2_all(
    input  wire clk,
    input  wire rst,
    output wire done_all,

    // === debug: pool 寫入訊號輸出給 testbench 看 ===
    output wire        dbg_pool_we,
    output wire [7:0]  dbg_pool_addr,
    output wire [7:0]  dbg_pool_din,

    // === debug: conv output 寫入 BRAM === (此版不再使用,輸出 0)
    output wire        dbg_conv_we,
    output wire [9:0]  dbg_conv_addr,
    output wire [7:0]  dbg_conv_din,

    // === debug: ReLU === (此版不再使用,輸出 0)
    output wire        dbg_relu_we,
    output wire [9:0]  dbg_relu_addr,
    output wire [7:0]  dbg_relu_din,
    output wire [9:0]  dbg_relu_in_addr,
    output wire [9:0]  dbg_bram_conv_relu_addr,

    // === debug: Padding ===
    output wire        dbg_pad_we,
    output wire [9:0]  dbg_pad_addr,
    output wire [7:0]  dbg_pad_din,

    // === debug: REAL BRAM write (commit) ===
    output wire        dbg_cr_bram_we,
    output wire [9:0]  dbg_cr_bram_addr,
    output wire [7:0]  dbg_cr_bram_din,

    output wire        dbg_rp_bram_we,
    output wire [9:0]  dbg_rp_bram_addr,
    output wire [7:0]  dbg_rp_bram_din,

    output wire        dbg_pp_bram_we,
    output wire [7:0]  dbg_pp_bram_addr,
    output wire [7:0]  dbg_pp_bram_din,
    
    
    
    output wire        dbg_s5_v,
    output wire [5:0]  dbg_s5_tile,
    output wire [1:0]  dbg_s5_sub,
    output wire [7:0]  dbg_s5_pix_relu,
    
    
    
    output wire        dbg_s3_v,
    output wire [5:0]  dbg_s3_tile,
    output wire [1:0]  dbg_s3_sub,
    output wire signed [31:0] dbg_s3_sum32,
    
    output wire [7:0] dbg_pool_acc_max,
    output wire [5:0] dbg_pool_acc_tile,
    output wire [3:0] dbg_pool_acc_seen,
    
        // === debug: CRP stg1 (mask/flat read 확인用) ===
    output wire        dbg_stg1_v,
    output wire [5:0]  dbg_stg1_tile,
    output wire [1:0]  dbg_stg1_sub,
    output wire [8:0]  dbg_stg1_mask9,
    output wire [71:0] dbg_stg1_flat72




);

    // =============================================================
    // (A) Flatten:18x18 → flat72/mask (維持原樣)
    // =============================================================
    wire        flat_we;
    wire [7:0]  flat_waddr;
    wire [71:0] flat_wdata;

    wire        mask_we;
    wire [7:0]  mask_waddr;
    wire [8:0]  mask_wdata;

  // ---- BRAM input regs (place right before BRAM) ----
   
    
   
    
    
    reg        flat_req;
    reg [7:0]  flat_req_addr;
    reg [71:0] flat_req_data;
    
    reg        mask_req;
    reg [7:0]  mask_req_addr;
    reg [8:0]  mask_req_data;


// ================================
// Flat72 BRAM write: clean commit
// ================================


reg        flat_commit_p;     // 1-cycle pulse to BRAM
reg [7:0]  flat_a_addr;
reg [71:0] flat_a_din;

always @(posedge clk or posedge rst) begin
  if (rst) begin
    flat_req       <= 1'b0;
    flat_req_addr  <= 8'd0;
    flat_req_data  <= 72'd0;
    flat_commit_p  <= 1'b0;
  end else begin
    // capture request
    flat_req <= flat_we;
    if (flat_we) begin
      flat_req_addr <= flat_waddr;
      flat_req_data <= flat_wdata;
    end

    // commit pulse = previous req
    flat_commit_p <= flat_req;
  end
end

always @(posedge clk or posedge rst) begin
  if (rst) begin
    flat_a_addr <= 8'd0;
    flat_a_din  <= 72'd0;
  end else if (flat_commit_p) begin
    flat_a_addr <= flat_req_addr;
    flat_a_din  <= flat_req_data;
  end
end

// drive BRAM inputs by assigns (no extra mux/CE)
wire        flat_we_bram    = flat_commit_p;
wire [7:0]  flat_waddr_bram = flat_a_addr;
wire [71:0] flat_wdata_bram = flat_a_din;

    
// ================================
// Mask BRAM write: clean commit
// ================================


reg        mask_commit_p;
reg [7:0]  mask_a_addr;
reg [8:0]  mask_a_din;

always @(posedge clk or posedge rst) begin
  if (rst) begin
    mask_req      <= 1'b0;
    mask_req_addr <= 8'd0;
    mask_req_data <= 9'd0;
    mask_commit_p <= 1'b0;
  end else begin
    mask_req <= mask_we;
    if (mask_we) begin
      mask_req_addr <= mask_waddr;
      mask_req_data <= mask_wdata;
    end
    mask_commit_p <= mask_req;
  end
end

always @(posedge clk or posedge rst) begin
  if (rst) begin
    mask_a_addr <= 8'd0;
    mask_a_din  <= 9'd0;
  end else if (mask_commit_p) begin
    mask_a_addr <= mask_req_addr;
    mask_a_din  <= mask_req_data;
  end
end

wire       mask_we_bram    = mask_commit_p;
wire [7:0] mask_waddr_bram = mask_a_addr;
wire [8:0] mask_wdata_bram = mask_a_din;













    wire        flatten_all_done;

    top_module2 flatten_u (
        .clk             (clk),
        .rst             (rst),
        .data_out        (),

        .flat_we         (flat_we),
        .flat_waddr      (flat_waddr),
        .flat_wdata      (flat_wdata),

        .mask_we         (mask_we),
        .mask_waddr      (mask_waddr),
        .mask_wdata      (mask_wdata),

        .flatten_all_done(flatten_all_done),

        .dbg_vld         (),
        .dbg_patch       (),
        .dbg_elem        (),
        .dbg_data        ()
    );

    // =============================================================
    // (B) Flat72 / Mask BRAM:Flatten 寫,CRP 讀 (single-port 仲裁)
    // =============================================================
    // ====== CRP read ports (channel0, filter0) ======
    wire [7:0]  conv_c0_f0_flat_addr;
    wire [71:0] conv_c0_f0_flat_dout;
    wire [7:0]  conv_c0_f0_mask_addr;
    wire [8:0]  conv_c0_f0_mask_dout;

    // ---- stage FSM:FLAT -> CRP -> PAD ----
    localparam ST_FLAT = 2'd0,
               ST_CRP  = 2'd1,
               ST_PAD  = 2'd2,
               ST_DONE = 2'd3;

    reg [1:0] stage;
    reg [1:0] stage_d;
    always @(posedge clk or posedge rst) begin
        if (rst) stage_d <= ST_FLAT;
        else     stage_d <= stage;
    end
    wire stage_sw = (stage != stage_d);
    reg  stage_sw_d1;
    always @(posedge clk or posedge rst) begin
        if (rst) stage_sw_d1 <= 1'b0;
        else     stage_sw_d1 <= stage_sw;
    end
    wire stage_guard = stage_sw | stage_sw_d1; // 切 stage 後兩拍 guard
    
    
    
    reg [2:0] flat_drain;
    reg       flat_done_lat;
    
    always @(posedge clk or posedge rst) begin
      if (rst) begin
        flat_done_lat <= 1'b0;
        flat_drain <= 2'd0;
      end else begin
        if (stage == ST_FLAT && flatten_all_done) begin
          flat_done_lat <= 1'b1;
          flat_drain <= 3'd7;
        end
        if (flat_done_lat && flat_drain != 0)
          flat_drain <= flat_drain - 1'b1;
      end
    end
    
    wire flat_ready_for_crp = flat_done_lat && (flat_drain == 0);

  
    
    
    
    

    // ---- flat72 single-port arbiter ----
   // =============================================================
    // (B-1) Flat72 BRAM : Simple Dual Port
    //   Port A = write by Flatten
    //   Port B = read  by CRP
    // =============================================================
    wire [71:0] flat72_doutb;
    
    blk_mem_gen_flat72_256 bram_flat72 (
      .clka  (clk),
      .ena   (1'b1),
      .wea   (flat_we_bram),
      .addra (flat_waddr_bram),
      .dina  (flat_wdata_bram),
      
      .clkb  (clk),
      .enb   (1'b1),
      .addrb (conv_c0_f0_flat_addr),
      .doutb (flat72_doutb)
    );
    
    // CRP 用的 flat dout 改接 doutb
    assign conv_c0_f0_flat_dout = flat72_doutb;


    // ---- mask single-port arbiter ----
  // =============================================================
    // (B-2) Mask BRAM : Simple Dual Port
    //   Port A = write by Flatten
    //   Port B = read  by CRP
    // =============================================================
    wire [8:0] mask_doutb;
    
     blk_mem_gen_mask2 bram_mask (
      .clka  (clk),
      .ena   (1'b1),
      .wea   (mask_we_bram),
      .addra (mask_waddr_bram),
      .dina  (mask_wdata_bram),
      
      .clkb  (clk),
      .enb   (1'b1),
      .addrb (conv_c0_f0_mask_addr),
      .doutb (mask_doutb)
    );
    
    // CRP 用的 mask dout 改接 doutb
    assign conv_c0_f0_mask_dout = mask_doutb;


    // =============================================================
    // (C) Weight BRAM (filter0) - 9 weights
    // =============================================================
    wire [3:0]        conv_f0_w_addr;
    wire signed [7:0] conv_f0_w_dout;

    // ★如果你 weight BRAM 名字不同,改這個 instance 名稱即可
    blk_mem_gen_weight_f0 bram_w_f0 (
        .clka  (clk),
        .ena   (1'b1),
        .addra (conv_f0_w_addr),
        .douta (conv_f0_w_dout)
    );

    // =============================================================
    // (D) CRP pipeline: Conv + ReLU + Pool (只在 pool 完才寫 BRAM)
    // =============================================================
    wire        pool_f0_we;
    wire [5:0]  pool_f0_addr;   // 0..63
    wire [7:0]  pool_f0_din;
    wire        crp_f0_done;
    
    
    
    
    

    
    
   // ================================
// pooling2 BRAM write: clean commit
// ================================
reg       p2_req;
reg [5:0] p2_req_addr;
reg [7:0] p2_req_data;

reg       p2_commit_p;
reg [5:0] p2_a_addr;
reg [7:0] p2_a_din;

always @(posedge clk or posedge rst) begin
  if (rst) begin
    p2_req      <= 1'b0;
    p2_req_addr <= 6'd0;
    p2_req_data <= 8'd0;
    p2_commit_p <= 1'b0;
  end else begin
    p2_req <= pool_f0_we;
    if (pool_f0_we) begin
      p2_req_addr <= pool_f0_addr;
      p2_req_data <= pool_f0_din;
    end
    p2_commit_p <= p2_req;
  end
end

always @(posedge clk or posedge rst) begin
  if (rst) begin
    p2_a_addr <= 6'd0;
    p2_a_din  <= 8'd0;
  end else if (p2_commit_p) begin
    p2_a_addr <= p2_req_addr;
    p2_a_din  <= p2_req_data;
  end
end

wire       p2_we_bram   = p2_commit_p;
wire [5:0] p2_addr_bram = p2_a_addr;
wire [7:0] p2_din_bram  = p2_a_din;



    

    // 進入 ST_CRP 的那拍給 start pulse
    wire crp_start = (stage == ST_CRP) && stage_sw;

    crp_tilepipe_c0_f0 #(
        .READ_LAT_D  (2),          //////////////////////////
        .READ_LAT_W  (2),
        .SHIFT_CONST (31)
    ) crp_u_c0_f0 (
        .clk              (clk),
        .rst              (rst),
        .start            (crp_start),
        .done             (crp_f0_done),

        .conv_c0_flat_addr(conv_c0_f0_flat_addr),
        .conv_c0_flat_dout(conv_c0_f0_flat_dout),
        .conv_c0_mask_addr(conv_c0_f0_mask_addr),
        .conv_c0_mask_dout(conv_c0_f0_mask_dout),

        .conv_f0_w_addr    (conv_f0_w_addr),
        .conv_f0_w_dout    (conv_f0_w_dout),

        .pool_f0_we        (pool_f0_we),
        .pool_f0_addr      (pool_f0_addr),
        .pool_f0_din       (pool_f0_din),
        .dbg_stg1_v       (dbg_stg1_v),
        .dbg_stg1_tile    (dbg_stg1_tile),
        .dbg_stg1_sub     (dbg_stg1_sub),
        .dbg_stg1_mask9   (dbg_stg1_mask9),
        .dbg_stg1_flat72  (dbg_stg1_flat72)

    );

    // =============================================================
    // (E) Pooling2 BRAM:ST_CRP 寫 / ST_PAD 讀 (single-port)
    // =============================================================
    wire [7:0] bram_pooling2_dout;

    // padding 讀地址 (0..63)
    wire [9:0] padding_addr_in;
    wire [7:0] pad_pool_addr = padding_addr_in[7:0];
    
    
      // ===== after CRP done, wait a few cycles before allowing PAD reads =====
    localparam integer P2_RD_LAT = 2;   // <<< 改成你 blk_mem_gen_pooling2 的 Read Latency
    reg [3:0] crp_drain_cnt;
    reg       crp_done_latched;
    
    always @(posedge clk or posedge rst) begin
      if (rst) begin
        crp_drain_cnt    <= 4'd0;
        crp_done_latched <= 1'b0;
      end else begin
        // 進入 ST_CRP 的那拍清掉
        if ((stage == ST_CRP) && stage_sw) begin
          crp_done_latched <= 1'b0;
          crp_drain_cnt    <= 4'd0;
        end
    
        // CRP done 那拍 latch 起來,開始倒數
        if ((stage == ST_CRP) && crp_f0_done) begin
          crp_done_latched <= 1'b1;
          crp_drain_cnt    <= (P2_RD_LAT + 2); // 例如 2->3 拍,保險
        end
    
        // 倒數
        if (crp_done_latched && (crp_drain_cnt != 0))
          crp_drain_cnt <= crp_drain_cnt - 1'b1;
      end
    end
    
    wire crp_drain_done = crp_done_latched && (crp_drain_cnt == 0);


    // =============================================================
    // (E) Pooling2 BRAM : Simple Dual Port
    //   Port A = write by CRP (pool_f0_*)
    //   Port B = read  by Padding (padding_addr_in)
    // =============================================================
    wire [7:0] bram_pooling2_doutb;
    reg pad_run;
    wire pad_active = (stage == ST_PAD) ;   // padding 真正在跑的期間
    
    
    reg pad_started;
    always @(posedge clk or posedge rst) begin
      if (rst) pad_started <= 1'b0;
      else begin
        if (stage != ST_PAD) pad_started <= 1'b0;
        else if (!pad_started) pad_started <= 1'b1;
      end
    end
    
    wire padding_start = (stage == ST_PAD) && !pad_started; // 只 pulse 一拍

    
    

   blk_mem_gen_pooling2 bram_pooling2 (
      .clka  (clk),
      .ena   (1'b1),
      .wea   (p2_we_bram),
      .addra (p2_addr_bram),
      .dina  (p2_din_bram),
    
      .clkb  (clk),
      .enb   (pad_active),
      .addrb (padding_addr_in[5:0]),
      .doutb (bram_pooling2_doutb)
    );
    
    // padding data_in 改接 doutb
    wire [7:0] padding_data_in = bram_pooling2_doutb;

    // =============================================================
    // (F) Padding:8x8 → 10x10 (維持原樣,只是 data_in 改讀 pooling2)
    // =============================================================


    wire [9:0] padding_addr_out;
    wire [7:0] padding_data_out;
    wire       padding_we_out;
    wire       padding_done;
    
    
    
     reg padding_done_d;
    always @(posedge clk or posedge rst) begin
        if (rst) padding_done_d <= 1'b0;
        else     padding_done_d <= padding_done;
    end

    // 進入 ST_PAD 的那拍給 start pulse
    wire padding_done_p = padding_done & ~padding_done_d;

    always @(posedge clk or posedge rst) begin
        if (rst) stage <= ST_FLAT;
        else begin
            case (stage)
                ST_FLAT: if (flat_ready_for_crp) stage <= ST_CRP;
                ST_CRP : if (crp_drain_done)   stage <= ST_PAD;
                ST_PAD : if (padding_done_p) stage <= ST_DONE;
                ST_DONE: stage <= ST_DONE;
                default: stage <= ST_FLAT;
            endcase
        end
    end
    
    
 
   
    always @(posedge clk or posedge rst) begin
      if (rst) pad_run <= 1'b0;
      else begin
        // 進入 ST_PAD 那拍打開
        if ((stage == ST_PAD) && stage_sw) pad_run <= 1'b1;
        // done 後關掉
        else if (padding_done_p)           pad_run <= 1'b0;
      end
    end
    
    

    padding_8x8_top #(
        .RD_TOTAL_LAT(3)
    ) padding_u (
        .clk      (clk),
        .rst      (rst),
        .start    (padding_start),
        .done     (padding_done),

        .addr_in  (padding_addr_in),
        .data_in  (padding_data_in),

        .addr_out (padding_addr_out),
        .data_out (padding_data_out),
        .we_out   (padding_we_out)
    );

    blk_mem_gen_out_padding_10x10 bram_padding_out (
        .clka  (clk),
        .ena   (1'b1),
        .wea   (padding_we_out),
        .addra (padding_addr_out[6:0]),
        .dina  (padding_data_out),
        .douta ()
    );

    // =============================================================
    // (G) stage transition
    // =============================================================
    // done pulse 轉 edge(padding_done 通常會持續,這裡用 pulse 避免重複判)
   
    

    // =============================================================
    // (H) done & debug
    // =============================================================
    assign done_all = (stage == ST_DONE);


    // pool debug = pipeline pool write(邏輯寫出)
    assign dbg_pool_we   = pool_f0_we;
    assign dbg_pool_addr = {2'b00, pool_f0_addr};
    assign dbg_pool_din  = pool_f0_din;

    // conv/relu debug:此版本不再使用(避免誤導)
    assign dbg_conv_we   = 1'b0;
    assign dbg_conv_addr = 10'd0;
    assign dbg_conv_din  = 8'd0;

    assign dbg_relu_we   = 1'b0;
    assign dbg_relu_addr = 10'd0;
    assign dbg_relu_din  = 8'd0;
    assign dbg_relu_in_addr = 10'd0;
    assign dbg_bram_conv_relu_addr = 10'd0;

    // padding debug
    assign dbg_pad_we   = padding_we_out;
    assign dbg_pad_addr = padding_addr_out;
    assign dbg_pad_din  = padding_data_out;

    // commit debug:真正寫進 pooling2 的訊號
    // cr/rp 這版不再有(輸出 0),pp 用 pooling2 commit
    assign dbg_cr_bram_we   = 1'b0;
    assign dbg_cr_bram_addr = 10'd0;
    assign dbg_cr_bram_din  = 8'd0;

    assign dbg_rp_bram_we   = 1'b0;
    assign dbg_rp_bram_addr = 10'd0;
    assign dbg_rp_bram_din  = 8'd0;

    assign dbg_pp_bram_we   = p2_we_bram;
    assign dbg_pp_bram_addr = {2'b00, p2_addr_bram};
    assign dbg_pp_bram_din  = p2_din_bram;


    
    
    assign dbg_s5_v        = crp_u_c0_f0.stg5_v;
    assign dbg_s5_tile     = crp_u_c0_f0.stg5_tile_id;
    assign dbg_s5_sub      = crp_u_c0_f0.stg5_sub_id;
    assign dbg_s5_pix_relu = crp_u_c0_f0.stg5_pix_relu;
    
    
    assign dbg_s3_v     = crp_u_c0_f0.stg3_v;
    assign dbg_s3_tile  = crp_u_c0_f0.stg3_tile_id;
    assign dbg_s3_sub   = crp_u_c0_f0.stg3_sub_id;
    assign dbg_s3_sum32 = crp_u_c0_f0.stg3_sum32;
    
    
    assign dbg_pool_acc_max  = crp_u_c0_f0.pool_acc_max;
    assign dbg_pool_acc_tile = crp_u_c0_f0.pool_acc_tile_id;
    assign dbg_pool_acc_seen = crp_u_c0_f0.pool_acc_seen;




endmodule



module top_module2 (
    input  wire        clk,
    input  wire        rst,
    output wire [10:0] data_out,          // debug 用

    // flat BRAM:每 patch 寫一次
    output wire        flat_we,
    output wire [7:0]  flat_waddr,        // 0..255
    output wire [71:0] flat_wdata,

    // mask BRAM:每 patch 寫一次 (維持 9-bit)
    output wire        mask_we,
    output wire [7:0]  mask_waddr,        // 0..255
    output wire [8:0]  mask_wdata,

    output wire        flatten_all_done,
    
    
    output wire        dbg_vld,
    output wire [15:0] dbg_patch,
    output wire [3:0]  dbg_elem,
    output wire [7:0]  dbg_data

);

    // === Wires and Regs ===
    wire [8:0] bram_addr;
    wire [7:0]  bram_data_out;
    wire [7:0]  data_out_internal;
    wire        done;
    wire        start;
    wire [8:0] base_addr;
    wire [15:0] patch_index;
    wire        valid;

    // ---- mask ----
    reg  [3:0]  elem_idx;          // 0..8
    reg  [8:0]  mask_temp;
    wire        is_nz = (data_out_internal != 8'd0);
    reg  [8:0]  mask_next;

    wire        last_elem       = valid && (elem_idx == 4'd8);
    wire [7:0]  write_addr_mask = patch_index[7:0];
    
    reg last_elem_d;
    wire last_elem_p = last_elem & ~last_elem_d;
    
    
    
    always @(posedge clk or posedge rst) begin
      if (rst) last_elem_d <= 1'b0;
      else     last_elem_d <= last_elem;
    end
    


    assign flatten_all_done = (patch_index == 16'd255) && done;

    // --- input BRAM (18x18) ---
    blk_mem_gen_out_padding bram_transform (
        .clka (clk),
        .ena  (1'b1),
        .addra(bram_addr),
        .douta(bram_data_out)
    );

    input_transformation2 #(
        .BRAM_READ_LAT(4)
    ) reader (
        .clk      (clk),
        .rst      (rst),
        .start    (start),
        .base_addr(base_addr),
        .bram_data(bram_data_out),
        .bram_addr(bram_addr),
        .data_out (data_out_internal),
        .done     (done),
        .valid    (valid)
    );

    controller2 patch_ctr (
        .clk        (clk),
        .rst        (rst),
        .done       (done),
        .base_addr  (base_addr),
        .start      (start),
        .patch_index(patch_index)
    );

    assign data_out = {3'b000, data_out_internal};

    // ===============================
    // NEW: patch buffer for 72-bit write
    // ===============================
    reg [7:0] patch_buf [0:8];   // 收 9 筆
    reg clear_buf;

    integer k;

    // 每拍 valid 把當前 element 放到對應位置
   always @(posedge clk or posedge rst) begin
  if (rst) begin
    for (k = 0; k < 9; k = k + 1)
      patch_buf[k] <= 8'd0;
    clear_buf <= 1'b0;
  end else begin
    // 如果上一拍要求清,這拍先清掉(然後把旗標放回 0)
    if (clear_buf) begin
      for (k = 0; k < 9; k = k + 1)
        patch_buf[k] <= 8'd0;
      clear_buf <= 1'b0;
    end

    // 收資料
    if (valid) begin
      patch_buf[elem_idx] <= data_out_internal;

      // 收到 elem 8 → 下一拍清
      if (elem_idx == 4'd8)
        clear_buf <= 1'b1;
    end
  end
end


  

    // flat BRAM write:每個 patch 一次
   // assign flat_we    = last_elem;
    //assign flat_waddr = patch_index[7:0];
    assign dbg_vld   = valid;
    assign dbg_patch = patch_index;
    assign dbg_elem  = elem_idx;
    assign dbg_data  = data_out_internal;


    // 72-bit 打包:byte i 放在 [i*8 +: 8]
   wire [71:0] flat_pack_curr;
assign flat_pack_curr = {
    patch_buf[8], patch_buf[7], patch_buf[6], patch_buf[5], patch_buf[4],
    patch_buf[3], patch_buf[2], patch_buf[1], patch_buf[0]
};

reg [71:0] flat_pack_next;
always @(*) begin
    flat_pack_next = flat_pack_curr;
    // last_elem 這拍 patch_buf[elem_idx] 還沒被 nonblocking 更新,所以手動塞 data_out_internal
    case (elem_idx)
        4'd0: flat_pack_next[ 7: 0] = data_out_internal;
        4'd1: flat_pack_next[15: 8] = data_out_internal;
        4'd2: flat_pack_next[23:16] = data_out_internal;
        4'd3: flat_pack_next[31:24] = data_out_internal;
        4'd4: flat_pack_next[39:32] = data_out_internal;
        4'd5: flat_pack_next[47:40] = data_out_internal;
        4'd6: flat_pack_next[55:48] = data_out_internal;
        4'd7: flat_pack_next[63:56] = data_out_internal;
        4'd8: flat_pack_next[71:64] = data_out_internal;
        default: flat_pack_next = flat_pack_curr;
    endcase
end

//assign flat_wdata = (last_elem) ? flat_pack_next : flat_pack_curr;


    // mask BRAM write:每個 patch 一次(維持你的寫法)
    //assign mask_we    = last_elem;
   // assign mask_waddr = write_addr_mask;
   // assign mask_wdata = mask_next;
   
   // -------------------------------
// BRAM write commit stage (flat + mask)
//   capture on last_elem, commit next cycle
// -------------------------------
// -------------------------------
// BRAM write commit stage (flat + mask)
//   capture on last_elem, COMMIT on NEXT cycle (pending)
// -------------------------------
reg        flat_pending;
reg [7:0]  flat_waddr_hold;
reg [71:0] flat_wdata_hold;

reg        mask_pending;
reg [9:0]  mask_waddr_hold;
reg [8:0]  mask_wdata_hold;

// connect to outputs
assign flat_we    = flat_pending;
assign flat_waddr = flat_waddr_hold;
assign flat_wdata = flat_wdata_hold;

assign mask_we    = mask_pending;
assign mask_waddr = mask_waddr_hold;
assign mask_wdata = mask_wdata_hold;

    always @(posedge clk or posedge rst) begin
      if (rst) begin
        flat_pending    <= 1'b0;
        flat_waddr_hold <= 8'd0;
        flat_wdata_hold <= 72'd0;
    
        mask_pending    <= 1'b0;
        mask_waddr_hold <= 10'd0;
        mask_wdata_hold <= 9'd0;
      end else begin
        // pending only lasts 1 full cycle
        if (flat_pending) flat_pending <= 1'b0;
        if (mask_pending) mask_pending <= 1'b0;
    
        // capture on last_elem, commit next cycle by pending=1
        if (last_elem_p) begin
          flat_waddr_hold <= patch_index[7:0];
          flat_wdata_hold <= flat_pack_next;  // includes last element
          mask_waddr_hold <= write_addr_mask;
          mask_wdata_hold <= mask_next;
    
          flat_pending <= 1'b1;
          mask_pending <= 1'b1;
        end
      end
end

    
       
       

    always @(*) begin
        mask_next = mask_temp;
        mask_next[elem_idx] = is_nz;
    end

    // mask_temp & elem_idx
    always @(posedge clk or posedge rst) begin
        if (rst) begin
            elem_idx  <= 4'd0;
            mask_temp <= 9'd0;
        end else if (valid) begin
            mask_temp[elem_idx] <= is_nz;

            if (elem_idx == 4'd8) begin
                elem_idx  <= 4'd0;
                mask_temp <= 9'd0;
            end else begin
                elem_idx <= elem_idx + 1'b1;
            end
        end
    end

endmodule









// ==================== INPUT TRANSFORMATION (latency-safe, pipeline align) ====================
// ==================== INPUT TRANSFORMATION (latency-safe, with warm-up) ====================
module input_transformation2 #(
    parameter integer BRAM_READ_LAT = 4   // 請設成和 blk_mem_gen_out_flatten 的 read latency 一樣
)(
    input  wire        clk,
    input  wire        rst,
    input  wire        start,
    input  wire [8:0] base_addr,
    input  wire [7:0]  bram_data,

    output reg  [8:0] bram_addr,
    output reg  [7:0]  data_out,
    output reg         done,      // 每個 patch 出完 9 筆時拉 1 拍
    output reg         valid      // 每次輸出一筆有效 data 時拉 1 拍
);
    // -------------------------------
    // 固定 offset:3x3 window
    // -------------------------------
    reg [10:0] offsets[0:8];
    initial begin
        offsets[0] = 0;
        offsets[1] = 1;
        offsets[2] = 2;
        offsets[3] = 18;
        offsets[4] = 19;
        offsets[5] = 20;
        offsets[6] = 36;
        offsets[7] = 37;
        offsets[8] = 38;
    end

    localparam integer LAT = (BRAM_READ_LAT < 1) ? 1 : BRAM_READ_LAT;

    // 狀態機
    localparam S_IDLE   = 2'd0;
    localparam S_WAIT   = 2'd1;
    localparam S_CAPTURE= 2'd2;

    reg [1:0] state;

    reg [3:0] idx;          // 要讀哪一個 offset(0..8)
    reg [3:0] elem_cnt;     // 已經輸出了幾個真正的 element(0..8)
    reg [3:0] wait_cnt;     // 等待 BRAM read latency 用
    reg       warmup;       // 0 = 還沒丟掉第一筆;1 = 已經丟掉,開始算真正的 9 筆

    // 主邏輯:每個 start 代表一個 patch
    always @(posedge clk or posedge rst) begin
        if (rst) begin
            state    <= S_IDLE;
            bram_addr<= 11'd0;
            data_out <= 8'd0;
            done     <= 1'b0;
            valid    <= 1'b0;
            idx      <= 4'd0;
            elem_cnt <= 4'd0;
            wait_cnt <= 4'd0;
            warmup   <= 1'b0;
        end else begin
            // 預設不拉 valid / done
            valid <= 1'b0;
            done  <= 1'b0;

            case (state)
                // 等 controller 丟 start
                S_IDLE: begin
                    if (start) begin
                        idx      <= 4'd0;
                        elem_cnt <= 4'd0;
                        warmup   <= 1'b0;

                        // 先發第一個 address(用 offsets[0])
                        bram_addr<= base_addr + offsets[0];
                        wait_cnt <= LAT;
                        state    <= S_WAIT;
                    end
                end

                // 等 BRAM latency
                S_WAIT: begin
                    if (wait_cnt != 0)
                        wait_cnt <= wait_cnt - 1'b1;
                    else
                        state <= S_CAPTURE;
                end

                // 收到一筆 BRAM output
                S_CAPTURE: begin
                    if (!warmup) begin
                        // 第一次 capture:當作 warm-up,不產生 valid,不算 elem_cnt
                        // 只把 warmup 拉起來,然後重新發「真正的第 0 筆」的 address
                        warmup   <= 1'b1;
                        idx      <= 4'd0;
                        bram_addr<= base_addr + offsets[0];
                        wait_cnt <= LAT;
                        state    <= S_WAIT;
                    end else begin
                        // 正式的 9 筆 data 都走這裡
                        data_out <= bram_data;
                        valid    <= 1'b1;    // 這一拍 data_out 有效

                        if (elem_cnt == 4'd8) begin
                            // 第 9 筆輸出完成 → 這個 patch 結束
                            done   <= 1'b1;  // 告訴 controller1:這個 patch ok 了
                            state  <= S_IDLE;
                        end else begin
                            // 還有下一筆
                            elem_cnt <= elem_cnt + 1'b1;
                            idx      <= idx + 1'b1;
                            bram_addr<= base_addr + offsets[idx + 1'b1];
                            wait_cnt <= LAT;
                            state    <= S_WAIT;
                        end
                    end
                end

                default: state <= S_IDLE;
            endcase
        end
    end

endmodule


// ==================== CONTROLLER (run-once then halt) ====================
module controller2 (
    input  wire        clk,
    input  wire        rst,
    input  wire        done,
    output reg  [8:0] base_addr,
    output reg         start,
    output reg  [15:0] patch_index
);

    reg [5:0] row, col;
    reg [1:0] state;

    localparam IDLE = 2'b00,
               READ = 2'b01,
               WAIT = 2'b10,
               HALT = 2'b11;   // ★ 新增

    wire last_col   = (col == 6'd15);
    wire last_row   = (row == 6'd15);
    wire last_patch = last_row && last_col;

    always @(posedge clk or posedge rst) begin
        if (rst) begin
            state       <= IDLE;
            row         <= 6'd0;
            col         <= 6'd0;
            base_addr   <= 11'd0;
            start       <= 1'b0;
            patch_index <= 16'd0;
        end else begin
            case (state)
                IDLE: begin
                    // ★ 只在 IDLE 打 1 拍 start
                    base_addr <= row * 9'd18 + col;
                    start     <= 1'b1;
                    state     <= READ;
                end

                READ: begin
                    // ★ 進 READ 立刻把 start 拉回 0(你原本就有)
                    start <= 1'b0;

                    if (done) begin
                        // ★ 關鍵:最後一個 patch 在這裡就直接停
                        if (last_patch) begin
                            patch_index <= 16'd255;  // optional
                            state       <= HALT;     // ★ 不再回 IDLE
                        end else begin
                            state <= WAIT;
                        end
                    end
                end

                WAIT: begin
                    // 下一個 patch 的 row/col/patch_index 更新
                    if (!last_col) begin
                        col <= col + 1'b1;
                    end else begin
                        col <= 6'd0;
                        row <= row + 1'b1;
                    end
                    patch_index <= patch_index + 1'b1;
                    state <= IDLE;
                end

                HALT: begin
                    // ★ 永遠不再發 start
                    start <= 1'b0;
                    state <= HALT;
                end
            endcase
        end
    end
endmodule





module crp_tilepipe_c0_f0 #(
    parameter  READ_LAT_D   = 2,   // flat/mask BRAM read latency
    parameter  READ_LAT_W   = 2,   // weight BRAM read latency
    parameter  SHIFT_CONST  = 31,

    parameter signed [31:0] BIAS_CONST  = 32'sd3598,
    parameter signed [31:0] SCALE_CONST = 32'sd10093451
)(
    input  wire clk,
    input  wire rst,
    input  wire start,               // 建議 pulse:進入 CONV stage 那拍
    output reg  done,                // 最後一筆 pool write 時 pulse

    // ===== Flatten outputs (read-only) =====
    output reg  [7:0]  conv_c0_flat_addr,
    input  wire [71:0] conv_c0_flat_dout,   // dout = addr from 2 cycles ago
    output reg  [7:0]  conv_c0_mask_addr,
    input  wire [8:0]  conv_c0_mask_dout,

    // ===== Weight (for filter0) =====
    output reg  [3:0]        conv_f0_w_addr,
    input  wire signed [7:0] conv_f0_w_dout, // dout = addr from 2 cycles ago

    // ===== Pool output (filter0) =====
    output wire        pool_f0_we,
    output wire  [5:0] pool_f0_addr,   // 0..63
    output wire  [7:0] pool_f0_din,
  
  
  
      // ===== DEBUG taps =====
    output wire        dbg_stg1_v,
    output wire [5:0]  dbg_stg1_tile,
    output wire [1:0]  dbg_stg1_sub,
    output wire [8:0]  dbg_stg1_mask9,
    output wire [71:0] dbg_stg1_flat72

);

    // =========================================================
    // 0) Weight preload (9 weights)
    // =========================================================
    reg signed [7:0] conv_f0_w0,conv_f0_w1,conv_f0_w2,conv_f0_w3,conv_f0_w4,
                     conv_f0_w5,conv_f0_w6,conv_f0_w7,conv_f0_w8;
    reg              conv_f0_weights_ready;
    
    // ---- NEW: registered commit stage for BRAM write ----
    reg        pool_commit_we;
    reg [7:0]  pool_commit_addr;
    reg [7:0]  pool_commit_din;
    

    localparam W_IDLE = 3'd0,
           W_ISS  = 3'd1,
           W_WT   = 3'd2,
           W_CAP  = 3'd3,
           W_DONE = 3'd4;

    reg [2:0]  wst;
    reg [3:0]  widx;
    reg [1:0]  wwait;

    always @(posedge clk or posedge rst) begin
       if (rst) begin
        wst   <= W_IDLE;
        widx  <= 4'd0;
        wwait <= 2'd0;
    
        conv_f0_w_addr <= 4'd0;
        conv_f0_weights_ready <= 1'b0;
    
        conv_f0_w0 <= 8'sd0;
        conv_f0_w1 <= 8'sd0;
        conv_f0_w2 <= 8'sd0;
        conv_f0_w3 <= 8'sd0;
        conv_f0_w4 <= 8'sd0;
        conv_f0_w5 <= 8'sd0;
        conv_f0_w6 <= 8'sd0;
        conv_f0_w7 <= 8'sd0;
        conv_f0_w8 <= 8'sd0;
    end  else begin
            if (start) begin
                wst <= W_ISS;
                widx <= 0;
                conv_f0_weights_ready <= 1'b0;
            end

            case (wst)
                W_IDLE: begin end
                W_ISS: begin
                    conv_f0_w_addr <= widx;
                    wwait <= (READ_LAT_W < 1) ? 1 : READ_LAT_W[1:0];
                    wst <= W_WT;
                end
                W_WT: begin
                    if (wwait != 0) wwait <= wwait - 1'b1;
                    else wst <= W_CAP;
                end
                W_CAP: begin
                    case (widx)
                        0: conv_f0_w0 <= conv_f0_w_dout;
                        1: conv_f0_w1 <= conv_f0_w_dout;
                        2: conv_f0_w2 <= conv_f0_w_dout;
                        3: conv_f0_w3 <= conv_f0_w_dout;
                        4: conv_f0_w4 <= conv_f0_w_dout;
                        5: conv_f0_w5 <= conv_f0_w_dout;
                        6: conv_f0_w6 <= conv_f0_w_dout;
                        7: conv_f0_w7 <= conv_f0_w_dout;
                        8: conv_f0_w8 <= conv_f0_w_dout;
                    endcase

                    if (widx == 8) begin
                        conv_f0_weights_ready <= 1'b1;
                        wst <= W_DONE;
                    end else begin
                        widx <= widx + 1'b1;
                        wst <= W_ISS;
                    end
                end
                W_DONE: begin end
            endcase
        end
    end

    // =========================================================
    // 1) ISSUE: tile-based patch order
    //    tile_id 0..63, sub_id 0..3 => patch = base + {0,1,16,17}
    // =========================================================
    reg        issue_run;
    reg [5:0]  issue_tile_id;   // 0..63
    reg [1:0]  issue_sub_id;    // 0..3

    wire [2:0] issue_tile_r = issue_tile_id[5:3]; // 0..7
    wire [2:0] issue_tile_c = issue_tile_id[2:0]; // 0..7

    // base = (tile_r*32) + (tile_c*2)
    wire [9:0] issue_base_patch = ({issue_tile_r,5'b0}) + ({issue_tile_c,1'b0});

    reg  [9:0] issue_patch_addr;
    always @(*) begin
        case (issue_sub_id)
            2'd0: issue_patch_addr = issue_base_patch + 10'd0;
            2'd1: issue_patch_addr = issue_base_patch + 10'd1;
            2'd2: issue_patch_addr = issue_base_patch + 10'd16;
            2'd3: issue_patch_addr = issue_base_patch + 10'd17;
        endcase
    end

    wire issue_vld = issue_run && conv_f0_weights_ready; // 每拍發 1 個 sub patch

    always @(posedge clk or posedge rst) begin
        if (rst) begin
            issue_run <= 1'b0;
            issue_tile_id <= 0;
            issue_sub_id  <= 0;
            conv_c0_flat_addr <= 0;
            conv_c0_mask_addr <= 0;
        end else begin
            if (start) begin
                issue_run <= 1'b1;
                issue_tile_id <= 0;
                issue_sub_id  <= 0;
            end

            if (issue_vld) begin
                conv_c0_flat_addr <= issue_patch_addr[7:0];
                conv_c0_mask_addr <= issue_patch_addr[7:0];

                // next sub / tile
                if (issue_sub_id == 2'd3) begin
                    issue_sub_id <= 2'd0;
                    if (issue_tile_id == 6'd63) begin
                        issue_run <= 1'b0; // issue 完畢,等 pipeline drain
                    end else begin
                        issue_tile_id <= issue_tile_id + 1'b1;
                    end
                end else begin
                    issue_sub_id <= issue_sub_id + 1'b1;
                end
            end
        end
    end

    // =========================================================
    // 2) ALIGN: BRAM latency=2 => shift tag 2 cycles 對齊 dout
    //    tag = {tile_id, sub_id}
    // =========================================================
    reg [READ_LAT_D:0] stg0_vld_sh;
    reg [5:0]          stg0_tile_sh [0:READ_LAT_D];
    reg [1:0]          stg0_sub_sh  [0:READ_LAT_D];

    integer k;
    always @(posedge clk or posedge rst) begin
      if (rst) begin
        stg0_vld_sh <= {(READ_LAT_D+1){1'b0}};
        for (k = 0; k <= READ_LAT_D; k = k + 1) begin
            stg0_tile_sh[k] <= 6'd0;
            stg0_sub_sh[k]  <= 2'd0;
        end
    end else begin
            stg0_vld_sh <= {stg0_vld_sh[READ_LAT_D-1:0], issue_vld};

            stg0_tile_sh[0] <= issue_tile_id;
            stg0_sub_sh[0]  <= issue_sub_id;
            for (k=1; k<=READ_LAT_D; k=k+1) begin
                stg0_tile_sh[k] <= stg0_tile_sh[k-1];
                stg0_sub_sh[k]  <= stg0_sub_sh[k-1];
            end
        end
    end

    // =========================================================
    // 3) STAGE1: latch flat/mask (aligned)
    // =========================================================
    reg        stg1_v;
    reg [5:0]  stg1_tile_id;
    reg [1:0]  stg1_sub_id;
    reg [71:0] stg1_flat72;
    reg [8:0]  stg1_mask9;

    always @(posedge clk or posedge rst) begin
         if (rst) begin
        stg1_v       <= 1'b0;
        stg1_tile_id <= 6'd0;
        stg1_sub_id  <= 2'd0;
        stg1_flat72  <= 72'd0;
        stg1_mask9   <= 9'd0;
    end else begin
            stg1_v       <= stg0_vld_sh[READ_LAT_D];
            stg1_tile_id <= stg0_tile_sh[READ_LAT_D];
            stg1_sub_id  <= stg0_sub_sh[READ_LAT_D];

            if (stg0_vld_sh[READ_LAT_D]) begin
                stg1_flat72 <= conv_c0_flat_dout;
                stg1_mask9  <= conv_c0_mask_dout;
            end
        end
    end

    // unpack x
    wire signed [7:0] stg1_x0 = stg1_flat72[ 7: 0];
    wire signed [7:0] stg1_x1 = stg1_flat72[15: 8];
    wire signed [7:0] stg1_x2 = stg1_flat72[23:16];
    wire signed [7:0] stg1_x3 = stg1_flat72[31:24];
    wire signed [7:0] stg1_x4 = stg1_flat72[39:32];
    wire signed [7:0] stg1_x5 = stg1_flat72[47:40];
    wire signed [7:0] stg1_x6 = stg1_flat72[55:48];
    wire signed [7:0] stg1_x7 = stg1_flat72[63:56];
    wire signed [7:0] stg1_x8 = stg1_flat72[71:64];




    assign dbg_stg1_v      = stg1_v;
    assign dbg_stg1_tile   = stg1_tile_id;
    assign dbg_stg1_sub    = stg1_sub_id;
    assign dbg_stg1_mask9  = stg1_mask9;
    assign dbg_stg1_flat72 = stg1_flat72;

    // =========================================================
    // 4) STAGE2: MUL (masked)
    // =========================================================
    reg        stg2_v;
    reg [5:0]  stg2_tile_id;
    reg [1:0]  stg2_sub_id;
    reg signed [31:0] stg2_p[0:8];

    integer m;
    always @(posedge clk or posedge rst) begin
        if (rst) begin
            stg2_v <= 1'b0;
            stg2_tile_id <= 6'd0;
            stg2_sub_id  <= 2'd0;
           for (m = 0; m < 9; m = m + 1)
            stg2_p[m] <= 32'sd0;   
        end else begin
            stg2_v       <= stg1_v;
            stg2_tile_id <= stg1_tile_id;
            stg2_sub_id  <= stg1_sub_id;

            if (stg1_v) begin
                stg2_p[0] <= stg1_mask9[0] ? ($signed(stg1_x0) * $signed(conv_f0_w0)) : 32'sd0;
                stg2_p[1] <= stg1_mask9[1] ? ($signed(stg1_x1) * $signed(conv_f0_w1)) : 32'sd0;
                stg2_p[2] <= stg1_mask9[2] ? ($signed(stg1_x2) * $signed(conv_f0_w2)) : 32'sd0;
                stg2_p[3] <= stg1_mask9[3] ? ($signed(stg1_x3) * $signed(conv_f0_w3)) : 32'sd0;
                stg2_p[4] <= stg1_mask9[4] ? ($signed(stg1_x4) * $signed(conv_f0_w4)) : 32'sd0;
                stg2_p[5] <= stg1_mask9[5] ? ($signed(stg1_x5) * $signed(conv_f0_w5)) : 32'sd0;
                stg2_p[6] <= stg1_mask9[6] ? ($signed(stg1_x6) * $signed(conv_f0_w6)) : 32'sd0;
                stg2_p[7] <= stg1_mask9[7] ? ($signed(stg1_x7) * $signed(conv_f0_w7)) : 32'sd0;
                stg2_p[8] <= stg1_mask9[8] ? ($signed(stg1_x8) * $signed(conv_f0_w8)) : 32'sd0;
            end
        end
    end

    // =========================================================
    // 5) STAGE3: ACC + BIAS
    // =========================================================
    reg        stg3_v;
    reg [5:0]  stg3_tile_id;
    reg [1:0]  stg3_sub_id;
    reg signed [31:0] stg3_sum32;

    always @(posedge clk or posedge rst) begin
        if (rst) begin
            stg3_v <= 1'b0;
            stg3_tile_id <= 6'd0;
            stg3_sub_id  <= 2'd0;
            stg3_sum32   <= 32'sd0;
        end else begin
            stg3_v       <= stg2_v;
            stg3_tile_id <= stg2_tile_id;
            stg3_sub_id  <= stg2_sub_id;

            if (stg2_v) begin
                stg3_sum32 <= (stg2_p[0]+stg2_p[1]+stg2_p[2]+stg2_p[3]+stg2_p[4]+
                              stg2_p[5]+stg2_p[6]+stg2_p[7]+stg2_p[8]) + BIAS_CONST;
            end
        end
    end
     
    // =========================================================
    // 6) STAGE4: SCALE MUL
    // =========================================================
    reg        stg4_v;
    reg [5:0]  stg4_tile_id;
    reg [1:0]  stg4_sub_id;
    reg signed [63:0] stg4_prod64;

    always @(posedge clk or posedge rst) begin
        if (rst) begin
            stg4_v <= 1'b0;
            stg4_tile_id <= 6'd0;
            stg4_sub_id  <= 2'd0;
            stg4_prod64  <= 64'sd0;
        end else begin
            stg4_v       <= stg3_v;
            stg4_tile_id <= stg3_tile_id;
            stg4_sub_id  <= stg3_sub_id;

            if (stg3_v) begin
                stg4_prod64 <= $signed(stg3_sum32) * $signed(SCALE_CONST);
            end
        end
    end

    // =========================================================
    // 7) STAGE5: SHIFT + ReLU => pixel
    // =========================================================
    reg        stg5_v;
    reg [5:0]  stg5_tile_id;
    reg [1:0]  stg5_sub_id;
    reg [7:0]  stg5_pix_relu;

    wire signed [63:0] stg5_shifted64 = $signed(stg4_prod64) >>> SHIFT_CONST;
    wire signed [7:0]  stg5_out8_s    = stg5_shifted64[7:0];

    always @(posedge clk or posedge rst) begin
        if (rst) begin
            stg5_v <= 1'b0;
            stg5_tile_id <= 6'd0;
            stg5_sub_id  <= 2'd0;
            stg5_pix_relu<= 8'sd0;
        end else begin
            stg5_v       <= stg4_v;
            stg5_tile_id <= stg4_tile_id;
            stg5_sub_id  <= stg4_sub_id;

            if (stg4_v) begin
                stg5_pix_relu <= stg5_out8_s[7] ? 8'd0 : stg5_out8_s; // ReLU
            end
        end
    end

    // =========================================================
    // 8) POOL reducer (collect 4 sub pixels per tile)
    // =========================================================
    function [7:0] max2;
        input [7:0] a,b;
        begin max2 = (a>b)?a:b; end
    endfunction
    
    
    // ---- pending write hold regs ----
    reg       pool_pending;
    reg [7:0] pool_addr_hold;
    reg [7:0] pool_din_hold;
    
    
    

    // =========================================================
    // 8.0) BRAM write outputs are fully registered (commit stage)
    // =========================================================
  assign pool_f0_we   = pool_pending;
 assign pool_f0_addr = pool_addr_hold[5:0]; // 你 top 端用 6-bit
 assign pool_f0_din  = pool_din_hold;

    
        
    

    reg [7:0] pool_acc_max;
    reg [5:0] pool_acc_tile_id;
    reg [3:0] pool_acc_seen; // bitmask for sub_id 0..3

    always @(posedge clk or posedge rst) begin
        if (rst) begin
            //pool_f0_we   <= 1'b0;
           // pool_f0_addr <= 8'd0;
            //pool_f0_din  <= 8'd0;
            done         <= 1'b0;
            pool_pending <= 1'b0;
          pool_addr_hold <= 8'd0;
          pool_din_hold  <= 8'd0;

            pool_acc_max     <= 8'd0;
            pool_acc_tile_id <= 6'd0;
            pool_acc_seen    <= 4'b0000;
        end else begin
            pool_commit_we <= 1'b0;
            done       <= 1'b0;
            if (pool_pending) pool_pending <= 1'b0; // 只維持一拍
            if (start) begin
                pool_acc_max     <= 8'd0;
                pool_acc_tile_id <= 6'd0;
                pool_acc_seen    <= 4'b0000;
            end

            if (stg5_v) begin
                // new tile? (正常情況 stg5_tile_id 會依序,但我仍用 tile_id 做保護)
                if (pool_acc_seen == 4'b0000) begin
                    pool_acc_tile_id <= stg5_tile_id;
                    pool_acc_max     <= stg5_pix_relu;
                    pool_acc_seen    <= (4'b0001 << stg5_sub_id);
                end else if (stg5_tile_id == pool_acc_tile_id) begin
                    pool_acc_max  <= max2(pool_acc_max, stg5_pix_relu);
                    pool_acc_seen <= pool_acc_seen | (4'b0001 << stg5_sub_id);
                end else begin
                    // 若你未來改成 out-of-order,這裡就是你插 queue 的地方
                    // 目前先假設 in-order,不處理跨 tile 亂序
                    pool_acc_tile_id <= stg5_tile_id;
                    pool_acc_max     <= stg5_pix_relu;
                    pool_acc_seen    <= (4'b0001 << stg5_sub_id);
                end

                // 收滿 4 個 sub pixel => write
               if ((pool_acc_seen | (4'b0001 << stg5_sub_id)) == 4'b1111) begin
                    pool_addr_hold <= {2'b00, pool_acc_tile_id};
                    pool_din_hold  <= max2(pool_acc_max, stg5_pix_relu);
                    pool_pending   <= 1'b1;
                
                    if (pool_acc_tile_id == 6'd63) begin
                        done <= 1'b1;   // 這個 done 仍然在「收滿那拍」pulse
                    end
                
                    pool_acc_seen <= 4'b0000;
                end

            end
        end
    end

endmodule



module padding_8x8_top #(
    parameter IN_W  = 8,
    parameter IN_H  = 8,
    parameter OUT_W = IN_W + 2,   // 10
    parameter OUT_H = IN_H + 2,   // 10

    // ★總延遲:從「addr_in 被外部(你的top仲裁/BRAM)吃到」到「data_in 對應該addr」有效
    // 你目前常見是:p2_a_r 1拍 + BRAM 2拍 => 3
    parameter integer RD_TOTAL_LAT = 3
)(
    input  wire clk,
    input  wire rst,
    input  wire start,
    output reg  done,

    output reg  [9:0] addr_in,
    input  wire [7:0] data_in,

    output reg  [9:0] addr_out,
    output reg  [7:0] data_out,
    output reg        we_out
);

    // (x,y) on 10x10
    reg [5:0] x, y;
    
    reg run;


    // 0..63 for inner 8x8
    reg [9:0] in_cnt;

    // latch for incoming data
    reg [7:0] din_lat;

    // latch current "is inner?"
    reg is_inner_lat;

    // wait counter
    reg [7:0] wait_cnt;

    localparam S_ISSUE = 2'd0;
    localparam S_WAIT  = 2'd1;
    localparam S_LATCH = 2'd2;
    localparam S_WRITE = 2'd3;

    reg [1:0] state;

    wire is_inner_now = (x != 0 && x != OUT_W-1 && y != 0 && y != OUT_H-1);

    // next x,y helper
    wire last_x = (x == OUT_W-1);
    wire last_y = (y == OUT_H-1);

    always @(posedge clk or posedge rst) begin
        if (rst) begin
            x <= 0; y <= 0;
            in_cnt <= 0;
            run <= 1'b0;

            addr_in  <= 0;
            addr_out <= 0;
            data_out <= 0;
            we_out   <= 0;
            done     <= 0;

            din_lat      <= 0;
            is_inner_lat <= 0;
            wait_cnt     <= 0;
            state        <= S_ISSUE;
        end else begin
            we_out <= 1'b0;
            // latch start pulse -> keep running until done
            if (start) run <= 1'b1;
            if (done)  run <= 1'b0;

            if (!done && run) begin
                case (state)
                    // -------------------------------------------------
                    // ISSUE: 固定住本次要讀的 addr_in(若是inner)
                    // -------------------------------------------------
                    S_ISSUE: begin
                        is_inner_lat <= is_inner_now;

                        if (is_inner_now) begin
                            addr_in  <= in_cnt;                // ★把本次要讀的 in_cnt 發出去
                            wait_cnt <= (RD_TOTAL_LAT < 1) ? 1 : RD_TOTAL_LAT;
                            state    <= S_WAIT;
                        end else begin
                            // border:不需要讀 input,直接走寫出
                            state <= S_WRITE;
                        end
                    end

                    // -------------------------------------------------
                    // WAIT: 等到 data_in 對應 addr_in 穩定
                    // -------------------------------------------------
                    S_WAIT: begin
                        if (wait_cnt > 1)
                            wait_cnt <= wait_cnt - 1'b1;
                        else
                            state <= S_LATCH;
                    end

                    // -------------------------------------------------
                    // LATCH: 先把 data_in 鎖進暫存器,下一拍再寫
                    // -------------------------------------------------
                    S_LATCH: begin
                        din_lat <= data_in;
                        state   <= S_WRITE;
                    end

                    // -------------------------------------------------
                    // WRITE: 寫 output(inner用 din_lat,border=0)
                    // -------------------------------------------------
                    S_WRITE: begin
                        we_out   <= 1'b1;
                        addr_out <= y * OUT_W + x;

                        if (!is_inner_lat)
                            data_out <= 8'd0;
                        else begin
                            data_out <= din_lat;
                            in_cnt   <= in_cnt + 1'b1;
                        end

                        // advance x,y
                        if (last_x) begin
                            x <= 0;
                            if (last_y) begin
                                done <= 1'b1;
                            end else begin
                                y <= y + 1'b1;
                            end
                        end else begin
                            x <= x + 1'b1;
                        end

                        state <= S_ISSUE;
                    end
                endcase
            end
        end
    end

endmodule


Editor is loading...
Leave a Comment