SC-LDPC threshold saturation results: - SC original staircase: 2.28 photons/slot (vs 4.76 uncoupled) - SC DE-optimized: 1.03 photons/slot (vs 3.21 uncoupled) - Shannon limit: 0.47 photons/slot (remaining gap: 3.4 dB) Add CLI to sc_ldpc.py (threshold, fer-compare, full subcommands). Add threshold progression, SC threshold bars, and SC FER plots. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
811 lines
28 KiB
Python
811 lines
28 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Spatially-Coupled LDPC (SC-LDPC) Code Construction and Decoding
|
|
|
|
SC-LDPC codes achieve threshold saturation: the BP threshold approaches
|
|
the MAP threshold, closing a significant portion of the gap to Shannon limit.
|
|
|
|
Construction: replicate a protograph base matrix along a chain of L positions
|
|
with coupling width w, creating a convolutional-like structure.
|
|
"""
|
|
|
|
import numpy as np
|
|
import argparse
|
|
import json
|
|
import sys
|
|
import os
|
|
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
|
|
from ldpc_sim import Q_BITS, Q_MAX, Q_MIN, OFFSET
|
|
from density_evolution import de_cn_update_vectorized
|
|
|
|
|
|
def split_protograph(B, w=2, seed=None):
|
|
"""
|
|
Split a protograph base matrix into w component matrices.
|
|
|
|
Each edge in B (entry >= 0) is randomly assigned to exactly one of the
|
|
w component matrices. The component that receives the edge gets value 0
|
|
(circulant shift assigned later during chain construction), while all
|
|
other components get -1 (no connection) at that position.
|
|
|
|
Args:
|
|
B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected.
|
|
w: Coupling width (number of component matrices).
|
|
seed: Random seed for reproducibility.
|
|
|
|
Returns:
|
|
List of w component matrices, each with shape (m_base, n_base).
|
|
Component values: 0 where edge is assigned, -1 otherwise.
|
|
"""
|
|
rng = np.random.default_rng(seed)
|
|
m_base, n_base = B.shape
|
|
|
|
# Initialize all components to -1 (no connection)
|
|
components = [np.full((m_base, n_base), -1, dtype=np.int16) for _ in range(w)]
|
|
|
|
for r in range(m_base):
|
|
for c in range(n_base):
|
|
if B[r, c] >= 0:
|
|
# Randomly assign this edge to one component
|
|
chosen = rng.integers(0, w)
|
|
components[chosen][r, c] = 0
|
|
|
|
return components
|
|
|
|
|
|
def build_sc_chain(B, L=20, w=2, z=32, seed=None):
|
|
"""
|
|
Build the full SC-LDPC parity-check matrix as a dense binary matrix.
|
|
|
|
The chain has L CN positions and L+w-1 VN positions. Position t's CNs
|
|
connect to VN positions t, t+1, ..., t+w-1 using the w component matrices.
|
|
|
|
Args:
|
|
B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected.
|
|
L: Chain length (number of CN positions).
|
|
w: Coupling width.
|
|
z: Lifting factor (circulant size).
|
|
seed: Random seed for reproducibility.
|
|
|
|
Returns:
|
|
H_full: Binary parity-check matrix of shape
|
|
(L * m_base * z) x ((L + w - 1) * n_base * z).
|
|
components: List of w component matrices from split_protograph.
|
|
meta: Dictionary with construction metadata.
|
|
"""
|
|
rng = np.random.default_rng(seed)
|
|
m_base, n_base = B.shape
|
|
|
|
# Split the protograph into w components
|
|
# Use a sub-seed derived from the main seed for splitting
|
|
split_seed = int(rng.integers(0, 2**31)) if seed is not None else None
|
|
components = split_protograph(B, w=w, seed=split_seed)
|
|
|
|
n_cn_positions = L
|
|
n_vn_positions = L + w - 1
|
|
|
|
total_rows = n_cn_positions * m_base * z
|
|
total_cols = n_vn_positions * n_base * z
|
|
|
|
H_full = np.zeros((total_rows, total_cols), dtype=np.int8)
|
|
|
|
# For each CN position t, for each component i, connect CN group t
|
|
# to VN group t+i using component B_i with random circulant shifts.
|
|
for t in range(L):
|
|
for i in range(w):
|
|
vn_pos = t + i # VN position this component connects to
|
|
comp = components[i]
|
|
|
|
for r in range(m_base):
|
|
for c in range(n_base):
|
|
if comp[r, c] >= 0:
|
|
# This entry is connected; assign a random circulant shift
|
|
shift = int(rng.integers(0, z))
|
|
|
|
# Place the z x z circulant permutation matrix
|
|
# CN rows: [t * m_base * z + r * z, ... + (r+1)*z)
|
|
# VN cols: [vn_pos * n_base * z + c * z, ... + (c+1)*z)
|
|
for zz in range(z):
|
|
row_idx = t * m_base * z + r * z + zz
|
|
col_idx = vn_pos * n_base * z + c * z + (zz + shift) % z
|
|
H_full[row_idx, col_idx] = 1
|
|
|
|
meta = {
|
|
'L': L,
|
|
'w': w,
|
|
'z': z,
|
|
'm_base': m_base,
|
|
'n_base': n_base,
|
|
'total_rows': total_rows,
|
|
'total_cols': total_cols,
|
|
'n_cn_positions': n_cn_positions,
|
|
'n_vn_positions': n_vn_positions,
|
|
'rate_design': 1.0 - (total_rows / total_cols),
|
|
}
|
|
|
|
return H_full, components, meta
|
|
|
|
|
|
def sc_encode(info_bits, H_full, k_total):
|
|
"""
|
|
Encode using GF(2) Gaussian elimination on the SC-LDPC parity-check matrix.
|
|
|
|
Places info_bits in the first k_total positions of the codeword and solves
|
|
for the remaining parity bits such that H_full * codeword = 0 (mod 2).
|
|
|
|
Handles rank-deficient H matrices (common with SC-LDPC boundary effects)
|
|
by leaving free variables as zero.
|
|
|
|
Args:
|
|
info_bits: Information bits array of length k_total.
|
|
H_full: Binary parity-check matrix (m x n).
|
|
k_total: Number of information bits.
|
|
|
|
Returns:
|
|
codeword: Binary codeword array of length n (= H_full.shape[1]).
|
|
|
|
Raises:
|
|
ValueError: If encoding fails (syndrome is nonzero).
|
|
"""
|
|
m_full, n_full = H_full.shape
|
|
n_parity = n_full - k_total
|
|
|
|
assert len(info_bits) == k_total, (
|
|
f"info_bits length {len(info_bits)} != k_total {k_total}"
|
|
)
|
|
|
|
# Split H into information and parity parts
|
|
# H_full = [H_info | H_parity]
|
|
H_info = H_full[:, :k_total]
|
|
H_parity = H_full[:, k_total:]
|
|
|
|
# Compute RHS: H_info * info_bits mod 2
|
|
rhs = (H_info @ info_bits) % 2
|
|
|
|
# Solve H_parity * p = rhs (mod 2) via Gaussian elimination
|
|
# Build augmented matrix [H_parity | rhs]
|
|
aug = np.zeros((m_full, n_parity + 1), dtype=np.int8)
|
|
aug[:, :n_parity] = H_parity.copy()
|
|
aug[:, n_parity] = rhs
|
|
|
|
# Forward elimination with partial pivoting
|
|
pivot_row = 0
|
|
pivot_cols = []
|
|
|
|
for col in range(n_parity):
|
|
# Find a pivot row for this column
|
|
found = -1
|
|
for r in range(pivot_row, m_full):
|
|
if aug[r, col] == 1:
|
|
found = r
|
|
break
|
|
|
|
if found < 0:
|
|
# No pivot for this column (rank deficient) - skip
|
|
continue
|
|
|
|
# Swap pivot row into position
|
|
if found != pivot_row:
|
|
aug[[pivot_row, found]] = aug[[found, pivot_row]]
|
|
|
|
# Eliminate all other rows with a 1 in this column
|
|
for r in range(m_full):
|
|
if r != pivot_row and aug[r, col] == 1:
|
|
aug[r] = (aug[r] + aug[pivot_row]) % 2
|
|
|
|
pivot_cols.append(col)
|
|
pivot_row += 1
|
|
|
|
# Back-substitute: extract parity bit values from pivot columns
|
|
parity = np.zeros(n_parity, dtype=np.int8)
|
|
for i, col in enumerate(pivot_cols):
|
|
parity[col] = aug[i, n_parity]
|
|
|
|
# Assemble codeword: [info_bits | parity]
|
|
codeword = np.concatenate([info_bits.astype(np.int8), parity])
|
|
|
|
# Verify syndrome
|
|
syndrome = (H_full @ codeword) % 2
|
|
if not np.all(syndrome == 0):
|
|
raise ValueError(
|
|
f"SC-LDPC encoding failed: syndrome weight = {syndrome.sum()}, "
|
|
f"H rank ~{len(pivot_cols)}, H rows = {m_full}"
|
|
)
|
|
|
|
return codeword
|
|
|
|
|
|
def windowed_decode(llr_q, H_full, L, w, z, n_base, m_base, W=5, max_iter=20,
|
|
cn_mode='normalized', alpha=0.75):
|
|
"""
|
|
Windowed decoding for SC-LDPC codes.
|
|
|
|
Decode a sliding window of W positions at a time, fixing decoded positions
|
|
as the window advances. Uses flooding schedule within each window iteration
|
|
to avoid message staleness on the expanded binary H matrix.
|
|
|
|
Args:
|
|
llr_q: quantized channel LLRs for entire SC codeword
|
|
H_full: full SC-LDPC parity check matrix (binary)
|
|
L: chain length (number of CN positions)
|
|
w: coupling width
|
|
z: lifting factor
|
|
n_base: base matrix columns
|
|
m_base: base matrix rows
|
|
W: window size in positions (default 5)
|
|
max_iter: iterations per window position
|
|
cn_mode: 'offset' or 'normalized'
|
|
alpha: scaling factor for normalized mode
|
|
|
|
Returns:
|
|
(decoded_bits, converged, total_iterations)
|
|
"""
|
|
total_rows, total_cols = H_full.shape
|
|
n_vn_positions = L + w - 1
|
|
|
|
def sat_clip(v):
|
|
return max(Q_MIN, min(Q_MAX, int(v)))
|
|
|
|
def cn_update_row(msgs_in):
|
|
"""Min-sum CN update for a list of incoming VN->CN messages."""
|
|
dc = len(msgs_in)
|
|
if dc == 0:
|
|
return []
|
|
signs = [1 if m < 0 else 0 for m in msgs_in]
|
|
mags = [abs(m) for m in msgs_in]
|
|
sign_xor = sum(signs) % 2
|
|
|
|
min1 = Q_MAX
|
|
min2 = Q_MAX
|
|
min1_idx = 0
|
|
for i in range(dc):
|
|
if mags[i] < min1:
|
|
min2 = min1
|
|
min1 = mags[i]
|
|
min1_idx = i
|
|
elif mags[i] < min2:
|
|
min2 = mags[i]
|
|
|
|
msgs_out = []
|
|
for j in range(dc):
|
|
mag = min2 if j == min1_idx else min1
|
|
if cn_mode == 'normalized':
|
|
mag = int(mag * alpha)
|
|
else:
|
|
mag = max(0, mag - OFFSET)
|
|
sgn = sign_xor ^ signs[j]
|
|
val = -mag if sgn else mag
|
|
msgs_out.append(val)
|
|
return msgs_out
|
|
|
|
# Precompute CN->VN adjacency: for each row, list of connected column indices
|
|
cn_neighbors = []
|
|
for row in range(total_rows):
|
|
cn_neighbors.append(np.where(H_full[row] == 1)[0].tolist())
|
|
|
|
# Precompute VN->CN adjacency: for each column, list of connected row indices
|
|
vn_neighbors = []
|
|
for col in range(total_cols):
|
|
vn_neighbors.append(np.where(H_full[:, col] == 1)[0].tolist())
|
|
|
|
# Channel LLRs (fixed, never modified)
|
|
channel_llr = np.array([int(x) for x in llr_q], dtype=np.int32)
|
|
|
|
# CN->VN message memory: msg_mem[(row, col)] = last CN->VN message
|
|
msg_mem = {}
|
|
for row in range(total_rows):
|
|
for col in cn_neighbors[row]:
|
|
msg_mem[(row, col)] = 0
|
|
|
|
# Output array for hard decisions
|
|
decoded = np.zeros(total_cols, dtype=np.int8)
|
|
total_iterations = 0
|
|
|
|
# Process each target VN position
|
|
for p in range(n_vn_positions):
|
|
# Define window CN positions: max(0, p-W+1) to min(p, L-1)
|
|
cn_pos_start = max(0, p - W + 1)
|
|
cn_pos_end = min(p, L - 1)
|
|
|
|
# Collect all CN rows in the window
|
|
window_cn_rows = []
|
|
for cn_pos in range(cn_pos_start, cn_pos_end + 1):
|
|
row_start = cn_pos * m_base * z
|
|
row_end = (cn_pos + 1) * m_base * z
|
|
for r in range(row_start, row_end):
|
|
window_cn_rows.append(r)
|
|
|
|
if len(window_cn_rows) == 0:
|
|
# No CN rows cover this position; just make hard decisions from channel LLR
|
|
# plus accumulated CN messages
|
|
vn_col_start = p * n_base * z
|
|
vn_col_end = min((p + 1) * n_base * z, total_cols)
|
|
for c in range(vn_col_start, vn_col_end):
|
|
belief = int(channel_llr[c])
|
|
for row in vn_neighbors[c]:
|
|
belief += msg_mem[(row, c)]
|
|
decoded[c] = 1 if belief < 0 else 0
|
|
continue
|
|
|
|
# Collect all VN columns that are touched by the window CN rows
|
|
window_vn_cols_set = set()
|
|
for row in window_cn_rows:
|
|
for col in cn_neighbors[row]:
|
|
window_vn_cols_set.add(col)
|
|
window_vn_cols = sorted(window_vn_cols_set)
|
|
|
|
# Run max_iter flooding iterations on the window CN rows
|
|
for it in range(max_iter):
|
|
# Step 1: Compute beliefs for all VN columns in window
|
|
# belief[col] = channel_llr[col] + sum of all CN->VN messages to col
|
|
beliefs = {}
|
|
for col in window_vn_cols:
|
|
b = int(channel_llr[col])
|
|
for row in vn_neighbors[col]:
|
|
b += msg_mem[(row, col)]
|
|
beliefs[col] = sat_clip(b)
|
|
|
|
# Step 2: For each CN row in the window, compute VN->CN and CN->VN
|
|
new_msgs = {}
|
|
for row in window_cn_rows:
|
|
cols = cn_neighbors[row]
|
|
dc = len(cols)
|
|
if dc == 0:
|
|
continue
|
|
|
|
# VN->CN messages: belief - old CN->VN message from this row
|
|
vn_to_cn = []
|
|
for col in cols:
|
|
vn_to_cn.append(sat_clip(beliefs[col] - msg_mem[(row, col)]))
|
|
|
|
# CN update
|
|
cn_to_vn = cn_update_row(vn_to_cn)
|
|
|
|
# Store new messages (apply after all rows computed)
|
|
for ci, col in enumerate(cols):
|
|
new_msgs[(row, col)] = cn_to_vn[ci]
|
|
|
|
# Step 3: Update message memory
|
|
for (row, col), val in new_msgs.items():
|
|
msg_mem[(row, col)] = val
|
|
|
|
total_iterations += 1
|
|
|
|
# Make hard decisions for VN position p's bits
|
|
vn_col_start = p * n_base * z
|
|
vn_col_end = min((p + 1) * n_base * z, total_cols)
|
|
for c in range(vn_col_start, vn_col_end):
|
|
belief = int(channel_llr[c])
|
|
for row in vn_neighbors[c]:
|
|
belief += msg_mem[(row, c)]
|
|
decoded[c] = 1 if belief < 0 else 0
|
|
|
|
# Check if all decoded bits form a valid codeword
|
|
syndrome = (H_full @ decoded) % 2
|
|
converged = np.all(syndrome == 0)
|
|
|
|
return decoded, converged, total_iterations
|
|
|
|
|
|
def sc_density_evolution(B, L=50, w=2, lam_s=5.0, lam_b=0.1, z_pop=10000,
|
|
max_iter=200, cn_mode='normalized', alpha=0.75):
|
|
"""
|
|
Position-aware density evolution for SC-LDPC codes.
|
|
|
|
Uses flooding schedule: compute total beliefs from channel + all CN->VN
|
|
messages, then compute all new CN->VN messages simultaneously. This avoids
|
|
the layered update instability that occurs when multiple CN rows modify the
|
|
same VN belief within one iteration via different random permutations.
|
|
|
|
Tracks belief populations at each chain position to observe threshold
|
|
saturation and the wave decoding effect where boundary positions
|
|
converge before interior positions.
|
|
|
|
Args:
|
|
B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected.
|
|
L: Chain length (number of CN positions).
|
|
w: Coupling width (number of component matrices).
|
|
lam_s: Signal photons per slot.
|
|
lam_b: Background photons per slot.
|
|
z_pop: Population size for Monte Carlo DE.
|
|
max_iter: Maximum number of decoding iterations.
|
|
cn_mode: 'offset' or 'normalized' min-sum variant.
|
|
alpha: Scaling factor for normalized mode.
|
|
|
|
Returns:
|
|
(converged, per_position_errors, iterations_used)
|
|
converged: True if all positions have error rate < 1e-3.
|
|
per_position_errors: list of (L+w-1) error fractions at final iteration.
|
|
iterations_used: number of iterations actually performed.
|
|
"""
|
|
m_base, n_base = B.shape
|
|
n_vn_positions = L + w - 1
|
|
|
|
# Split protograph into w components (deterministic with seed=0)
|
|
components = split_protograph(B, w=w, seed=0)
|
|
|
|
# Build edge list: for each CN position t, row r, collect connected
|
|
# (component_idx, vn_position, base_col) tuples
|
|
# edge_map[(t, r)] = list of (i, vn_pos, c) where component i connects
|
|
# CN(t,r) to VN(vn_pos, c)
|
|
edge_map = {}
|
|
for t in range(L):
|
|
for r in range(m_base):
|
|
edges = []
|
|
for i in range(w):
|
|
vn_pos = t + i
|
|
for c in range(n_base):
|
|
if components[i][r, c] >= 0:
|
|
ei = len(edges)
|
|
edges.append((i, vn_pos, c))
|
|
edge_map[(t, r)] = edges
|
|
|
|
# Build reverse map: for each VN group (p, c), list of (t, r, ei) edges
|
|
vn_edges = {}
|
|
for t in range(L):
|
|
for r in range(m_base):
|
|
edges = edge_map[(t, r)]
|
|
for ei, (i, vn_pos, c) in enumerate(edges):
|
|
key = (vn_pos, c)
|
|
if key not in vn_edges:
|
|
vn_edges[key] = []
|
|
vn_edges[key].append((t, r, ei))
|
|
|
|
# Initialize channel LLR populations per (position, base_col)
|
|
# All-zeros codeword assumed (standard for DE)
|
|
log_ratio = np.log((lam_s + lam_b) / lam_b) if lam_b > 0 else 100.0
|
|
scale = Q_MAX / 5.0
|
|
|
|
channel_llr = np.zeros((n_vn_positions, n_base, z_pop), dtype=np.float64)
|
|
for p in range(n_vn_positions):
|
|
for c in range(n_base):
|
|
y = np.random.poisson(lam_b, size=z_pop)
|
|
llr_float = lam_s - y * log_ratio
|
|
llr_q = np.round(llr_float * scale).astype(np.float64)
|
|
llr_q = np.clip(llr_q, Q_MIN, Q_MAX)
|
|
channel_llr[p, c] = llr_q
|
|
|
|
# CN->VN message memory: msg[(t, r, ei)] = z_pop samples
|
|
msg = {}
|
|
for t in range(L):
|
|
for r in range(m_base):
|
|
for ei in range(len(edge_map[(t, r)])):
|
|
msg[(t, r, ei)] = np.zeros(z_pop, dtype=np.float64)
|
|
|
|
# Convergence uses interior positions (excluding w-1 boundary positions
|
|
# on each side) to avoid the boundary rate effect. Boundary positions
|
|
# have fewer CN connections and genuinely higher error rates.
|
|
conv_threshold = 1e-3
|
|
interior_start = w - 1 # skip first w-1 positions
|
|
interior_end = n_vn_positions - (w - 1) # skip last w-1 positions
|
|
if interior_end <= interior_start:
|
|
# Chain too short for interior; use all positions
|
|
interior_start = 0
|
|
interior_end = n_vn_positions
|
|
|
|
iterations_used = 0
|
|
for it in range(max_iter):
|
|
# Step 1: Compute total beliefs for each VN group
|
|
# beliefs[p,c] = channel_llr[p,c] + sum of randomly permuted CN->VN msgs
|
|
beliefs = channel_llr.copy()
|
|
permuted_msg = {} # store the permuted version used in belief computation
|
|
|
|
for (p, c), edge_list in vn_edges.items():
|
|
for (t, r, ei) in edge_list:
|
|
perm = np.random.permutation(z_pop)
|
|
pmsg = msg[(t, r, ei)][perm]
|
|
permuted_msg[(t, r, ei)] = pmsg
|
|
beliefs[p, c] += pmsg
|
|
|
|
# Note: beliefs are NOT clipped to Q_MIN/Q_MAX. They accumulate in
|
|
# full float64 precision to avoid saturation artifacts. Only the VN->CN
|
|
# and CN->VN messages are quantized (clipped), matching the distinction
|
|
# between internal node accumulation and wire-level quantization.
|
|
|
|
# Step 2: Compute new CN->VN messages (flooding: all at once)
|
|
new_msg = {}
|
|
for t in range(L):
|
|
for r in range(m_base):
|
|
edges = edge_map[(t, r)]
|
|
dc = len(edges)
|
|
if dc < 2:
|
|
for ei in range(dc):
|
|
new_msg[(t, r, ei)] = np.zeros(z_pop, dtype=np.float64)
|
|
continue
|
|
|
|
# Compute VN->CN extrinsic messages
|
|
vn_to_cn = []
|
|
for ei, (i, vn_pos, c) in enumerate(edges):
|
|
# Extrinsic = belief - this CN's permuted contribution
|
|
# Apply random interleaving permutation
|
|
perm = np.random.permutation(z_pop)
|
|
ext = beliefs[vn_pos, c][perm] - permuted_msg[(t, r, ei)][perm]
|
|
ext = np.clip(np.round(ext), Q_MIN, Q_MAX)
|
|
vn_to_cn.append(ext)
|
|
|
|
# CN update (vectorized min-sum)
|
|
cn_to_vn = de_cn_update_vectorized(
|
|
vn_to_cn, offset=OFFSET,
|
|
cn_mode=cn_mode, alpha=alpha
|
|
)
|
|
|
|
for ei in range(dc):
|
|
new_msg[(t, r, ei)] = cn_to_vn[ei]
|
|
|
|
msg = new_msg
|
|
iterations_used += 1
|
|
|
|
# Check convergence: per-position error rates
|
|
per_position_errors = _compute_position_errors(
|
|
beliefs, n_vn_positions, n_base, z_pop
|
|
)
|
|
|
|
# Converge based on interior positions only
|
|
interior_errors = per_position_errors[interior_start:interior_end]
|
|
if all(e < conv_threshold for e in interior_errors):
|
|
return True, per_position_errors, iterations_used
|
|
|
|
# Did not converge; return final state
|
|
per_position_errors = _compute_position_errors(
|
|
beliefs, n_vn_positions, n_base, z_pop
|
|
)
|
|
|
|
return False, per_position_errors, iterations_used
|
|
|
|
|
|
def _compute_position_errors(beliefs, n_vn_positions, n_base, z_pop):
|
|
"""Compute per-position error fractions from belief arrays."""
|
|
per_position_errors = []
|
|
for p in range(n_vn_positions):
|
|
wrong = 0
|
|
total = 0
|
|
for c in range(n_base):
|
|
wrong += np.sum(beliefs[p, c] < 0)
|
|
total += z_pop
|
|
per_position_errors.append(wrong / total)
|
|
return per_position_errors
|
|
|
|
|
|
def compute_sc_threshold(B, L=50, w=2, lam_b=0.1, z_pop=10000, tol=0.25,
|
|
cn_mode='normalized', alpha=0.75):
|
|
"""
|
|
Binary search for minimum lam_s where SC density evolution converges.
|
|
|
|
Args:
|
|
B: Base matrix (m_base x n_base).
|
|
L: Chain length.
|
|
w: Coupling width.
|
|
lam_b: Background photon rate.
|
|
z_pop: DE population size.
|
|
tol: Search tolerance (stop when hi - lo <= tol).
|
|
cn_mode: 'offset' or 'normalized'.
|
|
alpha: Scaling factor for normalized mode.
|
|
|
|
Returns:
|
|
Threshold lam_s* (upper bound from binary search).
|
|
"""
|
|
lo = 0.1
|
|
hi = 20.0
|
|
|
|
# Verify hi converges
|
|
converged, _, _ = sc_density_evolution(
|
|
B, L=L, w=w, lam_s=hi, lam_b=lam_b,
|
|
z_pop=z_pop, max_iter=100, cn_mode=cn_mode, alpha=alpha
|
|
)
|
|
if not converged:
|
|
return hi # Doesn't converge even at hi
|
|
|
|
# Verify lo doesn't converge
|
|
converged, _, _ = sc_density_evolution(
|
|
B, L=L, w=w, lam_s=lo, lam_b=lam_b,
|
|
z_pop=z_pop, max_iter=100, cn_mode=cn_mode, alpha=alpha
|
|
)
|
|
if converged:
|
|
return lo # Converges even at lo
|
|
|
|
while hi - lo > tol:
|
|
mid = (lo + hi) / 2
|
|
converged, _, _ = sc_density_evolution(
|
|
B, L=L, w=w, lam_s=mid, lam_b=lam_b,
|
|
z_pop=z_pop, max_iter=100, cn_mode=cn_mode, alpha=alpha
|
|
)
|
|
if converged:
|
|
hi = mid
|
|
else:
|
|
lo = mid
|
|
|
|
return hi
|
|
|
|
|
|
# =============================================================================
|
|
# CLI
|
|
# =============================================================================
|
|
|
|
def run_threshold_comparison(seed=42, z_pop=5000, tol=0.5, L=20):
|
|
"""Compare SC-LDPC and uncoupled DE thresholds."""
|
|
from ldpc_sim import H_BASE
|
|
from density_evolution import (
|
|
compute_threshold, build_de_profile, make_profile
|
|
)
|
|
np.random.seed(seed)
|
|
|
|
print("=" * 60)
|
|
print("SC-LDPC vs Uncoupled Threshold Comparison")
|
|
print("=" * 60)
|
|
|
|
# Uncoupled thresholds
|
|
degrees_opt = [7, 4, 4, 4, 4, 3, 3, 3]
|
|
profile_opt = build_de_profile(degrees_opt, m_base=7)
|
|
profile_orig = make_profile(H_BASE)
|
|
|
|
print("\nUncoupled thresholds:")
|
|
thresh_opt_offset = compute_threshold(
|
|
profile_opt, lam_b=0.1, z_pop=z_pop, tol=tol, cn_mode='offset')
|
|
thresh_opt_norm = compute_threshold(
|
|
profile_opt, lam_b=0.1, z_pop=z_pop, tol=tol,
|
|
cn_mode='normalized', alpha=0.875)
|
|
thresh_orig = compute_threshold(
|
|
profile_orig, lam_b=0.1, z_pop=z_pop, tol=tol, cn_mode='offset')
|
|
print(f" Original staircase (offset): {thresh_orig:.2f} photons/slot")
|
|
print(f" DE-optimized (offset): {thresh_opt_offset:.2f} photons/slot")
|
|
print(f" DE-optimized (normalized 0.875): {thresh_opt_norm:.2f} photons/slot")
|
|
|
|
# SC-LDPC thresholds
|
|
print(f"\nSC-LDPC thresholds (L={L}, w=2, normalized 0.875):")
|
|
sc_thresh_orig = compute_sc_threshold(
|
|
H_BASE, L=L, w=2, lam_b=0.1, z_pop=z_pop, tol=tol,
|
|
cn_mode='normalized', alpha=0.875)
|
|
print(f" SC original staircase: {sc_thresh_orig:.2f} photons/slot")
|
|
|
|
from density_evolution import construct_base_matrix
|
|
H_opt, girth = construct_base_matrix(degrees_opt, z=32, n_trials=500)
|
|
sc_thresh_opt = compute_sc_threshold(
|
|
H_opt, L=L, w=2, lam_b=0.1, z_pop=z_pop, tol=tol,
|
|
cn_mode='normalized', alpha=0.875)
|
|
print(f" SC DE-optimized: {sc_thresh_opt:.2f} photons/slot")
|
|
|
|
shannon_limit = 0.47
|
|
print(f"\n Shannon limit (rate 1/8): {shannon_limit} photons/slot")
|
|
|
|
return {
|
|
'uncoupled_thresholds': {
|
|
'original_offset': float(thresh_orig),
|
|
'optimized_offset': float(thresh_opt_offset),
|
|
'optimized_normalized': float(thresh_opt_norm),
|
|
},
|
|
'sc_thresholds': {
|
|
'sc_original': float(sc_thresh_orig),
|
|
'sc_optimized': float(sc_thresh_opt),
|
|
},
|
|
'shannon_limit': shannon_limit,
|
|
'params': {'L': L, 'w': 2, 'z_pop': z_pop, 'tol': tol},
|
|
}
|
|
|
|
|
|
def run_fer_comparison(seed=42, n_frames=50, L=10, z=32):
|
|
"""FER comparison: SC-LDPC vs uncoupled at Z=32."""
|
|
from ldpc_sim import H_BASE, poisson_channel, quantize_llr
|
|
np.random.seed(seed)
|
|
|
|
print("=" * 60)
|
|
print(f"SC-LDPC vs Uncoupled FER Comparison (Z={z}, L={L})")
|
|
print("=" * 60)
|
|
|
|
m_base, n_base = H_BASE.shape
|
|
|
|
# Build SC chain
|
|
H_sc, components, meta = build_sc_chain(
|
|
H_BASE, L=L, w=2, z=z, seed=seed)
|
|
n_total = H_sc.shape[1]
|
|
|
|
lam_s_points = [2.0, 3.0, 4.0, 5.0, 7.0, 10.0]
|
|
sc_results = {}
|
|
|
|
print(f"\nSC-LDPC (L={L}, w=2, windowed W=5, normalized alpha=0.875):")
|
|
print(f"{'lam_s':>8s} {'FER':>10s} {'BER':>10s}")
|
|
print("-" * 30)
|
|
|
|
for lam_s in lam_s_points:
|
|
frame_errors = 0
|
|
bit_errors = 0
|
|
total_bits = 0
|
|
|
|
for _ in range(n_frames):
|
|
codeword = np.zeros(n_total, dtype=np.int8)
|
|
llr_float, _ = poisson_channel(codeword, lam_s, 0.1)
|
|
llr_q = quantize_llr(llr_float)
|
|
|
|
decoded, converged, iters = windowed_decode(
|
|
llr_q, H_sc, L=L, w=2, z=z, n_base=n_base, m_base=m_base,
|
|
W=5, max_iter=20, cn_mode='normalized', alpha=0.875)
|
|
|
|
errs = np.sum(decoded != 0)
|
|
bit_errors += errs
|
|
total_bits += n_total
|
|
if errs > 0:
|
|
frame_errors += 1
|
|
|
|
fer = frame_errors / n_frames
|
|
ber = bit_errors / total_bits if total_bits > 0 else 0
|
|
sc_results[lam_s] = {'fer': float(fer), 'ber': float(ber)}
|
|
print(f"{lam_s:8.1f} {fer:10.3f} {ber:10.6f}")
|
|
|
|
return {
|
|
'lam_s_points': lam_s_points,
|
|
'sc_fer': {str(k): v for k, v in sc_results.items()},
|
|
'params': {'L': L, 'w': 2, 'z': z, 'n_frames': n_frames},
|
|
}
|
|
|
|
|
|
def run_full_pipeline(seed=42):
|
|
"""Full SC-LDPC pipeline: threshold comparison + FER."""
|
|
print("=" * 70)
|
|
print("SC-LDPC FULL PIPELINE")
|
|
print("=" * 70)
|
|
|
|
# Step 1: Threshold comparison
|
|
print("\n--- Step 1: Threshold Comparison ---")
|
|
threshold_results = run_threshold_comparison(
|
|
seed=seed, z_pop=5000, tol=0.5, L=20)
|
|
|
|
# Step 2: FER comparison
|
|
print("\n--- Step 2: FER Comparison ---")
|
|
fer_results = run_fer_comparison(
|
|
seed=seed, n_frames=50, L=10, z=32)
|
|
|
|
# Combine and save results
|
|
output = {
|
|
**threshold_results,
|
|
'fer_comparison': fer_results,
|
|
}
|
|
|
|
out_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data')
|
|
os.makedirs(out_dir, exist_ok=True)
|
|
out_path = os.path.join(out_dir, 'sc_ldpc_results.json')
|
|
with open(out_path, 'w') as f:
|
|
json.dump(output, f, indent=2, default=str)
|
|
print(f"\nResults saved to {out_path}")
|
|
|
|
return output
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description='SC-LDPC Code Construction and Analysis',
|
|
)
|
|
subparsers = parser.add_subparsers(dest='command')
|
|
|
|
p_thresh = subparsers.add_parser('threshold',
|
|
help='SC-DE threshold comparison')
|
|
p_thresh.add_argument('--seed', type=int, default=42)
|
|
p_thresh.add_argument('--z-pop', type=int, default=5000)
|
|
p_thresh.add_argument('--tol', type=float, default=0.5)
|
|
p_thresh.add_argument('--L', type=int, default=20)
|
|
|
|
p_fer = subparsers.add_parser('fer-compare',
|
|
help='FER: SC vs uncoupled')
|
|
p_fer.add_argument('--seed', type=int, default=42)
|
|
p_fer.add_argument('--n-frames', type=int, default=50)
|
|
p_fer.add_argument('--L', type=int, default=10)
|
|
|
|
p_full = subparsers.add_parser('full', help='Full pipeline')
|
|
p_full.add_argument('--seed', type=int, default=42)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.command == 'threshold':
|
|
run_threshold_comparison(seed=args.seed, z_pop=args.z_pop,
|
|
tol=args.tol, L=args.L)
|
|
elif args.command == 'fer-compare':
|
|
run_fer_comparison(seed=args.seed, n_frames=args.n_frames, L=args.L)
|
|
elif args.command == 'full':
|
|
run_full_pipeline(seed=args.seed)
|
|
else:
|
|
parser.print_help()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|