#!/usr/bin/env python3 """ Spatially-Coupled LDPC (SC-LDPC) Code Construction and Decoding SC-LDPC codes achieve threshold saturation: the BP threshold approaches the MAP threshold, closing a significant portion of the gap to Shannon limit. Construction: replicate a protograph base matrix along a chain of L positions with coupling width w, creating a convolutional-like structure. """ import numpy as np import argparse import json import sys import os sys.path.insert(0, os.path.dirname(__file__)) from ldpc_sim import Q_BITS, Q_MAX, Q_MIN, OFFSET from density_evolution import de_cn_update_vectorized def split_protograph(B, w=2, seed=None): """ Split a protograph base matrix into w component matrices. Each edge in B (entry >= 0) is randomly assigned to exactly one of the w component matrices. The component that receives the edge gets value 0 (circulant shift assigned later during chain construction), while all other components get -1 (no connection) at that position. Args: B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected. w: Coupling width (number of component matrices). seed: Random seed for reproducibility. Returns: List of w component matrices, each with shape (m_base, n_base). Component values: 0 where edge is assigned, -1 otherwise. """ rng = np.random.default_rng(seed) m_base, n_base = B.shape # Initialize all components to -1 (no connection) components = [np.full((m_base, n_base), -1, dtype=np.int16) for _ in range(w)] for r in range(m_base): for c in range(n_base): if B[r, c] >= 0: # Randomly assign this edge to one component chosen = rng.integers(0, w) components[chosen][r, c] = 0 return components def build_sc_chain(B, L=20, w=2, z=32, seed=None): """ Build the full SC-LDPC parity-check matrix as a dense binary matrix. The chain has L CN positions and L+w-1 VN positions. Position t's CNs connect to VN positions t, t+1, ..., t+w-1 using the w component matrices. Args: B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected. L: Chain length (number of CN positions). w: Coupling width. z: Lifting factor (circulant size). seed: Random seed for reproducibility. Returns: H_full: Binary parity-check matrix of shape (L * m_base * z) x ((L + w - 1) * n_base * z). components: List of w component matrices from split_protograph. meta: Dictionary with construction metadata. """ rng = np.random.default_rng(seed) m_base, n_base = B.shape # Split the protograph into w components # Use a sub-seed derived from the main seed for splitting split_seed = int(rng.integers(0, 2**31)) if seed is not None else None components = split_protograph(B, w=w, seed=split_seed) n_cn_positions = L n_vn_positions = L + w - 1 total_rows = n_cn_positions * m_base * z total_cols = n_vn_positions * n_base * z H_full = np.zeros((total_rows, total_cols), dtype=np.int8) # For each CN position t, for each component i, connect CN group t # to VN group t+i using component B_i with random circulant shifts. for t in range(L): for i in range(w): vn_pos = t + i # VN position this component connects to comp = components[i] for r in range(m_base): for c in range(n_base): if comp[r, c] >= 0: # This entry is connected; assign a random circulant shift shift = int(rng.integers(0, z)) # Place the z x z circulant permutation matrix # CN rows: [t * m_base * z + r * z, ... + (r+1)*z) # VN cols: [vn_pos * n_base * z + c * z, ... + (c+1)*z) for zz in range(z): row_idx = t * m_base * z + r * z + zz col_idx = vn_pos * n_base * z + c * z + (zz + shift) % z H_full[row_idx, col_idx] = 1 meta = { 'L': L, 'w': w, 'z': z, 'm_base': m_base, 'n_base': n_base, 'total_rows': total_rows, 'total_cols': total_cols, 'n_cn_positions': n_cn_positions, 'n_vn_positions': n_vn_positions, 'rate_design': 1.0 - (total_rows / total_cols), } return H_full, components, meta def sc_encode(info_bits, H_full, k_total): """ Encode using GF(2) Gaussian elimination on the SC-LDPC parity-check matrix. Places info_bits in the first k_total positions of the codeword and solves for the remaining parity bits such that H_full * codeword = 0 (mod 2). Handles rank-deficient H matrices (common with SC-LDPC boundary effects) by leaving free variables as zero. Args: info_bits: Information bits array of length k_total. H_full: Binary parity-check matrix (m x n). k_total: Number of information bits. Returns: codeword: Binary codeword array of length n (= H_full.shape[1]). Raises: ValueError: If encoding fails (syndrome is nonzero). """ m_full, n_full = H_full.shape n_parity = n_full - k_total assert len(info_bits) == k_total, ( f"info_bits length {len(info_bits)} != k_total {k_total}" ) # Split H into information and parity parts # H_full = [H_info | H_parity] H_info = H_full[:, :k_total] H_parity = H_full[:, k_total:] # Compute RHS: H_info * info_bits mod 2 rhs = (H_info @ info_bits) % 2 # Solve H_parity * p = rhs (mod 2) via Gaussian elimination # Build augmented matrix [H_parity | rhs] aug = np.zeros((m_full, n_parity + 1), dtype=np.int8) aug[:, :n_parity] = H_parity.copy() aug[:, n_parity] = rhs # Forward elimination with partial pivoting pivot_row = 0 pivot_cols = [] for col in range(n_parity): # Find a pivot row for this column found = -1 for r in range(pivot_row, m_full): if aug[r, col] == 1: found = r break if found < 0: # No pivot for this column (rank deficient) - skip continue # Swap pivot row into position if found != pivot_row: aug[[pivot_row, found]] = aug[[found, pivot_row]] # Eliminate all other rows with a 1 in this column for r in range(m_full): if r != pivot_row and aug[r, col] == 1: aug[r] = (aug[r] + aug[pivot_row]) % 2 pivot_cols.append(col) pivot_row += 1 # Back-substitute: extract parity bit values from pivot columns parity = np.zeros(n_parity, dtype=np.int8) for i, col in enumerate(pivot_cols): parity[col] = aug[i, n_parity] # Assemble codeword: [info_bits | parity] codeword = np.concatenate([info_bits.astype(np.int8), parity]) # Verify syndrome syndrome = (H_full @ codeword) % 2 if not np.all(syndrome == 0): raise ValueError( f"SC-LDPC encoding failed: syndrome weight = {syndrome.sum()}, " f"H rank ~{len(pivot_cols)}, H rows = {m_full}" ) return codeword def windowed_decode(llr_q, H_full, L, w, z, n_base, m_base, W=5, max_iter=20, cn_mode='normalized', alpha=0.75): """ Windowed decoding for SC-LDPC codes. Decode a sliding window of W positions at a time, fixing decoded positions as the window advances. Uses flooding schedule within each window iteration to avoid message staleness on the expanded binary H matrix. Args: llr_q: quantized channel LLRs for entire SC codeword H_full: full SC-LDPC parity check matrix (binary) L: chain length (number of CN positions) w: coupling width z: lifting factor n_base: base matrix columns m_base: base matrix rows W: window size in positions (default 5) max_iter: iterations per window position cn_mode: 'offset' or 'normalized' alpha: scaling factor for normalized mode Returns: (decoded_bits, converged, total_iterations) """ total_rows, total_cols = H_full.shape n_vn_positions = L + w - 1 def sat_clip(v): return max(Q_MIN, min(Q_MAX, int(v))) def cn_update_row(msgs_in): """Min-sum CN update for a list of incoming VN->CN messages.""" dc = len(msgs_in) if dc == 0: return [] signs = [1 if m < 0 else 0 for m in msgs_in] mags = [abs(m) for m in msgs_in] sign_xor = sum(signs) % 2 min1 = Q_MAX min2 = Q_MAX min1_idx = 0 for i in range(dc): if mags[i] < min1: min2 = min1 min1 = mags[i] min1_idx = i elif mags[i] < min2: min2 = mags[i] msgs_out = [] for j in range(dc): mag = min2 if j == min1_idx else min1 if cn_mode == 'normalized': mag = int(mag * alpha) else: mag = max(0, mag - OFFSET) sgn = sign_xor ^ signs[j] val = -mag if sgn else mag msgs_out.append(val) return msgs_out # Precompute CN->VN adjacency: for each row, list of connected column indices cn_neighbors = [] for row in range(total_rows): cn_neighbors.append(np.where(H_full[row] == 1)[0].tolist()) # Precompute VN->CN adjacency: for each column, list of connected row indices vn_neighbors = [] for col in range(total_cols): vn_neighbors.append(np.where(H_full[:, col] == 1)[0].tolist()) # Channel LLRs (fixed, never modified) channel_llr = np.array([int(x) for x in llr_q], dtype=np.int32) # CN->VN message memory: msg_mem[(row, col)] = last CN->VN message msg_mem = {} for row in range(total_rows): for col in cn_neighbors[row]: msg_mem[(row, col)] = 0 # Output array for hard decisions decoded = np.zeros(total_cols, dtype=np.int8) total_iterations = 0 # Process each target VN position for p in range(n_vn_positions): # Define window CN positions: max(0, p-W+1) to min(p, L-1) cn_pos_start = max(0, p - W + 1) cn_pos_end = min(p, L - 1) # Collect all CN rows in the window window_cn_rows = [] for cn_pos in range(cn_pos_start, cn_pos_end + 1): row_start = cn_pos * m_base * z row_end = (cn_pos + 1) * m_base * z for r in range(row_start, row_end): window_cn_rows.append(r) if len(window_cn_rows) == 0: # No CN rows cover this position; just make hard decisions from channel LLR # plus accumulated CN messages vn_col_start = p * n_base * z vn_col_end = min((p + 1) * n_base * z, total_cols) for c in range(vn_col_start, vn_col_end): belief = int(channel_llr[c]) for row in vn_neighbors[c]: belief += msg_mem[(row, c)] decoded[c] = 1 if belief < 0 else 0 continue # Collect all VN columns that are touched by the window CN rows window_vn_cols_set = set() for row in window_cn_rows: for col in cn_neighbors[row]: window_vn_cols_set.add(col) window_vn_cols = sorted(window_vn_cols_set) # Run max_iter flooding iterations on the window CN rows for it in range(max_iter): # Step 1: Compute beliefs for all VN columns in window # belief[col] = channel_llr[col] + sum of all CN->VN messages to col beliefs = {} for col in window_vn_cols: b = int(channel_llr[col]) for row in vn_neighbors[col]: b += msg_mem[(row, col)] beliefs[col] = sat_clip(b) # Step 2: For each CN row in the window, compute VN->CN and CN->VN new_msgs = {} for row in window_cn_rows: cols = cn_neighbors[row] dc = len(cols) if dc == 0: continue # VN->CN messages: belief - old CN->VN message from this row vn_to_cn = [] for col in cols: vn_to_cn.append(sat_clip(beliefs[col] - msg_mem[(row, col)])) # CN update cn_to_vn = cn_update_row(vn_to_cn) # Store new messages (apply after all rows computed) for ci, col in enumerate(cols): new_msgs[(row, col)] = cn_to_vn[ci] # Step 3: Update message memory for (row, col), val in new_msgs.items(): msg_mem[(row, col)] = val total_iterations += 1 # Make hard decisions for VN position p's bits vn_col_start = p * n_base * z vn_col_end = min((p + 1) * n_base * z, total_cols) for c in range(vn_col_start, vn_col_end): belief = int(channel_llr[c]) for row in vn_neighbors[c]: belief += msg_mem[(row, c)] decoded[c] = 1 if belief < 0 else 0 # Check if all decoded bits form a valid codeword syndrome = (H_full @ decoded) % 2 converged = np.all(syndrome == 0) return decoded, converged, total_iterations def sc_density_evolution(B, L=50, w=2, lam_s=5.0, lam_b=0.1, z_pop=10000, max_iter=200, cn_mode='normalized', alpha=0.75): """ Position-aware density evolution for SC-LDPC codes. Uses flooding schedule: compute total beliefs from channel + all CN->VN messages, then compute all new CN->VN messages simultaneously. This avoids the layered update instability that occurs when multiple CN rows modify the same VN belief within one iteration via different random permutations. Tracks belief populations at each chain position to observe threshold saturation and the wave decoding effect where boundary positions converge before interior positions. Args: B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected. L: Chain length (number of CN positions). w: Coupling width (number of component matrices). lam_s: Signal photons per slot. lam_b: Background photons per slot. z_pop: Population size for Monte Carlo DE. max_iter: Maximum number of decoding iterations. cn_mode: 'offset' or 'normalized' min-sum variant. alpha: Scaling factor for normalized mode. Returns: (converged, per_position_errors, iterations_used) converged: True if all positions have error rate < 1e-3. per_position_errors: list of (L+w-1) error fractions at final iteration. iterations_used: number of iterations actually performed. """ m_base, n_base = B.shape n_vn_positions = L + w - 1 # Split protograph into w components (deterministic with seed=0) components = split_protograph(B, w=w, seed=0) # Build edge list: for each CN position t, row r, collect connected # (component_idx, vn_position, base_col) tuples # edge_map[(t, r)] = list of (i, vn_pos, c) where component i connects # CN(t,r) to VN(vn_pos, c) edge_map = {} for t in range(L): for r in range(m_base): edges = [] for i in range(w): vn_pos = t + i for c in range(n_base): if components[i][r, c] >= 0: ei = len(edges) edges.append((i, vn_pos, c)) edge_map[(t, r)] = edges # Build reverse map: for each VN group (p, c), list of (t, r, ei) edges vn_edges = {} for t in range(L): for r in range(m_base): edges = edge_map[(t, r)] for ei, (i, vn_pos, c) in enumerate(edges): key = (vn_pos, c) if key not in vn_edges: vn_edges[key] = [] vn_edges[key].append((t, r, ei)) # Initialize channel LLR populations per (position, base_col) # All-zeros codeword assumed (standard for DE) log_ratio = np.log((lam_s + lam_b) / lam_b) if lam_b > 0 else 100.0 scale = Q_MAX / 5.0 channel_llr = np.zeros((n_vn_positions, n_base, z_pop), dtype=np.float64) for p in range(n_vn_positions): for c in range(n_base): y = np.random.poisson(lam_b, size=z_pop) llr_float = lam_s - y * log_ratio llr_q = np.round(llr_float * scale).astype(np.float64) llr_q = np.clip(llr_q, Q_MIN, Q_MAX) channel_llr[p, c] = llr_q # CN->VN message memory: msg[(t, r, ei)] = z_pop samples msg = {} for t in range(L): for r in range(m_base): for ei in range(len(edge_map[(t, r)])): msg[(t, r, ei)] = np.zeros(z_pop, dtype=np.float64) # Convergence uses interior positions (excluding w-1 boundary positions # on each side) to avoid the boundary rate effect. Boundary positions # have fewer CN connections and genuinely higher error rates. conv_threshold = 1e-3 interior_start = w - 1 # skip first w-1 positions interior_end = n_vn_positions - (w - 1) # skip last w-1 positions if interior_end <= interior_start: # Chain too short for interior; use all positions interior_start = 0 interior_end = n_vn_positions iterations_used = 0 for it in range(max_iter): # Step 1: Compute total beliefs for each VN group # beliefs[p,c] = channel_llr[p,c] + sum of randomly permuted CN->VN msgs beliefs = channel_llr.copy() permuted_msg = {} # store the permuted version used in belief computation for (p, c), edge_list in vn_edges.items(): for (t, r, ei) in edge_list: perm = np.random.permutation(z_pop) pmsg = msg[(t, r, ei)][perm] permuted_msg[(t, r, ei)] = pmsg beliefs[p, c] += pmsg # Note: beliefs are NOT clipped to Q_MIN/Q_MAX. They accumulate in # full float64 precision to avoid saturation artifacts. Only the VN->CN # and CN->VN messages are quantized (clipped), matching the distinction # between internal node accumulation and wire-level quantization. # Step 2: Compute new CN->VN messages (flooding: all at once) new_msg = {} for t in range(L): for r in range(m_base): edges = edge_map[(t, r)] dc = len(edges) if dc < 2: for ei in range(dc): new_msg[(t, r, ei)] = np.zeros(z_pop, dtype=np.float64) continue # Compute VN->CN extrinsic messages vn_to_cn = [] for ei, (i, vn_pos, c) in enumerate(edges): # Extrinsic = belief - this CN's permuted contribution # Apply random interleaving permutation perm = np.random.permutation(z_pop) ext = beliefs[vn_pos, c][perm] - permuted_msg[(t, r, ei)][perm] ext = np.clip(np.round(ext), Q_MIN, Q_MAX) vn_to_cn.append(ext) # CN update (vectorized min-sum) cn_to_vn = de_cn_update_vectorized( vn_to_cn, offset=OFFSET, cn_mode=cn_mode, alpha=alpha ) for ei in range(dc): new_msg[(t, r, ei)] = cn_to_vn[ei] msg = new_msg iterations_used += 1 # Check convergence: per-position error rates per_position_errors = _compute_position_errors( beliefs, n_vn_positions, n_base, z_pop ) # Converge based on interior positions only interior_errors = per_position_errors[interior_start:interior_end] if all(e < conv_threshold for e in interior_errors): return True, per_position_errors, iterations_used # Did not converge; return final state per_position_errors = _compute_position_errors( beliefs, n_vn_positions, n_base, z_pop ) return False, per_position_errors, iterations_used def _compute_position_errors(beliefs, n_vn_positions, n_base, z_pop): """Compute per-position error fractions from belief arrays.""" per_position_errors = [] for p in range(n_vn_positions): wrong = 0 total = 0 for c in range(n_base): wrong += np.sum(beliefs[p, c] < 0) total += z_pop per_position_errors.append(wrong / total) return per_position_errors def compute_sc_threshold(B, L=50, w=2, lam_b=0.1, z_pop=10000, tol=0.25, cn_mode='normalized', alpha=0.75): """ Binary search for minimum lam_s where SC density evolution converges. Args: B: Base matrix (m_base x n_base). L: Chain length. w: Coupling width. lam_b: Background photon rate. z_pop: DE population size. tol: Search tolerance (stop when hi - lo <= tol). cn_mode: 'offset' or 'normalized'. alpha: Scaling factor for normalized mode. Returns: Threshold lam_s* (upper bound from binary search). """ lo = 0.1 hi = 20.0 # Verify hi converges converged, _, _ = sc_density_evolution( B, L=L, w=w, lam_s=hi, lam_b=lam_b, z_pop=z_pop, max_iter=100, cn_mode=cn_mode, alpha=alpha ) if not converged: return hi # Doesn't converge even at hi # Verify lo doesn't converge converged, _, _ = sc_density_evolution( B, L=L, w=w, lam_s=lo, lam_b=lam_b, z_pop=z_pop, max_iter=100, cn_mode=cn_mode, alpha=alpha ) if converged: return lo # Converges even at lo while hi - lo > tol: mid = (lo + hi) / 2 converged, _, _ = sc_density_evolution( B, L=L, w=w, lam_s=mid, lam_b=lam_b, z_pop=z_pop, max_iter=100, cn_mode=cn_mode, alpha=alpha ) if converged: hi = mid else: lo = mid return hi # ============================================================================= # CLI # ============================================================================= def run_threshold_comparison(seed=42, z_pop=5000, tol=0.5, L=20): """Compare SC-LDPC and uncoupled DE thresholds.""" from ldpc_sim import H_BASE from density_evolution import ( compute_threshold, build_de_profile, make_profile ) np.random.seed(seed) print("=" * 60) print("SC-LDPC vs Uncoupled Threshold Comparison") print("=" * 60) # Uncoupled thresholds degrees_opt = [7, 4, 4, 4, 4, 3, 3, 3] profile_opt = build_de_profile(degrees_opt, m_base=7) profile_orig = make_profile(H_BASE) print("\nUncoupled thresholds:") thresh_opt_offset = compute_threshold( profile_opt, lam_b=0.1, z_pop=z_pop, tol=tol, cn_mode='offset') thresh_opt_norm = compute_threshold( profile_opt, lam_b=0.1, z_pop=z_pop, tol=tol, cn_mode='normalized', alpha=0.875) thresh_orig = compute_threshold( profile_orig, lam_b=0.1, z_pop=z_pop, tol=tol, cn_mode='offset') print(f" Original staircase (offset): {thresh_orig:.2f} photons/slot") print(f" DE-optimized (offset): {thresh_opt_offset:.2f} photons/slot") print(f" DE-optimized (normalized 0.875): {thresh_opt_norm:.2f} photons/slot") # SC-LDPC thresholds print(f"\nSC-LDPC thresholds (L={L}, w=2, normalized 0.875):") sc_thresh_orig = compute_sc_threshold( H_BASE, L=L, w=2, lam_b=0.1, z_pop=z_pop, tol=tol, cn_mode='normalized', alpha=0.875) print(f" SC original staircase: {sc_thresh_orig:.2f} photons/slot") from density_evolution import construct_base_matrix H_opt, girth = construct_base_matrix(degrees_opt, z=32, n_trials=500) sc_thresh_opt = compute_sc_threshold( H_opt, L=L, w=2, lam_b=0.1, z_pop=z_pop, tol=tol, cn_mode='normalized', alpha=0.875) print(f" SC DE-optimized: {sc_thresh_opt:.2f} photons/slot") shannon_limit = 0.47 print(f"\n Shannon limit (rate 1/8): {shannon_limit} photons/slot") return { 'uncoupled_thresholds': { 'original_offset': float(thresh_orig), 'optimized_offset': float(thresh_opt_offset), 'optimized_normalized': float(thresh_opt_norm), }, 'sc_thresholds': { 'sc_original': float(sc_thresh_orig), 'sc_optimized': float(sc_thresh_opt), }, 'shannon_limit': shannon_limit, 'params': {'L': L, 'w': 2, 'z_pop': z_pop, 'tol': tol}, } def run_fer_comparison(seed=42, n_frames=50, L=10, z=32): """FER comparison: SC-LDPC vs uncoupled at Z=32.""" from ldpc_sim import H_BASE, poisson_channel, quantize_llr np.random.seed(seed) print("=" * 60) print(f"SC-LDPC vs Uncoupled FER Comparison (Z={z}, L={L})") print("=" * 60) m_base, n_base = H_BASE.shape # Build SC chain H_sc, components, meta = build_sc_chain( H_BASE, L=L, w=2, z=z, seed=seed) n_total = H_sc.shape[1] lam_s_points = [2.0, 3.0, 4.0, 5.0, 7.0, 10.0] sc_results = {} print(f"\nSC-LDPC (L={L}, w=2, windowed W=5, normalized alpha=0.875):") print(f"{'lam_s':>8s} {'FER':>10s} {'BER':>10s}") print("-" * 30) for lam_s in lam_s_points: frame_errors = 0 bit_errors = 0 total_bits = 0 for _ in range(n_frames): codeword = np.zeros(n_total, dtype=np.int8) llr_float, _ = poisson_channel(codeword, lam_s, 0.1) llr_q = quantize_llr(llr_float) decoded, converged, iters = windowed_decode( llr_q, H_sc, L=L, w=2, z=z, n_base=n_base, m_base=m_base, W=5, max_iter=20, cn_mode='normalized', alpha=0.875) errs = np.sum(decoded != 0) bit_errors += errs total_bits += n_total if errs > 0: frame_errors += 1 fer = frame_errors / n_frames ber = bit_errors / total_bits if total_bits > 0 else 0 sc_results[lam_s] = {'fer': float(fer), 'ber': float(ber)} print(f"{lam_s:8.1f} {fer:10.3f} {ber:10.6f}") return { 'lam_s_points': lam_s_points, 'sc_fer': {str(k): v for k, v in sc_results.items()}, 'params': {'L': L, 'w': 2, 'z': z, 'n_frames': n_frames}, } def run_full_pipeline(seed=42): """Full SC-LDPC pipeline: threshold comparison + FER.""" print("=" * 70) print("SC-LDPC FULL PIPELINE") print("=" * 70) # Step 1: Threshold comparison print("\n--- Step 1: Threshold Comparison ---") threshold_results = run_threshold_comparison( seed=seed, z_pop=5000, tol=0.5, L=20) # Step 2: FER comparison print("\n--- Step 2: FER Comparison ---") fer_results = run_fer_comparison( seed=seed, n_frames=50, L=10, z=32) # Combine and save results output = { **threshold_results, 'fer_comparison': fer_results, } out_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data') os.makedirs(out_dir, exist_ok=True) out_path = os.path.join(out_dir, 'sc_ldpc_results.json') with open(out_path, 'w') as f: json.dump(output, f, indent=2, default=str) print(f"\nResults saved to {out_path}") return output def main(): parser = argparse.ArgumentParser( description='SC-LDPC Code Construction and Analysis', ) subparsers = parser.add_subparsers(dest='command') p_thresh = subparsers.add_parser('threshold', help='SC-DE threshold comparison') p_thresh.add_argument('--seed', type=int, default=42) p_thresh.add_argument('--z-pop', type=int, default=5000) p_thresh.add_argument('--tol', type=float, default=0.5) p_thresh.add_argument('--L', type=int, default=20) p_fer = subparsers.add_parser('fer-compare', help='FER: SC vs uncoupled') p_fer.add_argument('--seed', type=int, default=42) p_fer.add_argument('--n-frames', type=int, default=50) p_fer.add_argument('--L', type=int, default=10) p_full = subparsers.add_parser('full', help='Full pipeline') p_full.add_argument('--seed', type=int, default=42) args = parser.parse_args() if args.command == 'threshold': run_threshold_comparison(seed=args.seed, z_pop=args.z_pop, tol=args.tol, L=args.L) elif args.command == 'fer-compare': run_fer_comparison(seed=args.seed, n_frames=args.n_frames, L=args.L) elif args.command == 'full': run_full_pipeline(seed=args.seed) else: parser.print_help() if __name__ == '__main__': main()