feat: add SC-LDPC density evolution with threshold computation
Implement position-aware density evolution for SC-LDPC codes: - sc_density_evolution(): flooding-schedule DE tracking per-position error rates, demonstrating the wave decoding effect - compute_sc_threshold(): binary search for SC-LDPC threshold Uses flooding schedule (not layered) to avoid belief divergence from cross-position message interference in the coupled chain. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
231
model/sc_ldpc.py
231
model/sc_ldpc.py
@@ -16,6 +16,7 @@ import os
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from ldpc_sim import Q_BITS, Q_MAX, Q_MIN, OFFSET
|
||||
from density_evolution import de_cn_update_vectorized
|
||||
|
||||
|
||||
def split_protograph(B, w=2, seed=None):
|
||||
@@ -384,3 +385,233 @@ def windowed_decode(llr_q, H_full, L, w, z, n_base, m_base, W=5, max_iter=20,
|
||||
converged = np.all(syndrome == 0)
|
||||
|
||||
return decoded, converged, total_iterations
|
||||
|
||||
|
||||
def sc_density_evolution(B, L=50, w=2, lam_s=5.0, lam_b=0.1, z_pop=10000,
|
||||
max_iter=200, cn_mode='normalized', alpha=0.75):
|
||||
"""
|
||||
Position-aware density evolution for SC-LDPC codes.
|
||||
|
||||
Uses flooding schedule: compute total beliefs from channel + all CN->VN
|
||||
messages, then compute all new CN->VN messages simultaneously. This avoids
|
||||
the layered update instability that occurs when multiple CN rows modify the
|
||||
same VN belief within one iteration via different random permutations.
|
||||
|
||||
Tracks belief populations at each chain position to observe threshold
|
||||
saturation and the wave decoding effect where boundary positions
|
||||
converge before interior positions.
|
||||
|
||||
Args:
|
||||
B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected.
|
||||
L: Chain length (number of CN positions).
|
||||
w: Coupling width (number of component matrices).
|
||||
lam_s: Signal photons per slot.
|
||||
lam_b: Background photons per slot.
|
||||
z_pop: Population size for Monte Carlo DE.
|
||||
max_iter: Maximum number of decoding iterations.
|
||||
cn_mode: 'offset' or 'normalized' min-sum variant.
|
||||
alpha: Scaling factor for normalized mode.
|
||||
|
||||
Returns:
|
||||
(converged, per_position_errors, iterations_used)
|
||||
converged: True if all positions have error rate < 1e-3.
|
||||
per_position_errors: list of (L+w-1) error fractions at final iteration.
|
||||
iterations_used: number of iterations actually performed.
|
||||
"""
|
||||
m_base, n_base = B.shape
|
||||
n_vn_positions = L + w - 1
|
||||
|
||||
# Split protograph into w components (deterministic with seed=0)
|
||||
components = split_protograph(B, w=w, seed=0)
|
||||
|
||||
# Build edge list: for each CN position t, row r, collect connected
|
||||
# (component_idx, vn_position, base_col) tuples
|
||||
# edge_map[(t, r)] = list of (i, vn_pos, c) where component i connects
|
||||
# CN(t,r) to VN(vn_pos, c)
|
||||
edge_map = {}
|
||||
for t in range(L):
|
||||
for r in range(m_base):
|
||||
edges = []
|
||||
for i in range(w):
|
||||
vn_pos = t + i
|
||||
for c in range(n_base):
|
||||
if components[i][r, c] >= 0:
|
||||
ei = len(edges)
|
||||
edges.append((i, vn_pos, c))
|
||||
edge_map[(t, r)] = edges
|
||||
|
||||
# Build reverse map: for each VN group (p, c), list of (t, r, ei) edges
|
||||
vn_edges = {}
|
||||
for t in range(L):
|
||||
for r in range(m_base):
|
||||
edges = edge_map[(t, r)]
|
||||
for ei, (i, vn_pos, c) in enumerate(edges):
|
||||
key = (vn_pos, c)
|
||||
if key not in vn_edges:
|
||||
vn_edges[key] = []
|
||||
vn_edges[key].append((t, r, ei))
|
||||
|
||||
# Initialize channel LLR populations per (position, base_col)
|
||||
# All-zeros codeword assumed (standard for DE)
|
||||
log_ratio = np.log((lam_s + lam_b) / lam_b) if lam_b > 0 else 100.0
|
||||
scale = Q_MAX / 5.0
|
||||
|
||||
channel_llr = np.zeros((n_vn_positions, n_base, z_pop), dtype=np.float64)
|
||||
for p in range(n_vn_positions):
|
||||
for c in range(n_base):
|
||||
y = np.random.poisson(lam_b, size=z_pop)
|
||||
llr_float = lam_s - y * log_ratio
|
||||
llr_q = np.round(llr_float * scale).astype(np.float64)
|
||||
llr_q = np.clip(llr_q, Q_MIN, Q_MAX)
|
||||
channel_llr[p, c] = llr_q
|
||||
|
||||
# CN->VN message memory: msg[(t, r, ei)] = z_pop samples
|
||||
msg = {}
|
||||
for t in range(L):
|
||||
for r in range(m_base):
|
||||
for ei in range(len(edge_map[(t, r)])):
|
||||
msg[(t, r, ei)] = np.zeros(z_pop, dtype=np.float64)
|
||||
|
||||
# Convergence uses interior positions (excluding w-1 boundary positions
|
||||
# on each side) to avoid the boundary rate effect. Boundary positions
|
||||
# have fewer CN connections and genuinely higher error rates.
|
||||
conv_threshold = 1e-3
|
||||
interior_start = w - 1 # skip first w-1 positions
|
||||
interior_end = n_vn_positions - (w - 1) # skip last w-1 positions
|
||||
if interior_end <= interior_start:
|
||||
# Chain too short for interior; use all positions
|
||||
interior_start = 0
|
||||
interior_end = n_vn_positions
|
||||
|
||||
iterations_used = 0
|
||||
for it in range(max_iter):
|
||||
# Step 1: Compute total beliefs for each VN group
|
||||
# beliefs[p,c] = channel_llr[p,c] + sum of randomly permuted CN->VN msgs
|
||||
beliefs = channel_llr.copy()
|
||||
permuted_msg = {} # store the permuted version used in belief computation
|
||||
|
||||
for (p, c), edge_list in vn_edges.items():
|
||||
for (t, r, ei) in edge_list:
|
||||
perm = np.random.permutation(z_pop)
|
||||
pmsg = msg[(t, r, ei)][perm]
|
||||
permuted_msg[(t, r, ei)] = pmsg
|
||||
beliefs[p, c] += pmsg
|
||||
|
||||
# Note: beliefs are NOT clipped to Q_MIN/Q_MAX. They accumulate in
|
||||
# full float64 precision to avoid saturation artifacts. Only the VN->CN
|
||||
# and CN->VN messages are quantized (clipped), matching the distinction
|
||||
# between internal node accumulation and wire-level quantization.
|
||||
|
||||
# Step 2: Compute new CN->VN messages (flooding: all at once)
|
||||
new_msg = {}
|
||||
for t in range(L):
|
||||
for r in range(m_base):
|
||||
edges = edge_map[(t, r)]
|
||||
dc = len(edges)
|
||||
if dc < 2:
|
||||
for ei in range(dc):
|
||||
new_msg[(t, r, ei)] = np.zeros(z_pop, dtype=np.float64)
|
||||
continue
|
||||
|
||||
# Compute VN->CN extrinsic messages
|
||||
vn_to_cn = []
|
||||
for ei, (i, vn_pos, c) in enumerate(edges):
|
||||
# Extrinsic = belief - this CN's permuted contribution
|
||||
# Apply random interleaving permutation
|
||||
perm = np.random.permutation(z_pop)
|
||||
ext = beliefs[vn_pos, c][perm] - permuted_msg[(t, r, ei)][perm]
|
||||
ext = np.clip(np.round(ext), Q_MIN, Q_MAX)
|
||||
vn_to_cn.append(ext)
|
||||
|
||||
# CN update (vectorized min-sum)
|
||||
cn_to_vn = de_cn_update_vectorized(
|
||||
vn_to_cn, offset=OFFSET,
|
||||
cn_mode=cn_mode, alpha=alpha
|
||||
)
|
||||
|
||||
for ei in range(dc):
|
||||
new_msg[(t, r, ei)] = cn_to_vn[ei]
|
||||
|
||||
msg = new_msg
|
||||
iterations_used += 1
|
||||
|
||||
# Check convergence: per-position error rates
|
||||
per_position_errors = _compute_position_errors(
|
||||
beliefs, n_vn_positions, n_base, z_pop
|
||||
)
|
||||
|
||||
# Converge based on interior positions only
|
||||
interior_errors = per_position_errors[interior_start:interior_end]
|
||||
if all(e < conv_threshold for e in interior_errors):
|
||||
return True, per_position_errors, iterations_used
|
||||
|
||||
# Did not converge; return final state
|
||||
per_position_errors = _compute_position_errors(
|
||||
beliefs, n_vn_positions, n_base, z_pop
|
||||
)
|
||||
|
||||
return False, per_position_errors, iterations_used
|
||||
|
||||
|
||||
def _compute_position_errors(beliefs, n_vn_positions, n_base, z_pop):
|
||||
"""Compute per-position error fractions from belief arrays."""
|
||||
per_position_errors = []
|
||||
for p in range(n_vn_positions):
|
||||
wrong = 0
|
||||
total = 0
|
||||
for c in range(n_base):
|
||||
wrong += np.sum(beliefs[p, c] < 0)
|
||||
total += z_pop
|
||||
per_position_errors.append(wrong / total)
|
||||
return per_position_errors
|
||||
|
||||
|
||||
def compute_sc_threshold(B, L=50, w=2, lam_b=0.1, z_pop=10000, tol=0.25,
|
||||
cn_mode='normalized', alpha=0.75):
|
||||
"""
|
||||
Binary search for minimum lam_s where SC density evolution converges.
|
||||
|
||||
Args:
|
||||
B: Base matrix (m_base x n_base).
|
||||
L: Chain length.
|
||||
w: Coupling width.
|
||||
lam_b: Background photon rate.
|
||||
z_pop: DE population size.
|
||||
tol: Search tolerance (stop when hi - lo <= tol).
|
||||
cn_mode: 'offset' or 'normalized'.
|
||||
alpha: Scaling factor for normalized mode.
|
||||
|
||||
Returns:
|
||||
Threshold lam_s* (upper bound from binary search).
|
||||
"""
|
||||
lo = 0.1
|
||||
hi = 20.0
|
||||
|
||||
# Verify hi converges
|
||||
converged, _, _ = sc_density_evolution(
|
||||
B, L=L, w=w, lam_s=hi, lam_b=lam_b,
|
||||
z_pop=z_pop, max_iter=100, cn_mode=cn_mode, alpha=alpha
|
||||
)
|
||||
if not converged:
|
||||
return hi # Doesn't converge even at hi
|
||||
|
||||
# Verify lo doesn't converge
|
||||
converged, _, _ = sc_density_evolution(
|
||||
B, L=L, w=w, lam_s=lo, lam_b=lam_b,
|
||||
z_pop=z_pop, max_iter=100, cn_mode=cn_mode, alpha=alpha
|
||||
)
|
||||
if converged:
|
||||
return lo # Converges even at lo
|
||||
|
||||
while hi - lo > tol:
|
||||
mid = (lo + hi) / 2
|
||||
converged, _, _ = sc_density_evolution(
|
||||
B, L=L, w=w, lam_s=mid, lam_b=lam_b,
|
||||
z_pop=z_pop, max_iter=100, cn_mode=cn_mode, alpha=alpha
|
||||
)
|
||||
if converged:
|
||||
hi = mid
|
||||
else:
|
||||
lo = mid
|
||||
|
||||
return hi
|
||||
|
||||
Reference in New Issue
Block a user