From ad7cb5098c8d3fd59846a38e5bdc5a1c50e7630e Mon Sep 17 00:00:00 2001 From: cah Date: Tue, 24 Feb 2026 04:46:39 -0700 Subject: [PATCH] feat: add code analysis tool with rate, matrix, quantization, and Shannon analyses Implements four LDPC code analyses for photon-starved optical channels: - Rate sweep: compare FER across 1/2, 1/3, 1/4, 1/6, 1/8 IRA staircase codes - Matrix comparison: original staircase vs improved staircase vs PEG ring - Quantization sweep: 4-16 bit and float precision impact on FER - Shannon gap: binary-input Poisson channel capacity limits via binary search Core infrastructure includes generic IRA staircase builder, GF(2) Gaussian elimination encoder for non-triangular matrices, parameterized layered min-sum decoder with variable check degree, and BFS girth computation. Co-Authored-By: Claude Opus 4.6 --- model/ldpc_analysis.py | 1052 +++++++++++++++++++++++++++++++++++ model/test_ldpc_analysis.py | 85 +++ 2 files changed, 1137 insertions(+) create mode 100644 model/ldpc_analysis.py create mode 100644 model/test_ldpc_analysis.py diff --git a/model/ldpc_analysis.py b/model/ldpc_analysis.py new file mode 100644 index 0000000..093af49 --- /dev/null +++ b/model/ldpc_analysis.py @@ -0,0 +1,1052 @@ +#!/usr/bin/env python3 +""" +LDPC Code Analysis Tool + +Four analyses for the photon-starved optical LDPC decoder: + 1. Rate comparison (--rate-sweep) + 2. Base matrix comparison (--matrix-compare) + 3. Quantization sweep (--quant-sweep) + 4. Shannon gap analysis (--shannon-gap) + +Usage: + python3 ldpc_analysis.py --rate-sweep + python3 ldpc_analysis.py --matrix-compare + python3 ldpc_analysis.py --quant-sweep + python3 ldpc_analysis.py --shannon-gap + python3 ldpc_analysis.py --all +""" + +import numpy as np +import argparse +import json +import os +import sys +from collections import deque + +from ldpc_sim import ( + poisson_channel, quantize_llr, decode_layered_min_sum, + compute_syndrome_weight, sat_add_q, sat_sub_q, min_sum_cn_update, + Q_BITS, Q_MAX, Q_MIN, OFFSET, + N, K, M, Z, N_BASE, M_BASE, H_BASE, + build_full_h_matrix, ldpc_encode, +) + + +# ============================================================================= +# Core infrastructure +# ============================================================================= + +def build_ira_staircase(m_base, z=32): + """ + Build generic IRA staircase QC-LDPC code for any rate. + + Rate = 1/(m_base+1), n_base = m_base+1. + Col 0 (info) connects to ALL m_base rows with spread shifts. + Lower-triangular staircase parity: row r connects to col r+1 (diagonal), + row r>0 also connects to col r (sub-diagonal). + + Returns: + (H_base, H_full) where H_full is the expanded binary matrix. + """ + n_base = m_base + 1 + H_b = -np.ones((m_base, n_base), dtype=np.int16) + + # Column 0 (info) connects to all rows with spread shifts + for r in range(m_base): + H_b[r, 0] = (r * z // m_base) % z + + # Staircase parity structure + for r in range(m_base): + # Diagonal: row r connects to col r+1 + if r == 0: + H_b[r, r + 1] = (z // 4) % z # some spread for row 0 + else: + H_b[r, r + 1] = 0 # identity for sub-diagonal rows + + # Sub-diagonal: row r>0 connects to col r + if r > 0: + H_b[r, r] = (r * 7) % z # spread shift + + # Build full expanded matrix + m_full = m_base * z + n_full = n_base * z + H_full = np.zeros((m_full, n_full), dtype=np.int8) + for r in range(m_base): + for c in range(n_base): + shift = int(H_b[r, c]) + if shift < 0: + continue + for zz in range(z): + col_idx = c * z + (zz + shift) % z + H_full[r * z + zz, col_idx] = 1 + + return H_b, H_full + + +def ira_encode(info, H_base, H_full, z=32): + """ + Generic IRA encoder using sequential staircase solve. + + Row 0: solve for p[1]. + Rows 1..m_base-1: solve for p[2]..p[m_base]. + Asserts syndrome = 0. + """ + m_base = H_base.shape[0] + n_base = H_base.shape[1] + k = z + assert len(info) == k + + def apply_p(x, s): + """Apply circulant P_s: result[i] = x[(i+s)%Z].""" + return np.roll(x, -int(s)) + + def inv_p(y, s): + """Apply P_s inverse.""" + return np.roll(y, int(s)) + + # Parity blocks: p[c] for c = 1..n_base-1 + p = [np.zeros(z, dtype=np.int8) for _ in range(n_base)] + + # Step 1: Solve row 0 for p[1] + # Accumulate contributions from all connected columns except col 1 (the new unknown) + accum = np.zeros(z, dtype=np.int8) + for c in range(n_base): + shift = int(H_base[0, c]) + if shift < 0: + continue + if c == 0: + accum = (accum + apply_p(info, shift)) % 2 + elif c == 1: + # This is the unknown we're solving for + pass + else: + accum = (accum + apply_p(p[c], shift)) % 2 + + # p[1] = P_{H[0][1]}^-1 * accum + p[1] = inv_p(accum, H_base[0, 1]) + + # Step 2: Solve rows 1..m_base-1 for p[2]..p[m_base] + for r in range(1, m_base): + accum = np.zeros(z, dtype=np.int8) + target_col = r + 1 # the new unknown column + for c in range(n_base): + shift = int(H_base[r, c]) + if shift < 0: + continue + if c == target_col: + continue # skip the unknown + if c == 0: + accum = (accum + apply_p(info, shift)) % 2 + else: + accum = (accum + apply_p(p[c], shift)) % 2 + + p[target_col] = inv_p(accum, H_base[r, target_col]) + + # Assemble codeword + codeword = np.concatenate([info] + [p[c] for c in range(1, n_base)]) + + # Verify + check = H_full @ codeword % 2 + assert np.all(check == 0), f"IRA encoding failed: syndrome weight = {check.sum()}" + + return codeword + + +def generic_decode(llr_q, H_base, z=32, max_iter=30, early_term=True, q_bits=6): + """ + Parameterized layered min-sum decoder for any QC-LDPC base matrix. + + Handles variable check node degree (different rows may have different + numbers of connected columns). + + Args: + llr_q: quantized channel LLRs + H_base: base matrix (m_base x n_base), -1 = no connection + z: lifting factor + max_iter: maximum iterations + early_term: stop when syndrome is zero + q_bits: quantization bits + + Returns: + (decoded_info_bits, converged, iterations, syndrome_weight) + """ + m_base_local = H_base.shape[0] + n_base_local = H_base.shape[1] + n_local = n_base_local * z + k_local = z + q_max = 2**(q_bits-1) - 1 + q_min = -(2**(q_bits-1)) + + def sat_add(a, b): + s = int(a) + int(b) + return max(q_min, min(q_max, s)) + + def sat_sub(a, b): + return sat_add(a, -b) + + def cn_update(msgs_in, offset=1): + dc = len(msgs_in) + signs = [1 if m < 0 else 0 for m in msgs_in] + mags = [abs(m) for m in msgs_in] + sign_xor = sum(signs) % 2 + + min1 = q_max + min2 = q_max + min1_idx = 0 + for i in range(dc): + if mags[i] < min1: + min2 = min1 + min1 = mags[i] + min1_idx = i + elif mags[i] < min2: + min2 = mags[i] + + msgs_out = [] + for j in range(dc): + mag = min2 if j == min1_idx else min1 + mag = max(0, mag - offset) + sgn = sign_xor ^ signs[j] + val = -mag if sgn else mag + msgs_out.append(val) + return msgs_out + + # Initialize beliefs from channel LLRs + beliefs = [int(x) for x in llr_q] + + # Initialize CN->VN messages to zero + msg = [[[0 for _ in range(z)] for _ in range(n_base_local)] for _ in range(m_base_local)] + + for iteration in range(max_iter): + for row in range(m_base_local): + connected_cols = [c for c in range(n_base_local) if H_base[row, c] >= 0] + dc = len(connected_cols) + + # VN->CN messages + vn_to_cn = [[0]*z for _ in range(dc)] + for ci, col in enumerate(connected_cols): + shift = int(H_base[row, col]) + for zz in range(z): + shifted_z = (zz + shift) % z + bit_idx = col * z + shifted_z + old_msg = msg[row][col][zz] + vn_to_cn[ci][zz] = sat_sub(beliefs[bit_idx], old_msg) + + # CN update + cn_to_vn = [[0]*z for _ in range(dc)] + for zz in range(z): + cn_inputs = [vn_to_cn[ci][zz] for ci in range(dc)] + cn_outputs = cn_update(cn_inputs) + for ci in range(dc): + cn_to_vn[ci][zz] = cn_outputs[ci] + + # Update beliefs and store new messages + for ci, col in enumerate(connected_cols): + shift = int(H_base[row, col]) + for zz in range(z): + shifted_z = (zz + shift) % z + bit_idx = col * z + shifted_z + new_msg = cn_to_vn[ci][zz] + extrinsic = vn_to_cn[ci][zz] + beliefs[bit_idx] = sat_add(extrinsic, new_msg) + msg[row][col][zz] = new_msg + + # Syndrome check + hard = [1 if b < 0 else 0 for b in beliefs] + sw = _compute_syndrome_weight(hard, H_base, z) + + if early_term and sw == 0: + return np.array(hard[:k_local]), True, iteration + 1, 0 + + hard = [1 if b < 0 else 0 for b in beliefs] + sw = _compute_syndrome_weight(hard, H_base, z) + return np.array(hard[:k_local]), sw == 0, max_iter, sw + + +def _compute_syndrome_weight(hard_bits, H_base, z): + """Compute syndrome weight for a generic base matrix.""" + m_base_local = H_base.shape[0] + n_base_local = H_base.shape[1] + weight = 0 + for r in range(m_base_local): + for zz in range(z): + parity = 0 + for c in range(n_base_local): + shift = int(H_base[r, c]) + if shift < 0: + continue + shifted_z = (zz + shift) % z + bit_idx = c * z + shifted_z + parity ^= hard_bits[bit_idx] + if parity: + weight += 1 + return weight + + +# ============================================================================= +# Analysis 1: Rate Comparison +# ============================================================================= + +def rate_sweep_single(m_base, z, lam_s, lam_b, n_frames, max_iter, q_bits): + """ + Run n_frames through encode/channel/decode for one rate and one lambda_s. + + Returns dict with rate, rate_val, m_base, lam_s, ber, fer, avg_iter. + """ + H_base, H_full = build_ira_staircase(m_base=m_base, z=z) + k = z + n = (m_base + 1) * z + rate_val = 1.0 / (m_base + 1) + + bit_errors = 0 + frame_errors = 0 + total_bits = 0 + total_iter = 0 + + for _ in range(n_frames): + info = np.random.randint(0, 2, k).astype(np.int8) + codeword = ira_encode(info, H_base, H_full, z=z) + + llr_float, _ = poisson_channel(codeword, lam_s, lam_b) + llr_q = quantize_llr(llr_float, q_bits=q_bits) + + decoded, converged, iters, sw = generic_decode( + llr_q, H_base, z=z, max_iter=max_iter, q_bits=q_bits + ) + total_iter += iters + + errs = np.sum(decoded != info) + bit_errors += errs + total_bits += k + if errs > 0: + frame_errors += 1 + + ber = bit_errors / total_bits if total_bits > 0 else 0 + fer = frame_errors / n_frames + avg_iter = total_iter / n_frames + + return { + 'rate': f'1/{m_base+1}', + 'rate_val': rate_val, + 'm_base': m_base, + 'lam_s': lam_s, + 'ber': float(ber), + 'fer': float(fer), + 'avg_iter': float(avg_iter), + } + + +def run_rate_sweep(lam_b=0.1, n_frames=500, max_iter=30): + """ + Compare FER across multiple code rates. + + Rates: 1/2, 1/3, 1/4, 1/6, 1/8. + """ + rates = [(1, '1/2'), (2, '1/3'), (3, '1/4'), (5, '1/6'), (7, '1/8')] + lam_s_values = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0, 5.0, 7.0, 10.0] + + print("=" * 80) + print("ANALYSIS 1: Rate Comparison (IRA staircase codes)") + print(f" n_frames={n_frames}, lam_b={lam_b}, max_iter={max_iter}") + print("=" * 80) + + all_results = [] + + # Header + header = f"{'lam_s':>8s}" + for m, label in rates: + header += f" {label:>8s}" + print(header) + print("-" * len(header)) + + for lam_s in lam_s_values: + row = f"{lam_s:8.1f}" + for m_base, label in rates: + result = rate_sweep_single( + m_base=m_base, z=32, lam_s=lam_s, lam_b=lam_b, + n_frames=n_frames, max_iter=max_iter, q_bits=Q_BITS, + ) + all_results.append(result) + fer = result['fer'] + if fer == 0: + row += f" {'0.000':>8s}" + else: + row += f" {fer:8.3f}" + print(row) + + # Threshold summary + print("\nThreshold summary (lowest lam_s where FER < 10%):") + for m_base, label in rates: + threshold = None + for lam_s in lam_s_values: + matching = [r for r in all_results + if r['m_base'] == m_base and r['lam_s'] == lam_s] + if matching and matching[0]['fer'] < 0.10: + threshold = lam_s + break + if threshold is not None: + print(f" Rate {label}: lam_s >= {threshold}") + else: + print(f" Rate {label}: FER >= 10% at all tested lam_s") + + return all_results + + +# ============================================================================= +# Analysis 2: Base Matrix Comparison +# ============================================================================= + +def build_improved_staircase(z=32): + """ + Improved staircase: same base structure as original but adds extra + connections to low-degree columns. + + Key: col 7 gets extra connection (dv=1->dv=2), col 1 gets extra (dv=2->dv=3). + NOTE: This is NO LONGER purely lower-triangular. Must use peg_encode. + """ + m_base = 7 + n_base = 8 + + # Start from original H_BASE + H_b = H_BASE.copy() + + # Add extra connection for col 7 (currently only connected from row 6) + # Connect col 7 to row 3 with shift 17 + H_b[3, 7] = 17 + + # Add extra connection for col 1 (currently rows 0,1) + # Connect col 1 to row 4 with shift 11 + H_b[4, 1] = 11 + + # Build full matrix + m_full = m_base * z + n_full = n_base * z + H_full = np.zeros((m_full, n_full), dtype=np.int8) + for r in range(m_base): + for c in range(n_base): + shift = int(H_b[r, c]) + if shift < 0: + continue + for zz in range(z): + col_idx = c * z + (zz + shift) % z + H_full[r * z + zz, col_idx] = 1 + + return H_b, H_full + + +def build_peg_matrix(z=32): + """ + PEG-like construction with ring structure (not strictly lower-triangular). + + More uniform degree distribution. All VN columns have degree >= 2. + """ + m_base = 7 + n_base = 8 + + # PEG ring: each column connects to 2-3 rows, more evenly distributed + # Design for good girth while maintaining encodability + H_b = -np.ones((m_base, n_base), dtype=np.int16) + + # Column 0 (info): high degree, connect to rows 0,1,2,3,4,5,6 + for r in range(m_base): + H_b[r, 0] = (r * z // m_base) % z + + # Parity columns with ring connections (degree 2-3 each) + # Col 1: rows 0, 1, 5 + H_b[0, 1] = 5 + H_b[1, 1] = 3 + H_b[5, 1] = 19 + + # Col 2: rows 1, 2, 6 + H_b[1, 2] = 0 + H_b[2, 2] = 7 + H_b[6, 2] = 23 + + # Col 3: rows 2, 3 + H_b[2, 3] = 0 + H_b[3, 3] = 13 + + # Col 4: rows 3, 4 + H_b[3, 4] = 0 + H_b[4, 4] = 19 + + # Col 5: rows 4, 5 + H_b[4, 5] = 0 + H_b[5, 5] = 25 + + # Col 6: rows 5, 6 + H_b[5, 6] = 0 + H_b[6, 6] = 31 + + # Col 7: rows 0, 6 + H_b[0, 7] = 11 + H_b[6, 7] = 0 + + # Build full matrix + m_full = m_base * z + n_full = n_base * z + H_full = np.zeros((m_full, n_full), dtype=np.int8) + for r in range(m_base): + for c in range(n_base): + shift = int(H_b[r, c]) + if shift < 0: + continue + for zz in range(z): + col_idx = c * z + (zz + shift) % z + H_full[r * z + zz, col_idx] = 1 + + return H_b, H_full + + +def peg_encode(info, H_base, H_full, z=32): + """ + GF(2) Gaussian elimination encoder for non-staircase matrices. + + Solves H_parity * p = H_info * info mod 2. + Works for any base matrix structure. + """ + m_base_local = H_base.shape[0] + n_base_local = H_base.shape[1] + m_full = m_base_local * z + n_full = n_base_local * z + k = z + + assert len(info) == k + + # H_full = [H_info | H_parity] + # H_info is first z columns, H_parity is the rest + H_info = H_full[:, :k] + H_parity = H_full[:, k:] + + # RHS: H_info * info mod 2 + rhs = (H_info @ info) % 2 + + # Solve H_parity * p = rhs mod 2 via Gaussian elimination + n_parity = n_full - k + # Augmented matrix [H_parity | rhs] + aug = np.zeros((m_full, n_parity + 1), dtype=np.int8) + aug[:, :n_parity] = H_parity.copy() + aug[:, n_parity] = rhs + + # Forward elimination + pivot_row = 0 + pivot_cols = [] + for col in range(n_parity): + # Find pivot + found = -1 + for r in range(pivot_row, m_full): + if aug[r, col] == 1: + found = r + break + if found < 0: + continue + # Swap rows + if found != pivot_row: + aug[[pivot_row, found]] = aug[[found, pivot_row]] + # Eliminate + for r in range(m_full): + if r != pivot_row and aug[r, col] == 1: + aug[r] = (aug[r] + aug[pivot_row]) % 2 + pivot_cols.append(col) + pivot_row += 1 + + # Back-substitute + parity = np.zeros(n_parity, dtype=np.int8) + for i, col in enumerate(pivot_cols): + parity[col] = aug[i, n_parity] + + # Assemble codeword + codeword = np.concatenate([info, parity]) + + # Verify + check = H_full @ codeword % 2 + assert np.all(check == 0), f"PEG encoding failed: syndrome weight = {check.sum()}" + + return codeword + + +def compute_girth(H_base, z): + """ + BFS-based shortest cycle detection in Tanner graph. + + Samples a subset of VNs for speed. Returns the girth (length of + shortest cycle) or infinity if no cycle found. + """ + m_base_local = H_base.shape[0] + n_base_local = H_base.shape[1] + + # Build adjacency: VN -> list of CN, CN -> list of VN + # In QC graph, VN (c, zz) connects to CN (r, (zz + shift) % z) for each row r + # where H_base[r, c] >= 0 + n_vn = n_base_local * z + n_cn = m_base_local * z + + def vn_neighbors(vn_idx): + """CNs connected to this VN.""" + c = vn_idx // z + zz = vn_idx % z + neighbors = [] + for r in range(m_base_local): + shift = int(H_base[r, c]) + if shift < 0: + continue + cn_idx = r * z + (zz + shift) % z + neighbors.append(cn_idx) + return neighbors + + def cn_neighbors(cn_idx): + """VNs connected to this CN.""" + r = cn_idx // z + zz_cn = cn_idx % z + neighbors = [] + for c in range(n_base_local): + shift = int(H_base[r, c]) + if shift < 0: + continue + # CN (r, zz_cn) connects to VN (c, (zz_cn - shift) % z) + vn_zz = (zz_cn - shift) % z + vn_idx = c * z + vn_zz + neighbors.append(vn_idx) + return neighbors + + min_cycle = float('inf') + + # Sample VNs: pick a few from each base column + sample_vns = [] + for c in range(n_base_local): + for zz in [0, z // 4, z // 2]: + sample_vns.append(c * z + zz) + + for start_vn in sample_vns: + # BFS from start_vn through bipartite graph + # Track (node_type, node_idx, parent_type, parent_idx) + # node_type: 'vn' or 'cn' + dist = {} + queue = deque() + dist[('vn', start_vn)] = 0 + queue.append(('vn', start_vn, None, None)) + + while queue: + ntype, nidx, ptype, pidx = queue.popleft() + d = dist[(ntype, nidx)] + + if ntype == 'vn': + for cn in vn_neighbors(nidx): + if ('cn', cn) == (ptype, pidx): + continue # don't go back the way we came + if ('cn', cn) in dist: + cycle_len = d + 1 + dist[('cn', cn)] + min_cycle = min(min_cycle, cycle_len) + else: + dist[('cn', cn)] = d + 1 + queue.append(('cn', cn, 'vn', nidx)) + else: # cn + for vn in cn_neighbors(nidx): + if ('vn', vn) == (ptype, pidx): + continue + if ('vn', vn) in dist: + cycle_len = d + 1 + dist[('vn', vn)] + min_cycle = min(min_cycle, cycle_len) + else: + dist[('vn', vn)] = d + 1 + queue.append(('vn', vn, 'cn', nidx)) + + # Early termination if we already found a short cycle + if min_cycle <= 6: + return min_cycle + + return min_cycle + + +def run_matrix_comparison(lam_b=0.1, n_frames=500, max_iter=30): + """ + Compare three base matrix designs for rate-1/8 code. + + 1. Original staircase (from ldpc_sim.py) + 2. Improved staircase (extra connections) + 3. PEG ring structure + """ + print("=" * 80) + print("ANALYSIS 2: Base Matrix Comparison (rate 1/8)") + print(f" n_frames={n_frames}, lam_b={lam_b}, max_iter={max_iter}") + print("=" * 80) + + # Build matrices + matrices = {} + + # Original staircase + H_orig_base = H_BASE.copy() + H_orig_full = build_full_h_matrix() + matrices['Original staircase'] = { + 'H_base': H_orig_base, 'H_full': H_orig_full, + 'encoder': 'staircase', + } + + # Improved staircase + H_imp_base, H_imp_full = build_improved_staircase(z=32) + matrices['Improved staircase'] = { + 'H_base': H_imp_base, 'H_full': H_imp_full, + 'encoder': 'gaussian', + } + + # PEG ring + H_peg_base, H_peg_full = build_peg_matrix(z=32) + matrices['PEG ring'] = { + 'H_base': H_peg_base, 'H_full': H_peg_full, + 'encoder': 'gaussian', + } + + # Print degree distributions and girth + print("\nDegree distributions and girth:") + print(f" {'Matrix':<22s} {'VN degrees':<30s} {'Girth':>6s}") + print(" " + "-" * 60) + for name, minfo in matrices.items(): + H_b = minfo['H_base'] + # VN degrees: count non-(-1) entries per column + vn_degrees = [] + for c in range(H_b.shape[1]): + d = np.sum(H_b[:, c] >= 0) + vn_degrees.append(d) + girth = compute_girth(H_b, 32) + girth_str = str(girth) if girth < float('inf') else 'inf' + deg_str = str(vn_degrees) + print(f" {name:<22s} {deg_str:<30s} {girth_str:>6s}") + + # Run simulations + lam_s_values = [0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 5.0, 7.0, 10.0] + all_results = [] + + # Header + header = f"\n{'lam_s':>8s}" + for name in matrices: + header += f" {name[:12]:>12s}" + print(header) + print("-" * len(header)) + + for lam_s in lam_s_values: + row = f"{lam_s:8.1f}" + for name, minfo in matrices.items(): + H_b = minfo['H_base'] + H_f = minfo['H_full'] + enc_type = minfo['encoder'] + + bit_errors = 0 + frame_errors = 0 + total_iter = 0 + k = 32 + + for _ in range(n_frames): + info = np.random.randint(0, 2, k).astype(np.int8) + + if enc_type == 'staircase': + codeword = ldpc_encode(info, H_f) + else: + codeword = peg_encode(info, H_b, H_f, z=32) + + llr_float, _ = poisson_channel(codeword, lam_s, lam_b) + llr_q = quantize_llr(llr_float) + + decoded, converged, iters, sw = generic_decode( + llr_q, H_b, z=32, max_iter=max_iter, q_bits=Q_BITS + ) + total_iter += iters + + errs = np.sum(decoded != info) + bit_errors += errs + if errs > 0: + frame_errors += 1 + + fer = frame_errors / n_frames + avg_iter = total_iter / n_frames + all_results.append({ + 'matrix': name, 'lam_s': lam_s, + 'fer': fer, 'avg_iter': avg_iter, + }) + + if fer == 0: + row += f" {'0.000':>12s}" + else: + row += f" {fer:12.3f}" + print(row) + + return all_results + + +# ============================================================================= +# Analysis 3: Quantization Sweep +# ============================================================================= + +def run_quant_sweep(lam_s=None, lam_b=0.1, n_frames=500, max_iter=30): + """ + Sweep quantization bit-width and measure FER. + + Uses the original H_BASE from ldpc_sim for consistency. + Includes 4,5,6,8,10,16 bits plus a "float" mode (16-bit with high scale). + """ + if lam_s is None: + lam_s_values = [2.0, 3.0, 5.0] + else: + lam_s_values = [lam_s] + + # Each entry: (label, actual_q_bits) + # "float" uses 16-bit quantization (effectively lossless for this range) + quant_configs = [ + ('4', 4), ('5', 5), ('6', 6), ('8', 8), ('10', 10), ('16', 16), + ('float', 16), + ] + + print("=" * 80) + print("ANALYSIS 3: Quantization Sweep (rate 1/8, original staircase)") + print(f" n_frames={n_frames}, lam_b={lam_b}, max_iter={max_iter}") + print("=" * 80) + + H_full = build_full_h_matrix() + all_results = [] + + # Header + header = f"{'lam_s':>8s}" + for label, _ in quant_configs: + header += f" {label + '-bit':>8s}" + print(header) + print("-" * len(header)) + + for lam_s_val in lam_s_values: + row = f"{lam_s_val:8.1f}" + + for label, actual_q in quant_configs: + bit_errors = 0 + frame_errors = 0 + total_iter = 0 + + for _ in range(n_frames): + info = np.random.randint(0, 2, K).astype(np.int8) + codeword = ldpc_encode(info, H_full) + llr_float, _ = poisson_channel(codeword, lam_s_val, lam_b) + llr_q = quantize_llr(llr_float, q_bits=actual_q) + + decoded, converged, iters, sw = generic_decode( + llr_q, H_BASE, z=Z, max_iter=max_iter, q_bits=actual_q + ) + total_iter += iters + + errs = np.sum(decoded != info) + bit_errors += errs + if errs > 0: + frame_errors += 1 + + fer = frame_errors / n_frames + all_results.append({ + 'q_bits': label, 'lam_s': lam_s_val, + 'fer': float(fer), + }) + + if fer == 0: + row += f" {'0.000':>8s}" + else: + row += f" {fer:8.3f}" + + print(row) + + return all_results + + +# ============================================================================= +# Analysis 4: Shannon Gap +# ============================================================================= + +def poisson_channel_capacity(lam_s, lam_b, max_y=50): + """ + Binary-input Poisson channel capacity. + + Uses scipy.stats.poisson for PMFs. + Optimizes over input probability p (0 to 1 in 1% steps). + C = max_p [H(Y) - p0*H(Y|X=0) - p1*H(Y|X=1)] + """ + from scipy.stats import poisson + + def entropy(pmf): + """Compute entropy of a discrete distribution.""" + pmf = pmf[pmf > 0] + return -np.sum(pmf * np.log2(pmf)) + + # PMFs for Y|X=0 and Y|X=1 + y_range = np.arange(0, max_y + 1) + py_given_0 = poisson.pmf(y_range, lam_b) # P(Y=y | X=0) + py_given_1 = poisson.pmf(y_range, lam_s + lam_b) # P(Y=y | X=1) + + # Conditional entropies + h_y_given_0 = entropy(py_given_0) + h_y_given_1 = entropy(py_given_1) + + best_capacity = 0.0 + best_p = 0.5 + + # Search over input probability p1 = P(X=1) + for p_int in range(0, 101): + p1 = p_int / 100.0 + p0 = 1.0 - p1 + + # P(Y=y) = p0 * P(Y=y|X=0) + p1 * P(Y=y|X=1) + py = p0 * py_given_0 + p1 * py_given_1 + + # C = H(Y) - H(Y|X) + h_y = entropy(py) + h_y_given_x = p0 * h_y_given_0 + p1 * h_y_given_1 + capacity = h_y - h_y_given_x + + if capacity > best_capacity: + best_capacity = capacity + best_p = p1 + + return best_capacity + + +def run_shannon_gap(lam_b=0.1): + """ + For each rate (1/2 through 1/8): binary search for minimum lambda_s + where channel capacity >= rate. + """ + print("=" * 80) + print("ANALYSIS 4: Shannon Gap (Binary-Input Poisson Channel)") + print(f" lam_b={lam_b}") + print("=" * 80) + + rates = [ + (1, '1/2', 0.5), + (2, '1/3', 1.0 / 3), + (3, '1/4', 0.25), + (5, '1/6', 1.0 / 6), + (7, '1/8', 0.125), + ] + + results = [] + print(f"\n{'Rate':>8s} {'Shannon lam_s':>14s} {'Capacity':>10s}") + print("-" * 35) + + for m_base, label, rate_val in rates: + # Binary search for minimum lam_s where C >= rate_val + lo = 0.001 + hi = 50.0 + + # First check if capacity at hi is sufficient + c_hi = poisson_channel_capacity(hi, lam_b) + if c_hi < rate_val: + print(f"{label:>8s} {'> 50':>14s} {c_hi:10.4f}") + results.append({ + 'rate': label, 'rate_val': rate_val, + 'shannon_lam_s': None, 'capacity': c_hi, + }) + continue + + for _ in range(50): # binary search iterations + mid = (lo + hi) / 2 + c_mid = poisson_channel_capacity(mid, lam_b) + if c_mid >= rate_val: + hi = mid + else: + lo = mid + if hi - lo < 0.001: + break + + shannon_lam_s = hi + c_final = poisson_channel_capacity(shannon_lam_s, lam_b) + print(f"{label:>8s} {shannon_lam_s:14.3f} {c_final:10.4f}") + results.append({ + 'rate': label, 'rate_val': rate_val, + 'shannon_lam_s': float(shannon_lam_s), + 'capacity': float(c_final), + }) + + print("\nInterpretation:") + print(" Shannon limit = minimum lam_s where C(lam_s) >= R.") + print(" Practical LDPC codes operate ~0.5-2 dB above Shannon limit.") + print(" Lower rates need fewer photons/slot (lower lam_s threshold).") + print(" Gap to Shannon = 10*log10(lam_s_practical / lam_s_shannon) dB.") + + return results + + +# ============================================================================= +# CLI +# ============================================================================= + +def main(): + parser = argparse.ArgumentParser( + description='LDPC Code Analysis Tool', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python3 ldpc_analysis.py --rate-sweep + python3 ldpc_analysis.py --matrix-compare + python3 ldpc_analysis.py --quant-sweep + python3 ldpc_analysis.py --shannon-gap + python3 ldpc_analysis.py --all + """ + ) + parser.add_argument('--rate-sweep', action='store_true', + help='Compare FER across code rates') + parser.add_argument('--matrix-compare', action='store_true', + help='Compare base matrix designs') + parser.add_argument('--quant-sweep', action='store_true', + help='Sweep quantization bit-width') + parser.add_argument('--shannon-gap', action='store_true', + help='Compute Shannon limits') + parser.add_argument('--all', action='store_true', + help='Run all analyses') + parser.add_argument('--n-frames', type=int, default=500, + help='Frames per simulation point (default: 500)') + parser.add_argument('--lam-b', type=float, default=0.1, + help='Background photon rate (default: 0.1)') + parser.add_argument('--seed', type=int, default=42, + help='Random seed (default: 42)') + args = parser.parse_args() + + np.random.seed(args.seed) + + all_results = {} + ran_any = False + + if args.rate_sweep or args.all: + ran_any = True + results = run_rate_sweep( + lam_b=args.lam_b, n_frames=args.n_frames, max_iter=30 + ) + all_results['rate_sweep'] = results + print() + + if args.matrix_compare or args.all: + ran_any = True + results = run_matrix_comparison( + lam_b=args.lam_b, n_frames=args.n_frames, max_iter=30 + ) + all_results['matrix_compare'] = results + print() + + if args.quant_sweep or args.all: + ran_any = True + results = run_quant_sweep( + lam_b=args.lam_b, n_frames=args.n_frames, max_iter=30 + ) + all_results['quant_sweep'] = results + print() + + if args.shannon_gap or args.all: + ran_any = True + results = run_shannon_gap(lam_b=args.lam_b) + all_results['shannon_gap'] = results + print() + + if not ran_any: + parser.print_help() + return + + # Save results + out_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data') + os.makedirs(out_dir, exist_ok=True) + out_path = os.path.join(out_dir, 'analysis_results.json') + with open(out_path, 'w') as f: + json.dump(all_results, f, indent=2, default=str) + print(f"Results saved to {out_path}") + + +if __name__ == '__main__': + main() diff --git a/model/test_ldpc_analysis.py b/model/test_ldpc_analysis.py new file mode 100644 index 0000000..39f20e7 --- /dev/null +++ b/model/test_ldpc_analysis.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python3 +""" +Tests for LDPC code analysis module (ldpc_analysis.py). + +Covers IRA staircase code construction and rate sweep functionality. + +Run: + python3 -m pytest model/test_ldpc_analysis.py -v +""" + +import numpy as np +import pytest + +from ldpc_analysis import ( + build_ira_staircase, + ira_encode, + generic_decode, + rate_sweep_single, +) +from ldpc_sim import poisson_channel, quantize_llr + + +# ============================================================================= +# TestIRAStaircase +# ============================================================================= + +class TestIRAStaircase: + """Validate generic IRA staircase code construction for various rates.""" + + def test_rate_1_2(self): + """m_base=1 -> rate 1/2: (1,2) base matrix, (32,64) full matrix, rank=32.""" + H_base, H_full = build_ira_staircase(m_base=1, z=32) + assert H_base.shape == (1, 2), f"Base shape: {H_base.shape}" + assert H_full.shape == (32, 64), f"Full shape: {H_full.shape}" + rank = np.linalg.matrix_rank(H_full.astype(float)) + assert rank == 32, f"Rank: {rank}, expected 32" + + def test_rate_1_4(self): + """m_base=3 -> rate 1/4: (3,4) base matrix, (96,128) full matrix, rank=96.""" + H_base, H_full = build_ira_staircase(m_base=3, z=32) + assert H_base.shape == (3, 4), f"Base shape: {H_base.shape}" + assert H_full.shape == (96, 128), f"Full shape: {H_full.shape}" + rank = np.linalg.matrix_rank(H_full.astype(float)) + assert rank == 96, f"Rank: {rank}, expected 96" + + def test_rate_1_8_matches_existing(self): + """m_base=7 -> rate 1/8: (7,8) base matrix, (224,256) full matrix.""" + H_base, H_full = build_ira_staircase(m_base=7, z=32) + assert H_base.shape == (7, 8), f"Base shape: {H_base.shape}" + assert H_full.shape == (224, 256), f"Full shape: {H_full.shape}" + + def test_encode_decode_roundtrip(self): + """For m_base in [1,2,3,5], encode random info, verify syndrome=0.""" + np.random.seed(42) + for m_base in [1, 2, 3, 5]: + H_base, H_full = build_ira_staircase(m_base=m_base, z=32) + z = 32 + k = z # info bits = Z + info = np.random.randint(0, 2, k).astype(np.int8) + codeword = ira_encode(info, H_base, H_full, z=z) + syndrome = H_full @ codeword % 2 + assert np.all(syndrome == 0), ( + f"m_base={m_base}: syndrome weight = {syndrome.sum()}" + ) + + +# ============================================================================= +# TestRateSweep +# ============================================================================= + +class TestRateSweep: + """Validate rate sweep simulation.""" + + def test_high_snr_all_rates_decode(self): + """At lam_s=20, all rates should achieve FER=0 with 10 frames.""" + np.random.seed(42) + for m_base in [1, 2, 3, 5, 7]: + result = rate_sweep_single( + m_base=m_base, z=32, lam_s=20.0, lam_b=0.1, + n_frames=10, max_iter=30, q_bits=6, + ) + assert result['fer'] == 0.0, ( + f"m_base={m_base} (rate 1/{m_base+1}): FER={result['fer']} at lam_s=20, " + f"expected 0.0" + )