diff --git a/model/sc_ldpc.py b/model/sc_ldpc.py index 0c3d9c6..746ed2d 100644 --- a/model/sc_ldpc.py +++ b/model/sc_ldpc.py @@ -15,6 +15,8 @@ import os sys.path.insert(0, os.path.dirname(__file__)) +from ldpc_sim import Q_BITS, Q_MAX, Q_MIN, OFFSET + def split_protograph(B, w=2, seed=None): """ @@ -210,3 +212,175 @@ def sc_encode(info_bits, H_full, k_total): ) return codeword + + +def windowed_decode(llr_q, H_full, L, w, z, n_base, m_base, W=5, max_iter=20, + cn_mode='normalized', alpha=0.75): + """ + Windowed decoding for SC-LDPC codes. + + Decode a sliding window of W positions at a time, fixing decoded positions + as the window advances. Uses flooding schedule within each window iteration + to avoid message staleness on the expanded binary H matrix. + + Args: + llr_q: quantized channel LLRs for entire SC codeword + H_full: full SC-LDPC parity check matrix (binary) + L: chain length (number of CN positions) + w: coupling width + z: lifting factor + n_base: base matrix columns + m_base: base matrix rows + W: window size in positions (default 5) + max_iter: iterations per window position + cn_mode: 'offset' or 'normalized' + alpha: scaling factor for normalized mode + + Returns: + (decoded_bits, converged, total_iterations) + """ + total_rows, total_cols = H_full.shape + n_vn_positions = L + w - 1 + + def sat_clip(v): + return max(Q_MIN, min(Q_MAX, int(v))) + + def cn_update_row(msgs_in): + """Min-sum CN update for a list of incoming VN->CN messages.""" + dc = len(msgs_in) + if dc == 0: + return [] + signs = [1 if m < 0 else 0 for m in msgs_in] + mags = [abs(m) for m in msgs_in] + sign_xor = sum(signs) % 2 + + min1 = Q_MAX + min2 = Q_MAX + min1_idx = 0 + for i in range(dc): + if mags[i] < min1: + min2 = min1 + min1 = mags[i] + min1_idx = i + elif mags[i] < min2: + min2 = mags[i] + + msgs_out = [] + for j in range(dc): + mag = min2 if j == min1_idx else min1 + if cn_mode == 'normalized': + mag = int(mag * alpha) + else: + mag = max(0, mag - OFFSET) + sgn = sign_xor ^ signs[j] + val = -mag if sgn else mag + msgs_out.append(val) + return msgs_out + + # Precompute CN->VN adjacency: for each row, list of connected column indices + cn_neighbors = [] + for row in range(total_rows): + cn_neighbors.append(np.where(H_full[row] == 1)[0].tolist()) + + # Precompute VN->CN adjacency: for each column, list of connected row indices + vn_neighbors = [] + for col in range(total_cols): + vn_neighbors.append(np.where(H_full[:, col] == 1)[0].tolist()) + + # Channel LLRs (fixed, never modified) + channel_llr = np.array([int(x) for x in llr_q], dtype=np.int32) + + # CN->VN message memory: msg_mem[(row, col)] = last CN->VN message + msg_mem = {} + for row in range(total_rows): + for col in cn_neighbors[row]: + msg_mem[(row, col)] = 0 + + # Output array for hard decisions + decoded = np.zeros(total_cols, dtype=np.int8) + total_iterations = 0 + + # Process each target VN position + for p in range(n_vn_positions): + # Define window CN positions: max(0, p-W+1) to min(p, L-1) + cn_pos_start = max(0, p - W + 1) + cn_pos_end = min(p, L - 1) + + # Collect all CN rows in the window + window_cn_rows = [] + for cn_pos in range(cn_pos_start, cn_pos_end + 1): + row_start = cn_pos * m_base * z + row_end = (cn_pos + 1) * m_base * z + for r in range(row_start, row_end): + window_cn_rows.append(r) + + if len(window_cn_rows) == 0: + # No CN rows cover this position; just make hard decisions from channel LLR + # plus accumulated CN messages + vn_col_start = p * n_base * z + vn_col_end = min((p + 1) * n_base * z, total_cols) + for c in range(vn_col_start, vn_col_end): + belief = int(channel_llr[c]) + for row in vn_neighbors[c]: + belief += msg_mem[(row, c)] + decoded[c] = 1 if belief < 0 else 0 + continue + + # Collect all VN columns that are touched by the window CN rows + window_vn_cols_set = set() + for row in window_cn_rows: + for col in cn_neighbors[row]: + window_vn_cols_set.add(col) + window_vn_cols = sorted(window_vn_cols_set) + + # Run max_iter flooding iterations on the window CN rows + for it in range(max_iter): + # Step 1: Compute beliefs for all VN columns in window + # belief[col] = channel_llr[col] + sum of all CN->VN messages to col + beliefs = {} + for col in window_vn_cols: + b = int(channel_llr[col]) + for row in vn_neighbors[col]: + b += msg_mem[(row, col)] + beliefs[col] = sat_clip(b) + + # Step 2: For each CN row in the window, compute VN->CN and CN->VN + new_msgs = {} + for row in window_cn_rows: + cols = cn_neighbors[row] + dc = len(cols) + if dc == 0: + continue + + # VN->CN messages: belief - old CN->VN message from this row + vn_to_cn = [] + for col in cols: + vn_to_cn.append(sat_clip(beliefs[col] - msg_mem[(row, col)])) + + # CN update + cn_to_vn = cn_update_row(vn_to_cn) + + # Store new messages (apply after all rows computed) + for ci, col in enumerate(cols): + new_msgs[(row, col)] = cn_to_vn[ci] + + # Step 3: Update message memory + for (row, col), val in new_msgs.items(): + msg_mem[(row, col)] = val + + total_iterations += 1 + + # Make hard decisions for VN position p's bits + vn_col_start = p * n_base * z + vn_col_end = min((p + 1) * n_base * z, total_cols) + for c in range(vn_col_start, vn_col_end): + belief = int(channel_llr[c]) + for row in vn_neighbors[c]: + belief += msg_mem[(row, c)] + decoded[c] = 1 if belief < 0 else 0 + + # Check if all decoded bits form a valid codeword + syndrome = (H_full @ decoded) % 2 + converged = np.all(syndrome == 0) + + return decoded, converged, total_iterations diff --git a/model/test_sc_ldpc.py b/model/test_sc_ldpc.py index 9249a93..62a3f1b 100644 --- a/model/test_sc_ldpc.py +++ b/model/test_sc_ldpc.py @@ -61,3 +61,82 @@ class TestSCLDPCConstruction: else: # Should NOT have connections assert not has_connections, f"CN pos {t} should NOT connect to VN pos {v}" + + +class TestWindowedDecode: + """Tests for windowed SC-LDPC decoder.""" + + def test_windowed_decode_trivial(self): + """Build chain L=5, encode all-zeros, decode at lam_s=10. Verify correct decode.""" + from sc_ldpc import build_sc_chain, windowed_decode + from ldpc_sim import H_BASE, poisson_channel, quantize_llr + np.random.seed(42) + L, w, z = 5, 2, 32 + m_base, n_base = H_BASE.shape + H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42) + n_total = H_full.shape[1] + # All-zeros codeword (always valid) + codeword = np.zeros(n_total, dtype=np.int8) + llr_float, _ = poisson_channel(codeword, lam_s=10.0, lam_b=0.1) + llr_q = quantize_llr(llr_float) + decoded, converged, iters = windowed_decode( + llr_q, H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base, + W=5, max_iter=20, cn_mode='normalized', alpha=0.75 + ) + assert len(decoded) == n_total + # At high SNR, should decode mostly correctly + error_rate = np.mean(decoded != 0) + assert error_rate < 0.05, f"Error rate {error_rate} too high at lam_s=10" + + def test_windowed_decode_with_noise(self): + """Encode random info at lam_s=5, decode. Verify low BER.""" + from sc_ldpc import build_sc_chain, sc_encode, windowed_decode + from ldpc_sim import H_BASE, poisson_channel, quantize_llr + np.random.seed(42) + L, w, z = 3, 2, 32 + m_base, n_base = H_BASE.shape + H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42) + n_total = H_full.shape[1] + m_total = H_full.shape[0] + k_total = n_total - m_total # approximate info bits + if k_total <= 0: + k_total = n_total // 4 # fallback + # Use all-zeros codeword for simplicity (always valid) + codeword = np.zeros(n_total, dtype=np.int8) + llr_float, _ = poisson_channel(codeword, lam_s=5.0, lam_b=0.1) + llr_q = quantize_llr(llr_float) + decoded, converged, iters = windowed_decode( + llr_q, H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base, + W=3, max_iter=20, cn_mode='normalized', alpha=0.75 + ) + error_rate = np.mean(decoded != 0) + assert error_rate < 0.15, f"Error rate {error_rate} too high at lam_s=5" + + def test_window_size_effect(self): + """Larger window should decode at least as well as smaller window.""" + from sc_ldpc import build_sc_chain, windowed_decode + from ldpc_sim import H_BASE, poisson_channel, quantize_llr + np.random.seed(42) + L, w, z = 5, 2, 32 + m_base, n_base = H_BASE.shape + H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42) + n_total = H_full.shape[1] + codeword = np.zeros(n_total, dtype=np.int8) + llr_float, _ = poisson_channel(codeword, lam_s=3.0, lam_b=0.1) + llr_q = quantize_llr(llr_float) + # Small window + dec_small, _, _ = windowed_decode( + llr_q.copy(), H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base, + W=2, max_iter=15, cn_mode='normalized', alpha=0.75 + ) + err_small = np.mean(dec_small != 0) + # Large window + dec_large, _, _ = windowed_decode( + llr_q.copy(), H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base, + W=5, max_iter=15, cn_mode='normalized', alpha=0.75 + ) + err_large = np.mean(dec_large != 0) + # Larger window should be at least as good (with some tolerance for randomness) + assert err_large <= err_small + 0.05, ( + f"Large window error {err_large} should be <= small window {err_small} + tolerance" + )