#!/usr/bin/env python3 """ Spatially-Coupled LDPC (SC-LDPC) Code Construction and Decoding SC-LDPC codes achieve threshold saturation: the BP threshold approaches the MAP threshold, closing a significant portion of the gap to Shannon limit. Construction: replicate a protograph base matrix along a chain of L positions with coupling width w, creating a convolutional-like structure. """ import numpy as np import sys import os sys.path.insert(0, os.path.dirname(__file__)) from ldpc_sim import Q_BITS, Q_MAX, Q_MIN, OFFSET def split_protograph(B, w=2, seed=None): """ Split a protograph base matrix into w component matrices. Each edge in B (entry >= 0) is randomly assigned to exactly one of the w component matrices. The component that receives the edge gets value 0 (circulant shift assigned later during chain construction), while all other components get -1 (no connection) at that position. Args: B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected. w: Coupling width (number of component matrices). seed: Random seed for reproducibility. Returns: List of w component matrices, each with shape (m_base, n_base). Component values: 0 where edge is assigned, -1 otherwise. """ rng = np.random.default_rng(seed) m_base, n_base = B.shape # Initialize all components to -1 (no connection) components = [np.full((m_base, n_base), -1, dtype=np.int16) for _ in range(w)] for r in range(m_base): for c in range(n_base): if B[r, c] >= 0: # Randomly assign this edge to one component chosen = rng.integers(0, w) components[chosen][r, c] = 0 return components def build_sc_chain(B, L=20, w=2, z=32, seed=None): """ Build the full SC-LDPC parity-check matrix as a dense binary matrix. The chain has L CN positions and L+w-1 VN positions. Position t's CNs connect to VN positions t, t+1, ..., t+w-1 using the w component matrices. Args: B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected. L: Chain length (number of CN positions). w: Coupling width. z: Lifting factor (circulant size). seed: Random seed for reproducibility. Returns: H_full: Binary parity-check matrix of shape (L * m_base * z) x ((L + w - 1) * n_base * z). components: List of w component matrices from split_protograph. meta: Dictionary with construction metadata. """ rng = np.random.default_rng(seed) m_base, n_base = B.shape # Split the protograph into w components # Use a sub-seed derived from the main seed for splitting split_seed = int(rng.integers(0, 2**31)) if seed is not None else None components = split_protograph(B, w=w, seed=split_seed) n_cn_positions = L n_vn_positions = L + w - 1 total_rows = n_cn_positions * m_base * z total_cols = n_vn_positions * n_base * z H_full = np.zeros((total_rows, total_cols), dtype=np.int8) # For each CN position t, for each component i, connect CN group t # to VN group t+i using component B_i with random circulant shifts. for t in range(L): for i in range(w): vn_pos = t + i # VN position this component connects to comp = components[i] for r in range(m_base): for c in range(n_base): if comp[r, c] >= 0: # This entry is connected; assign a random circulant shift shift = int(rng.integers(0, z)) # Place the z x z circulant permutation matrix # CN rows: [t * m_base * z + r * z, ... + (r+1)*z) # VN cols: [vn_pos * n_base * z + c * z, ... + (c+1)*z) for zz in range(z): row_idx = t * m_base * z + r * z + zz col_idx = vn_pos * n_base * z + c * z + (zz + shift) % z H_full[row_idx, col_idx] = 1 meta = { 'L': L, 'w': w, 'z': z, 'm_base': m_base, 'n_base': n_base, 'total_rows': total_rows, 'total_cols': total_cols, 'n_cn_positions': n_cn_positions, 'n_vn_positions': n_vn_positions, 'rate_design': 1.0 - (total_rows / total_cols), } return H_full, components, meta def sc_encode(info_bits, H_full, k_total): """ Encode using GF(2) Gaussian elimination on the SC-LDPC parity-check matrix. Places info_bits in the first k_total positions of the codeword and solves for the remaining parity bits such that H_full * codeword = 0 (mod 2). Handles rank-deficient H matrices (common with SC-LDPC boundary effects) by leaving free variables as zero. Args: info_bits: Information bits array of length k_total. H_full: Binary parity-check matrix (m x n). k_total: Number of information bits. Returns: codeword: Binary codeword array of length n (= H_full.shape[1]). Raises: ValueError: If encoding fails (syndrome is nonzero). """ m_full, n_full = H_full.shape n_parity = n_full - k_total assert len(info_bits) == k_total, ( f"info_bits length {len(info_bits)} != k_total {k_total}" ) # Split H into information and parity parts # H_full = [H_info | H_parity] H_info = H_full[:, :k_total] H_parity = H_full[:, k_total:] # Compute RHS: H_info * info_bits mod 2 rhs = (H_info @ info_bits) % 2 # Solve H_parity * p = rhs (mod 2) via Gaussian elimination # Build augmented matrix [H_parity | rhs] aug = np.zeros((m_full, n_parity + 1), dtype=np.int8) aug[:, :n_parity] = H_parity.copy() aug[:, n_parity] = rhs # Forward elimination with partial pivoting pivot_row = 0 pivot_cols = [] for col in range(n_parity): # Find a pivot row for this column found = -1 for r in range(pivot_row, m_full): if aug[r, col] == 1: found = r break if found < 0: # No pivot for this column (rank deficient) - skip continue # Swap pivot row into position if found != pivot_row: aug[[pivot_row, found]] = aug[[found, pivot_row]] # Eliminate all other rows with a 1 in this column for r in range(m_full): if r != pivot_row and aug[r, col] == 1: aug[r] = (aug[r] + aug[pivot_row]) % 2 pivot_cols.append(col) pivot_row += 1 # Back-substitute: extract parity bit values from pivot columns parity = np.zeros(n_parity, dtype=np.int8) for i, col in enumerate(pivot_cols): parity[col] = aug[i, n_parity] # Assemble codeword: [info_bits | parity] codeword = np.concatenate([info_bits.astype(np.int8), parity]) # Verify syndrome syndrome = (H_full @ codeword) % 2 if not np.all(syndrome == 0): raise ValueError( f"SC-LDPC encoding failed: syndrome weight = {syndrome.sum()}, " f"H rank ~{len(pivot_cols)}, H rows = {m_full}" ) return codeword def windowed_decode(llr_q, H_full, L, w, z, n_base, m_base, W=5, max_iter=20, cn_mode='normalized', alpha=0.75): """ Windowed decoding for SC-LDPC codes. Decode a sliding window of W positions at a time, fixing decoded positions as the window advances. Uses flooding schedule within each window iteration to avoid message staleness on the expanded binary H matrix. Args: llr_q: quantized channel LLRs for entire SC codeword H_full: full SC-LDPC parity check matrix (binary) L: chain length (number of CN positions) w: coupling width z: lifting factor n_base: base matrix columns m_base: base matrix rows W: window size in positions (default 5) max_iter: iterations per window position cn_mode: 'offset' or 'normalized' alpha: scaling factor for normalized mode Returns: (decoded_bits, converged, total_iterations) """ total_rows, total_cols = H_full.shape n_vn_positions = L + w - 1 def sat_clip(v): return max(Q_MIN, min(Q_MAX, int(v))) def cn_update_row(msgs_in): """Min-sum CN update for a list of incoming VN->CN messages.""" dc = len(msgs_in) if dc == 0: return [] signs = [1 if m < 0 else 0 for m in msgs_in] mags = [abs(m) for m in msgs_in] sign_xor = sum(signs) % 2 min1 = Q_MAX min2 = Q_MAX min1_idx = 0 for i in range(dc): if mags[i] < min1: min2 = min1 min1 = mags[i] min1_idx = i elif mags[i] < min2: min2 = mags[i] msgs_out = [] for j in range(dc): mag = min2 if j == min1_idx else min1 if cn_mode == 'normalized': mag = int(mag * alpha) else: mag = max(0, mag - OFFSET) sgn = sign_xor ^ signs[j] val = -mag if sgn else mag msgs_out.append(val) return msgs_out # Precompute CN->VN adjacency: for each row, list of connected column indices cn_neighbors = [] for row in range(total_rows): cn_neighbors.append(np.where(H_full[row] == 1)[0].tolist()) # Precompute VN->CN adjacency: for each column, list of connected row indices vn_neighbors = [] for col in range(total_cols): vn_neighbors.append(np.where(H_full[:, col] == 1)[0].tolist()) # Channel LLRs (fixed, never modified) channel_llr = np.array([int(x) for x in llr_q], dtype=np.int32) # CN->VN message memory: msg_mem[(row, col)] = last CN->VN message msg_mem = {} for row in range(total_rows): for col in cn_neighbors[row]: msg_mem[(row, col)] = 0 # Output array for hard decisions decoded = np.zeros(total_cols, dtype=np.int8) total_iterations = 0 # Process each target VN position for p in range(n_vn_positions): # Define window CN positions: max(0, p-W+1) to min(p, L-1) cn_pos_start = max(0, p - W + 1) cn_pos_end = min(p, L - 1) # Collect all CN rows in the window window_cn_rows = [] for cn_pos in range(cn_pos_start, cn_pos_end + 1): row_start = cn_pos * m_base * z row_end = (cn_pos + 1) * m_base * z for r in range(row_start, row_end): window_cn_rows.append(r) if len(window_cn_rows) == 0: # No CN rows cover this position; just make hard decisions from channel LLR # plus accumulated CN messages vn_col_start = p * n_base * z vn_col_end = min((p + 1) * n_base * z, total_cols) for c in range(vn_col_start, vn_col_end): belief = int(channel_llr[c]) for row in vn_neighbors[c]: belief += msg_mem[(row, c)] decoded[c] = 1 if belief < 0 else 0 continue # Collect all VN columns that are touched by the window CN rows window_vn_cols_set = set() for row in window_cn_rows: for col in cn_neighbors[row]: window_vn_cols_set.add(col) window_vn_cols = sorted(window_vn_cols_set) # Run max_iter flooding iterations on the window CN rows for it in range(max_iter): # Step 1: Compute beliefs for all VN columns in window # belief[col] = channel_llr[col] + sum of all CN->VN messages to col beliefs = {} for col in window_vn_cols: b = int(channel_llr[col]) for row in vn_neighbors[col]: b += msg_mem[(row, col)] beliefs[col] = sat_clip(b) # Step 2: For each CN row in the window, compute VN->CN and CN->VN new_msgs = {} for row in window_cn_rows: cols = cn_neighbors[row] dc = len(cols) if dc == 0: continue # VN->CN messages: belief - old CN->VN message from this row vn_to_cn = [] for col in cols: vn_to_cn.append(sat_clip(beliefs[col] - msg_mem[(row, col)])) # CN update cn_to_vn = cn_update_row(vn_to_cn) # Store new messages (apply after all rows computed) for ci, col in enumerate(cols): new_msgs[(row, col)] = cn_to_vn[ci] # Step 3: Update message memory for (row, col), val in new_msgs.items(): msg_mem[(row, col)] = val total_iterations += 1 # Make hard decisions for VN position p's bits vn_col_start = p * n_base * z vn_col_end = min((p + 1) * n_base * z, total_cols) for c in range(vn_col_start, vn_col_end): belief = int(channel_llr[c]) for row in vn_neighbors[c]: belief += msg_mem[(row, c)] decoded[c] = 1 if belief < 0 else 0 # Check if all decoded bits form a valid codeword syndrome = (H_full @ decoded) % 2 converged = np.all(syndrome == 0) return decoded, converged, total_iterations