Untitled

 avatar
unknown
plain_text
18 days ago
6.0 kB
2
Indexable
"""
1D ISM-like forward model based on Zunino et al. (2023).

The ISM forward model for detector element at position x_d is:
    h(x_s | x_d) = h_exc(-x_s) * h_det(x_s - x_d)
    i(x_s | x_d) = o(x_s) * h(x_s | x_d)       (convolution)
    observed     ~ Poisson{ i(x_s | x_d) }

In 1D we use N_d = 5 detector elements (the 1D analogue of a 5x5 SPAD array).

Physical parameters (from Zunino et al.):
  - lambda_exc = 640 nm, lambda_em = 680 nm, NA = 1.4
  - PSF FWHM ~ 197 nm  =>  sigma_exc ~ 84 nm
  - pixel size = 80 nm
  - detector pitch projected onto sample ~ 1 pixel
  - simplified: mu(x_d) = x_d / 2  (Gaussian, no Stokes shift)
  - noise: Poisson only (SPAD detectors)
"""

import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import convolve1d
from torch import trunc_

# physical parameters in pixel units
pixel_size_nm = 80.0
fwhm_nm = 197.0
sigma_exc_px = (fwhm_nm / 2.355) / pixel_size_nm  # ≈ 1.04 pixels
sigma_det_px = sigma_exc_px * 1.05                 # slightly wider (Stokes + pinhole)

# In the real setup, projected detector pitch ≈ 167 nm ≈ 2 pixels
# (physical pitch 75 µm / magnification 450×)
N_d = 5                                            # detector elements
det_pitch_px = 2.0                                  # detector pitch in pixels
detector_positions = (np.arange(N_d) - N_d // 2) * det_pitch_px  # [-4, -2, 0, 2, 4]

# grid 
d = 128                                             # signal length
xs = np.arange(d, dtype=float)
# truncation for lower dimensionality
trunc_start = 40
trunc_end = d - 20

#  build PSFs 
def gaussian_1d(x, mu, sigma):
    return np.exp(-0.5 * ((x - mu) / sigma) ** 2) / (sigma * np.sqrt(2 * np.pi))

def ism_psf(xs, xd, sigma_exc, sigma_det):
    """
    h(x_s | x_d) = h_exc(x_s) * h_det(x_s - x_d)
    Product of excitation and (shifted) detection Gaussians.
    """
    h_exc = gaussian_1d(xs, 0, sigma_exc)           # centred at scan position
    h_det = gaussian_1d(xs, xd, sigma_det)           # centred at detector offset
    h = h_exc * h_det
    # do NOT normalise: the integral gives the fingerprint f(x_d)
    # which encodes how much signal each detector collects
    return h

# centred PSF support (for convolution kernel)
kernel_half = 15
kernel_xs = np.arange(-kernel_half, kernel_half + 1, dtype=float)

psfs = {}
for xd in detector_positions:
    h = ism_psf(kernel_xs, xd, sigma_exc_px, sigma_det_px)
    psfs[xd] = h

# generate a piecewise-constant object 
def random_piecewise_constant(d, n_seg_range=(4, 8), amp_range=(50.0, 500.0), rng=None):
    """Non-negative piecewise constant signal (photon-count scale)."""
    if rng is None:
        rng = np.random.default_rng()
    n_seg = rng.integers(*n_seg_range, endpoint=True)
    cps = np.sort(rng.integers(1, d - 1, size=n_seg - 1))
    cps = np.concatenate(([0], cps, [d]))
    sig = np.zeros(d)
    for j in range(len(cps) - 1):
        sig[cps[j]:cps[j+1]] = rng.uniform(*amp_range)
    return sig

rng = np.random.default_rng(42)
obj = random_piecewise_constant(d, rng=rng)

# forward model: convolve + Poisson noise 
observations_clean = {}
observations_noisy = {}
fingerprint = np.zeros(N_d)

for k, xd in enumerate(detector_positions):
    h = psfs[xd]
    img_clean = convolve1d(obj, h, mode='constant')
    img_clean = np.clip(img_clean, 0, None)          # ensure non-negative
    fingerprint[k] = img_clean.sum()
    observations_clean[xd] = img_clean[trunc_start:trunc_end]
    observations_noisy[xd] = rng.poisson(img_clean[trunc_start:trunc_end]).astype(float)

# normalise fingerprint
fingerprint /= fingerprint.max()

#  plot 
fig, axes = plt.subplots(3, 1, figsize=(10, 10))

# (a) PSF profiles
ax = axes[0]
for xd in detector_positions:
    shift_label = f"$x_d$ = {xd:+.0f} px (shift ≈ {xd/2:+.1f} px)"
    ax.plot(kernel_xs * pixel_size_nm, psfs[xd], label=shift_label)
ax.set_xlabel("position (nm)")
ax.set_ylabel("PSF (normalised)")
ax.set_title(f"1D ISM PSFs  (FWHM ≈ {fwhm_nm:.0f} nm, pixel = {pixel_size_nm:.0f} nm)")
ax.legend(fontsize=8)

# (b) object + clean observations
ax = axes[1]
ax.plot(xs[trunc_start:trunc_end] * pixel_size_nm, obj[trunc_start:trunc_end], 'k-', lw=2, label='object $o(x_s)$', drawstyle='steps-post')
for xd in detector_positions:
    ax.plot(xs[trunc_start:trunc_end] * pixel_size_nm, observations_clean[xd], alpha=0.6,
            label=f'$i(x_s|x_d={xd:+.0f})$')
ax.set_xlabel("position (nm)")
ax.set_ylabel("photon counts")
ax.set_title("Object and clean observations per detector")
ax.legend(fontsize=8, ncol=2)

# (c) noisy observations (what you actually measure)
ax = axes[2]
for xd in detector_positions:
    ax.plot(xs[trunc_start:trunc_end] * pixel_size_nm, observations_noisy[xd], alpha=0.6,
            label=f'noisy $i(x_s|x_d={xd:+.0f})$')
ax.plot(xs[trunc_start:trunc_end] * pixel_size_nm, obj[trunc_start:trunc_end], 'k--', lw=1.5, alpha=0.4, label='true object',
        drawstyle='steps-post')
ax.set_xlabel("position (nm)")
ax.set_ylabel("photon counts")
ax.set_title("Noisy observations (Poisson)")
ax.legend(fontsize=8, ncol=2)

plt.tight_layout()
plt.savefig("results/ism_1d_forward.png", dpi=150)
plt.show()

# also show fingerprint
fig2, ax2 = plt.subplots(figsize=(4, 3))
ax2.bar(detector_positions, fingerprint, color='steelblue')
ax2.set_xlabel("detector position $x_d$ (px)")
ax2.set_ylabel("fingerprint $f(x_d)$ (normalised)")
ax2.set_title("Detector fingerprint")
plt.tight_layout()
plt.savefig("results/ism_1d_fingerprint.png", dpi=150)
plt.show()

print(f"\nsigma_exc  = {sigma_exc_px:.2f} px  ({sigma_exc_px * pixel_size_nm:.1f} nm)")
print(f"sigma_det  = {sigma_det_px:.2f} px  ({sigma_det_px * pixel_size_nm:.1f} nm)")
print(f"PSF FWHM   ≈ {fwhm_nm:.0f} nm")
print(f"pixel size = {pixel_size_nm:.0f} nm")
print(f"detectors  = {list(detector_positions)}")
print(f"fingerprint= {np.round(fingerprint, 3)}")
Editor is loading...
Leave a Comment