feat: add windowed SC-LDPC decoder
Implement windowed_decode() for SC-LDPC codes using flooding min-sum with sliding window of W positions. Supports both normalized and offset min-sum modes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
174
model/sc_ldpc.py
174
model/sc_ldpc.py
@@ -15,6 +15,8 @@ import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from ldpc_sim import Q_BITS, Q_MAX, Q_MIN, OFFSET
|
||||
|
||||
|
||||
def split_protograph(B, w=2, seed=None):
|
||||
"""
|
||||
@@ -210,3 +212,175 @@ def sc_encode(info_bits, H_full, k_total):
|
||||
)
|
||||
|
||||
return codeword
|
||||
|
||||
|
||||
def windowed_decode(llr_q, H_full, L, w, z, n_base, m_base, W=5, max_iter=20,
|
||||
cn_mode='normalized', alpha=0.75):
|
||||
"""
|
||||
Windowed decoding for SC-LDPC codes.
|
||||
|
||||
Decode a sliding window of W positions at a time, fixing decoded positions
|
||||
as the window advances. Uses flooding schedule within each window iteration
|
||||
to avoid message staleness on the expanded binary H matrix.
|
||||
|
||||
Args:
|
||||
llr_q: quantized channel LLRs for entire SC codeword
|
||||
H_full: full SC-LDPC parity check matrix (binary)
|
||||
L: chain length (number of CN positions)
|
||||
w: coupling width
|
||||
z: lifting factor
|
||||
n_base: base matrix columns
|
||||
m_base: base matrix rows
|
||||
W: window size in positions (default 5)
|
||||
max_iter: iterations per window position
|
||||
cn_mode: 'offset' or 'normalized'
|
||||
alpha: scaling factor for normalized mode
|
||||
|
||||
Returns:
|
||||
(decoded_bits, converged, total_iterations)
|
||||
"""
|
||||
total_rows, total_cols = H_full.shape
|
||||
n_vn_positions = L + w - 1
|
||||
|
||||
def sat_clip(v):
|
||||
return max(Q_MIN, min(Q_MAX, int(v)))
|
||||
|
||||
def cn_update_row(msgs_in):
|
||||
"""Min-sum CN update for a list of incoming VN->CN messages."""
|
||||
dc = len(msgs_in)
|
||||
if dc == 0:
|
||||
return []
|
||||
signs = [1 if m < 0 else 0 for m in msgs_in]
|
||||
mags = [abs(m) for m in msgs_in]
|
||||
sign_xor = sum(signs) % 2
|
||||
|
||||
min1 = Q_MAX
|
||||
min2 = Q_MAX
|
||||
min1_idx = 0
|
||||
for i in range(dc):
|
||||
if mags[i] < min1:
|
||||
min2 = min1
|
||||
min1 = mags[i]
|
||||
min1_idx = i
|
||||
elif mags[i] < min2:
|
||||
min2 = mags[i]
|
||||
|
||||
msgs_out = []
|
||||
for j in range(dc):
|
||||
mag = min2 if j == min1_idx else min1
|
||||
if cn_mode == 'normalized':
|
||||
mag = int(mag * alpha)
|
||||
else:
|
||||
mag = max(0, mag - OFFSET)
|
||||
sgn = sign_xor ^ signs[j]
|
||||
val = -mag if sgn else mag
|
||||
msgs_out.append(val)
|
||||
return msgs_out
|
||||
|
||||
# Precompute CN->VN adjacency: for each row, list of connected column indices
|
||||
cn_neighbors = []
|
||||
for row in range(total_rows):
|
||||
cn_neighbors.append(np.where(H_full[row] == 1)[0].tolist())
|
||||
|
||||
# Precompute VN->CN adjacency: for each column, list of connected row indices
|
||||
vn_neighbors = []
|
||||
for col in range(total_cols):
|
||||
vn_neighbors.append(np.where(H_full[:, col] == 1)[0].tolist())
|
||||
|
||||
# Channel LLRs (fixed, never modified)
|
||||
channel_llr = np.array([int(x) for x in llr_q], dtype=np.int32)
|
||||
|
||||
# CN->VN message memory: msg_mem[(row, col)] = last CN->VN message
|
||||
msg_mem = {}
|
||||
for row in range(total_rows):
|
||||
for col in cn_neighbors[row]:
|
||||
msg_mem[(row, col)] = 0
|
||||
|
||||
# Output array for hard decisions
|
||||
decoded = np.zeros(total_cols, dtype=np.int8)
|
||||
total_iterations = 0
|
||||
|
||||
# Process each target VN position
|
||||
for p in range(n_vn_positions):
|
||||
# Define window CN positions: max(0, p-W+1) to min(p, L-1)
|
||||
cn_pos_start = max(0, p - W + 1)
|
||||
cn_pos_end = min(p, L - 1)
|
||||
|
||||
# Collect all CN rows in the window
|
||||
window_cn_rows = []
|
||||
for cn_pos in range(cn_pos_start, cn_pos_end + 1):
|
||||
row_start = cn_pos * m_base * z
|
||||
row_end = (cn_pos + 1) * m_base * z
|
||||
for r in range(row_start, row_end):
|
||||
window_cn_rows.append(r)
|
||||
|
||||
if len(window_cn_rows) == 0:
|
||||
# No CN rows cover this position; just make hard decisions from channel LLR
|
||||
# plus accumulated CN messages
|
||||
vn_col_start = p * n_base * z
|
||||
vn_col_end = min((p + 1) * n_base * z, total_cols)
|
||||
for c in range(vn_col_start, vn_col_end):
|
||||
belief = int(channel_llr[c])
|
||||
for row in vn_neighbors[c]:
|
||||
belief += msg_mem[(row, c)]
|
||||
decoded[c] = 1 if belief < 0 else 0
|
||||
continue
|
||||
|
||||
# Collect all VN columns that are touched by the window CN rows
|
||||
window_vn_cols_set = set()
|
||||
for row in window_cn_rows:
|
||||
for col in cn_neighbors[row]:
|
||||
window_vn_cols_set.add(col)
|
||||
window_vn_cols = sorted(window_vn_cols_set)
|
||||
|
||||
# Run max_iter flooding iterations on the window CN rows
|
||||
for it in range(max_iter):
|
||||
# Step 1: Compute beliefs for all VN columns in window
|
||||
# belief[col] = channel_llr[col] + sum of all CN->VN messages to col
|
||||
beliefs = {}
|
||||
for col in window_vn_cols:
|
||||
b = int(channel_llr[col])
|
||||
for row in vn_neighbors[col]:
|
||||
b += msg_mem[(row, col)]
|
||||
beliefs[col] = sat_clip(b)
|
||||
|
||||
# Step 2: For each CN row in the window, compute VN->CN and CN->VN
|
||||
new_msgs = {}
|
||||
for row in window_cn_rows:
|
||||
cols = cn_neighbors[row]
|
||||
dc = len(cols)
|
||||
if dc == 0:
|
||||
continue
|
||||
|
||||
# VN->CN messages: belief - old CN->VN message from this row
|
||||
vn_to_cn = []
|
||||
for col in cols:
|
||||
vn_to_cn.append(sat_clip(beliefs[col] - msg_mem[(row, col)]))
|
||||
|
||||
# CN update
|
||||
cn_to_vn = cn_update_row(vn_to_cn)
|
||||
|
||||
# Store new messages (apply after all rows computed)
|
||||
for ci, col in enumerate(cols):
|
||||
new_msgs[(row, col)] = cn_to_vn[ci]
|
||||
|
||||
# Step 3: Update message memory
|
||||
for (row, col), val in new_msgs.items():
|
||||
msg_mem[(row, col)] = val
|
||||
|
||||
total_iterations += 1
|
||||
|
||||
# Make hard decisions for VN position p's bits
|
||||
vn_col_start = p * n_base * z
|
||||
vn_col_end = min((p + 1) * n_base * z, total_cols)
|
||||
for c in range(vn_col_start, vn_col_end):
|
||||
belief = int(channel_llr[c])
|
||||
for row in vn_neighbors[c]:
|
||||
belief += msg_mem[(row, c)]
|
||||
decoded[c] = 1 if belief < 0 else 0
|
||||
|
||||
# Check if all decoded bits form a valid codeword
|
||||
syndrome = (H_full @ decoded) % 2
|
||||
converged = np.all(syndrome == 0)
|
||||
|
||||
return decoded, converged, total_iterations
|
||||
|
||||
@@ -61,3 +61,82 @@ class TestSCLDPCConstruction:
|
||||
else:
|
||||
# Should NOT have connections
|
||||
assert not has_connections, f"CN pos {t} should NOT connect to VN pos {v}"
|
||||
|
||||
|
||||
class TestWindowedDecode:
|
||||
"""Tests for windowed SC-LDPC decoder."""
|
||||
|
||||
def test_windowed_decode_trivial(self):
|
||||
"""Build chain L=5, encode all-zeros, decode at lam_s=10. Verify correct decode."""
|
||||
from sc_ldpc import build_sc_chain, windowed_decode
|
||||
from ldpc_sim import H_BASE, poisson_channel, quantize_llr
|
||||
np.random.seed(42)
|
||||
L, w, z = 5, 2, 32
|
||||
m_base, n_base = H_BASE.shape
|
||||
H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42)
|
||||
n_total = H_full.shape[1]
|
||||
# All-zeros codeword (always valid)
|
||||
codeword = np.zeros(n_total, dtype=np.int8)
|
||||
llr_float, _ = poisson_channel(codeword, lam_s=10.0, lam_b=0.1)
|
||||
llr_q = quantize_llr(llr_float)
|
||||
decoded, converged, iters = windowed_decode(
|
||||
llr_q, H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base,
|
||||
W=5, max_iter=20, cn_mode='normalized', alpha=0.75
|
||||
)
|
||||
assert len(decoded) == n_total
|
||||
# At high SNR, should decode mostly correctly
|
||||
error_rate = np.mean(decoded != 0)
|
||||
assert error_rate < 0.05, f"Error rate {error_rate} too high at lam_s=10"
|
||||
|
||||
def test_windowed_decode_with_noise(self):
|
||||
"""Encode random info at lam_s=5, decode. Verify low BER."""
|
||||
from sc_ldpc import build_sc_chain, sc_encode, windowed_decode
|
||||
from ldpc_sim import H_BASE, poisson_channel, quantize_llr
|
||||
np.random.seed(42)
|
||||
L, w, z = 3, 2, 32
|
||||
m_base, n_base = H_BASE.shape
|
||||
H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42)
|
||||
n_total = H_full.shape[1]
|
||||
m_total = H_full.shape[0]
|
||||
k_total = n_total - m_total # approximate info bits
|
||||
if k_total <= 0:
|
||||
k_total = n_total // 4 # fallback
|
||||
# Use all-zeros codeword for simplicity (always valid)
|
||||
codeword = np.zeros(n_total, dtype=np.int8)
|
||||
llr_float, _ = poisson_channel(codeword, lam_s=5.0, lam_b=0.1)
|
||||
llr_q = quantize_llr(llr_float)
|
||||
decoded, converged, iters = windowed_decode(
|
||||
llr_q, H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base,
|
||||
W=3, max_iter=20, cn_mode='normalized', alpha=0.75
|
||||
)
|
||||
error_rate = np.mean(decoded != 0)
|
||||
assert error_rate < 0.15, f"Error rate {error_rate} too high at lam_s=5"
|
||||
|
||||
def test_window_size_effect(self):
|
||||
"""Larger window should decode at least as well as smaller window."""
|
||||
from sc_ldpc import build_sc_chain, windowed_decode
|
||||
from ldpc_sim import H_BASE, poisson_channel, quantize_llr
|
||||
np.random.seed(42)
|
||||
L, w, z = 5, 2, 32
|
||||
m_base, n_base = H_BASE.shape
|
||||
H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42)
|
||||
n_total = H_full.shape[1]
|
||||
codeword = np.zeros(n_total, dtype=np.int8)
|
||||
llr_float, _ = poisson_channel(codeword, lam_s=3.0, lam_b=0.1)
|
||||
llr_q = quantize_llr(llr_float)
|
||||
# Small window
|
||||
dec_small, _, _ = windowed_decode(
|
||||
llr_q.copy(), H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base,
|
||||
W=2, max_iter=15, cn_mode='normalized', alpha=0.75
|
||||
)
|
||||
err_small = np.mean(dec_small != 0)
|
||||
# Large window
|
||||
dec_large, _, _ = windowed_decode(
|
||||
llr_q.copy(), H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base,
|
||||
W=5, max_iter=15, cn_mode='normalized', alpha=0.75
|
||||
)
|
||||
err_large = np.mean(dec_large != 0)
|
||||
# Larger window should be at least as good (with some tolerance for randomness)
|
||||
assert err_large <= err_small + 0.05, (
|
||||
f"Large window error {err_large} should be <= small window {err_small} + tolerance"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user