Files
ldpc_optical/model/sc_ldpc.py
cah 5f69de6cb8 feat: add windowed SC-LDPC decoder
Implement windowed_decode() for SC-LDPC codes using flooding
min-sum with sliding window of W positions. Supports both
normalized and offset min-sum modes.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 17:08:06 -07:00

387 lines
13 KiB
Python

#!/usr/bin/env python3
"""
Spatially-Coupled LDPC (SC-LDPC) Code Construction and Decoding
SC-LDPC codes achieve threshold saturation: the BP threshold approaches
the MAP threshold, closing a significant portion of the gap to Shannon limit.
Construction: replicate a protograph base matrix along a chain of L positions
with coupling width w, creating a convolutional-like structure.
"""
import numpy as np
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
from ldpc_sim import Q_BITS, Q_MAX, Q_MIN, OFFSET
def split_protograph(B, w=2, seed=None):
"""
Split a protograph base matrix into w component matrices.
Each edge in B (entry >= 0) is randomly assigned to exactly one of the
w component matrices. The component that receives the edge gets value 0
(circulant shift assigned later during chain construction), while all
other components get -1 (no connection) at that position.
Args:
B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected.
w: Coupling width (number of component matrices).
seed: Random seed for reproducibility.
Returns:
List of w component matrices, each with shape (m_base, n_base).
Component values: 0 where edge is assigned, -1 otherwise.
"""
rng = np.random.default_rng(seed)
m_base, n_base = B.shape
# Initialize all components to -1 (no connection)
components = [np.full((m_base, n_base), -1, dtype=np.int16) for _ in range(w)]
for r in range(m_base):
for c in range(n_base):
if B[r, c] >= 0:
# Randomly assign this edge to one component
chosen = rng.integers(0, w)
components[chosen][r, c] = 0
return components
def build_sc_chain(B, L=20, w=2, z=32, seed=None):
"""
Build the full SC-LDPC parity-check matrix as a dense binary matrix.
The chain has L CN positions and L+w-1 VN positions. Position t's CNs
connect to VN positions t, t+1, ..., t+w-1 using the w component matrices.
Args:
B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected.
L: Chain length (number of CN positions).
w: Coupling width.
z: Lifting factor (circulant size).
seed: Random seed for reproducibility.
Returns:
H_full: Binary parity-check matrix of shape
(L * m_base * z) x ((L + w - 1) * n_base * z).
components: List of w component matrices from split_protograph.
meta: Dictionary with construction metadata.
"""
rng = np.random.default_rng(seed)
m_base, n_base = B.shape
# Split the protograph into w components
# Use a sub-seed derived from the main seed for splitting
split_seed = int(rng.integers(0, 2**31)) if seed is not None else None
components = split_protograph(B, w=w, seed=split_seed)
n_cn_positions = L
n_vn_positions = L + w - 1
total_rows = n_cn_positions * m_base * z
total_cols = n_vn_positions * n_base * z
H_full = np.zeros((total_rows, total_cols), dtype=np.int8)
# For each CN position t, for each component i, connect CN group t
# to VN group t+i using component B_i with random circulant shifts.
for t in range(L):
for i in range(w):
vn_pos = t + i # VN position this component connects to
comp = components[i]
for r in range(m_base):
for c in range(n_base):
if comp[r, c] >= 0:
# This entry is connected; assign a random circulant shift
shift = int(rng.integers(0, z))
# Place the z x z circulant permutation matrix
# CN rows: [t * m_base * z + r * z, ... + (r+1)*z)
# VN cols: [vn_pos * n_base * z + c * z, ... + (c+1)*z)
for zz in range(z):
row_idx = t * m_base * z + r * z + zz
col_idx = vn_pos * n_base * z + c * z + (zz + shift) % z
H_full[row_idx, col_idx] = 1
meta = {
'L': L,
'w': w,
'z': z,
'm_base': m_base,
'n_base': n_base,
'total_rows': total_rows,
'total_cols': total_cols,
'n_cn_positions': n_cn_positions,
'n_vn_positions': n_vn_positions,
'rate_design': 1.0 - (total_rows / total_cols),
}
return H_full, components, meta
def sc_encode(info_bits, H_full, k_total):
"""
Encode using GF(2) Gaussian elimination on the SC-LDPC parity-check matrix.
Places info_bits in the first k_total positions of the codeword and solves
for the remaining parity bits such that H_full * codeword = 0 (mod 2).
Handles rank-deficient H matrices (common with SC-LDPC boundary effects)
by leaving free variables as zero.
Args:
info_bits: Information bits array of length k_total.
H_full: Binary parity-check matrix (m x n).
k_total: Number of information bits.
Returns:
codeword: Binary codeword array of length n (= H_full.shape[1]).
Raises:
ValueError: If encoding fails (syndrome is nonzero).
"""
m_full, n_full = H_full.shape
n_parity = n_full - k_total
assert len(info_bits) == k_total, (
f"info_bits length {len(info_bits)} != k_total {k_total}"
)
# Split H into information and parity parts
# H_full = [H_info | H_parity]
H_info = H_full[:, :k_total]
H_parity = H_full[:, k_total:]
# Compute RHS: H_info * info_bits mod 2
rhs = (H_info @ info_bits) % 2
# Solve H_parity * p = rhs (mod 2) via Gaussian elimination
# Build augmented matrix [H_parity | rhs]
aug = np.zeros((m_full, n_parity + 1), dtype=np.int8)
aug[:, :n_parity] = H_parity.copy()
aug[:, n_parity] = rhs
# Forward elimination with partial pivoting
pivot_row = 0
pivot_cols = []
for col in range(n_parity):
# Find a pivot row for this column
found = -1
for r in range(pivot_row, m_full):
if aug[r, col] == 1:
found = r
break
if found < 0:
# No pivot for this column (rank deficient) - skip
continue
# Swap pivot row into position
if found != pivot_row:
aug[[pivot_row, found]] = aug[[found, pivot_row]]
# Eliminate all other rows with a 1 in this column
for r in range(m_full):
if r != pivot_row and aug[r, col] == 1:
aug[r] = (aug[r] + aug[pivot_row]) % 2
pivot_cols.append(col)
pivot_row += 1
# Back-substitute: extract parity bit values from pivot columns
parity = np.zeros(n_parity, dtype=np.int8)
for i, col in enumerate(pivot_cols):
parity[col] = aug[i, n_parity]
# Assemble codeword: [info_bits | parity]
codeword = np.concatenate([info_bits.astype(np.int8), parity])
# Verify syndrome
syndrome = (H_full @ codeword) % 2
if not np.all(syndrome == 0):
raise ValueError(
f"SC-LDPC encoding failed: syndrome weight = {syndrome.sum()}, "
f"H rank ~{len(pivot_cols)}, H rows = {m_full}"
)
return codeword
def windowed_decode(llr_q, H_full, L, w, z, n_base, m_base, W=5, max_iter=20,
cn_mode='normalized', alpha=0.75):
"""
Windowed decoding for SC-LDPC codes.
Decode a sliding window of W positions at a time, fixing decoded positions
as the window advances. Uses flooding schedule within each window iteration
to avoid message staleness on the expanded binary H matrix.
Args:
llr_q: quantized channel LLRs for entire SC codeword
H_full: full SC-LDPC parity check matrix (binary)
L: chain length (number of CN positions)
w: coupling width
z: lifting factor
n_base: base matrix columns
m_base: base matrix rows
W: window size in positions (default 5)
max_iter: iterations per window position
cn_mode: 'offset' or 'normalized'
alpha: scaling factor for normalized mode
Returns:
(decoded_bits, converged, total_iterations)
"""
total_rows, total_cols = H_full.shape
n_vn_positions = L + w - 1
def sat_clip(v):
return max(Q_MIN, min(Q_MAX, int(v)))
def cn_update_row(msgs_in):
"""Min-sum CN update for a list of incoming VN->CN messages."""
dc = len(msgs_in)
if dc == 0:
return []
signs = [1 if m < 0 else 0 for m in msgs_in]
mags = [abs(m) for m in msgs_in]
sign_xor = sum(signs) % 2
min1 = Q_MAX
min2 = Q_MAX
min1_idx = 0
for i in range(dc):
if mags[i] < min1:
min2 = min1
min1 = mags[i]
min1_idx = i
elif mags[i] < min2:
min2 = mags[i]
msgs_out = []
for j in range(dc):
mag = min2 if j == min1_idx else min1
if cn_mode == 'normalized':
mag = int(mag * alpha)
else:
mag = max(0, mag - OFFSET)
sgn = sign_xor ^ signs[j]
val = -mag if sgn else mag
msgs_out.append(val)
return msgs_out
# Precompute CN->VN adjacency: for each row, list of connected column indices
cn_neighbors = []
for row in range(total_rows):
cn_neighbors.append(np.where(H_full[row] == 1)[0].tolist())
# Precompute VN->CN adjacency: for each column, list of connected row indices
vn_neighbors = []
for col in range(total_cols):
vn_neighbors.append(np.where(H_full[:, col] == 1)[0].tolist())
# Channel LLRs (fixed, never modified)
channel_llr = np.array([int(x) for x in llr_q], dtype=np.int32)
# CN->VN message memory: msg_mem[(row, col)] = last CN->VN message
msg_mem = {}
for row in range(total_rows):
for col in cn_neighbors[row]:
msg_mem[(row, col)] = 0
# Output array for hard decisions
decoded = np.zeros(total_cols, dtype=np.int8)
total_iterations = 0
# Process each target VN position
for p in range(n_vn_positions):
# Define window CN positions: max(0, p-W+1) to min(p, L-1)
cn_pos_start = max(0, p - W + 1)
cn_pos_end = min(p, L - 1)
# Collect all CN rows in the window
window_cn_rows = []
for cn_pos in range(cn_pos_start, cn_pos_end + 1):
row_start = cn_pos * m_base * z
row_end = (cn_pos + 1) * m_base * z
for r in range(row_start, row_end):
window_cn_rows.append(r)
if len(window_cn_rows) == 0:
# No CN rows cover this position; just make hard decisions from channel LLR
# plus accumulated CN messages
vn_col_start = p * n_base * z
vn_col_end = min((p + 1) * n_base * z, total_cols)
for c in range(vn_col_start, vn_col_end):
belief = int(channel_llr[c])
for row in vn_neighbors[c]:
belief += msg_mem[(row, c)]
decoded[c] = 1 if belief < 0 else 0
continue
# Collect all VN columns that are touched by the window CN rows
window_vn_cols_set = set()
for row in window_cn_rows:
for col in cn_neighbors[row]:
window_vn_cols_set.add(col)
window_vn_cols = sorted(window_vn_cols_set)
# Run max_iter flooding iterations on the window CN rows
for it in range(max_iter):
# Step 1: Compute beliefs for all VN columns in window
# belief[col] = channel_llr[col] + sum of all CN->VN messages to col
beliefs = {}
for col in window_vn_cols:
b = int(channel_llr[col])
for row in vn_neighbors[col]:
b += msg_mem[(row, col)]
beliefs[col] = sat_clip(b)
# Step 2: For each CN row in the window, compute VN->CN and CN->VN
new_msgs = {}
for row in window_cn_rows:
cols = cn_neighbors[row]
dc = len(cols)
if dc == 0:
continue
# VN->CN messages: belief - old CN->VN message from this row
vn_to_cn = []
for col in cols:
vn_to_cn.append(sat_clip(beliefs[col] - msg_mem[(row, col)]))
# CN update
cn_to_vn = cn_update_row(vn_to_cn)
# Store new messages (apply after all rows computed)
for ci, col in enumerate(cols):
new_msgs[(row, col)] = cn_to_vn[ci]
# Step 3: Update message memory
for (row, col), val in new_msgs.items():
msg_mem[(row, col)] = val
total_iterations += 1
# Make hard decisions for VN position p's bits
vn_col_start = p * n_base * z
vn_col_end = min((p + 1) * n_base * z, total_cols)
for c in range(vn_col_start, vn_col_end):
belief = int(channel_llr[c])
for row in vn_neighbors[c]:
belief += msg_mem[(row, c)]
decoded[c] = 1 if belief < 0 else 0
# Check if all decoded bits form a valid codeword
syndrome = (H_full @ decoded) % 2
converged = np.all(syndrome == 0)
return decoded, converged, total_iterations