#!/usr/bin/env python3 """ LDPC Code Analysis Tool Four analyses for the photon-starved optical LDPC decoder: 1. Rate comparison (--rate-sweep) 2. Base matrix comparison (--matrix-compare) 3. Quantization sweep (--quant-sweep) 4. Shannon gap analysis (--shannon-gap) Usage: python3 ldpc_analysis.py --rate-sweep python3 ldpc_analysis.py --matrix-compare python3 ldpc_analysis.py --quant-sweep python3 ldpc_analysis.py --shannon-gap python3 ldpc_analysis.py --all """ import numpy as np import argparse import json import os import sys from collections import deque from ldpc_sim import ( poisson_channel, quantize_llr, decode_layered_min_sum, compute_syndrome_weight, sat_add_q, sat_sub_q, min_sum_cn_update, Q_BITS, Q_MAX, Q_MIN, OFFSET, N, K, M, Z, N_BASE, M_BASE, H_BASE, build_full_h_matrix, ldpc_encode, ) # ============================================================================= # Core infrastructure # ============================================================================= def build_ira_staircase(m_base, z=32): """ Build generic IRA staircase QC-LDPC code for any rate. Rate = 1/(m_base+1), n_base = m_base+1. Col 0 (info) connects to ALL m_base rows with spread shifts. Lower-triangular staircase parity: row r connects to col r+1 (diagonal), row r>0 also connects to col r (sub-diagonal). Returns: (H_base, H_full) where H_full is the expanded binary matrix. """ n_base = m_base + 1 H_b = -np.ones((m_base, n_base), dtype=np.int16) # Column 0 (info) connects to all rows with spread shifts for r in range(m_base): H_b[r, 0] = (r * z // m_base) % z # Staircase parity structure for r in range(m_base): # Diagonal: row r connects to col r+1 if r == 0: H_b[r, r + 1] = (z // 4) % z # some spread for row 0 else: H_b[r, r + 1] = 0 # identity for sub-diagonal rows # Sub-diagonal: row r>0 connects to col r if r > 0: H_b[r, r] = (r * 7) % z # spread shift # Build full expanded matrix m_full = m_base * z n_full = n_base * z H_full = np.zeros((m_full, n_full), dtype=np.int8) for r in range(m_base): for c in range(n_base): shift = int(H_b[r, c]) if shift < 0: continue for zz in range(z): col_idx = c * z + (zz + shift) % z H_full[r * z + zz, col_idx] = 1 return H_b, H_full def ira_encode(info, H_base, H_full, z=32): """ Generic IRA encoder using sequential staircase solve. Row 0: solve for p[1]. Rows 1..m_base-1: solve for p[2]..p[m_base]. Asserts syndrome = 0. """ m_base = H_base.shape[0] n_base = H_base.shape[1] k = z assert len(info) == k def apply_p(x, s): """Apply circulant P_s: result[i] = x[(i+s)%Z].""" return np.roll(x, -int(s)) def inv_p(y, s): """Apply P_s inverse.""" return np.roll(y, int(s)) # Parity blocks: p[c] for c = 1..n_base-1 p = [np.zeros(z, dtype=np.int8) for _ in range(n_base)] # Step 1: Solve row 0 for p[1] # Accumulate contributions from all connected columns except col 1 (the new unknown) accum = np.zeros(z, dtype=np.int8) for c in range(n_base): shift = int(H_base[0, c]) if shift < 0: continue if c == 0: accum = (accum + apply_p(info, shift)) % 2 elif c == 1: # This is the unknown we're solving for pass else: accum = (accum + apply_p(p[c], shift)) % 2 # p[1] = P_{H[0][1]}^-1 * accum p[1] = inv_p(accum, H_base[0, 1]) # Step 2: Solve rows 1..m_base-1 for p[2]..p[m_base] for r in range(1, m_base): accum = np.zeros(z, dtype=np.int8) target_col = r + 1 # the new unknown column for c in range(n_base): shift = int(H_base[r, c]) if shift < 0: continue if c == target_col: continue # skip the unknown if c == 0: accum = (accum + apply_p(info, shift)) % 2 else: accum = (accum + apply_p(p[c], shift)) % 2 p[target_col] = inv_p(accum, H_base[r, target_col]) # Assemble codeword codeword = np.concatenate([info] + [p[c] for c in range(1, n_base)]) # Verify check = H_full @ codeword % 2 assert np.all(check == 0), f"IRA encoding failed: syndrome weight = {check.sum()}" return codeword def generic_decode(llr_q, H_base, z=32, max_iter=30, early_term=True, q_bits=6, cn_mode='offset', alpha=0.75): """ Parameterized layered min-sum decoder for any QC-LDPC base matrix. Handles variable check node degree (different rows may have different numbers of connected columns). Args: llr_q: quantized channel LLRs H_base: base matrix (m_base x n_base), -1 = no connection z: lifting factor max_iter: maximum iterations early_term: stop when syndrome is zero q_bits: quantization bits cn_mode: 'offset' or 'normalized' alpha: scaling factor for normalized mode (default 0.75) Returns: (decoded_info_bits, converged, iterations, syndrome_weight) """ m_base_local = H_base.shape[0] n_base_local = H_base.shape[1] n_local = n_base_local * z k_local = z q_max = 2**(q_bits-1) - 1 q_min = -(2**(q_bits-1)) def sat_add(a, b): s = int(a) + int(b) return max(q_min, min(q_max, s)) def sat_sub(a, b): return sat_add(a, -b) def cn_update(msgs_in, offset=1): dc = len(msgs_in) signs = [1 if m < 0 else 0 for m in msgs_in] mags = [abs(m) for m in msgs_in] sign_xor = sum(signs) % 2 min1 = q_max min2 = q_max min1_idx = 0 for i in range(dc): if mags[i] < min1: min2 = min1 min1 = mags[i] min1_idx = i elif mags[i] < min2: min2 = mags[i] msgs_out = [] for j in range(dc): mag = min2 if j == min1_idx else min1 if cn_mode == 'normalized': mag = int(mag * alpha) else: mag = max(0, mag - offset) sgn = sign_xor ^ signs[j] val = -mag if sgn else mag msgs_out.append(val) return msgs_out # Initialize beliefs from channel LLRs beliefs = [int(x) for x in llr_q] # Initialize CN->VN messages to zero msg = [[[0 for _ in range(z)] for _ in range(n_base_local)] for _ in range(m_base_local)] for iteration in range(max_iter): for row in range(m_base_local): connected_cols = [c for c in range(n_base_local) if H_base[row, c] >= 0] dc = len(connected_cols) # VN->CN messages vn_to_cn = [[0]*z for _ in range(dc)] for ci, col in enumerate(connected_cols): shift = int(H_base[row, col]) for zz in range(z): shifted_z = (zz + shift) % z bit_idx = col * z + shifted_z old_msg = msg[row][col][zz] vn_to_cn[ci][zz] = sat_sub(beliefs[bit_idx], old_msg) # CN update cn_to_vn = [[0]*z for _ in range(dc)] for zz in range(z): cn_inputs = [vn_to_cn[ci][zz] for ci in range(dc)] cn_outputs = cn_update(cn_inputs) for ci in range(dc): cn_to_vn[ci][zz] = cn_outputs[ci] # Update beliefs and store new messages for ci, col in enumerate(connected_cols): shift = int(H_base[row, col]) for zz in range(z): shifted_z = (zz + shift) % z bit_idx = col * z + shifted_z new_msg = cn_to_vn[ci][zz] extrinsic = vn_to_cn[ci][zz] beliefs[bit_idx] = sat_add(extrinsic, new_msg) msg[row][col][zz] = new_msg # Syndrome check hard = [1 if b < 0 else 0 for b in beliefs] sw = _compute_syndrome_weight(hard, H_base, z) if early_term and sw == 0: return np.array(hard[:k_local]), True, iteration + 1, 0 hard = [1 if b < 0 else 0 for b in beliefs] sw = _compute_syndrome_weight(hard, H_base, z) return np.array(hard[:k_local]), sw == 0, max_iter, sw def _compute_syndrome_weight(hard_bits, H_base, z): """Compute syndrome weight for a generic base matrix.""" m_base_local = H_base.shape[0] n_base_local = H_base.shape[1] weight = 0 for r in range(m_base_local): for zz in range(z): parity = 0 for c in range(n_base_local): shift = int(H_base[r, c]) if shift < 0: continue shifted_z = (zz + shift) % z bit_idx = c * z + shifted_z parity ^= hard_bits[bit_idx] if parity: weight += 1 return weight # ============================================================================= # Analysis 1: Rate Comparison # ============================================================================= def rate_sweep_single(m_base, z, lam_s, lam_b, n_frames, max_iter, q_bits): """ Run n_frames through encode/channel/decode for one rate and one lambda_s. Returns dict with rate, rate_val, m_base, lam_s, ber, fer, avg_iter. """ H_base, H_full = build_ira_staircase(m_base=m_base, z=z) k = z n = (m_base + 1) * z rate_val = 1.0 / (m_base + 1) bit_errors = 0 frame_errors = 0 total_bits = 0 total_iter = 0 for _ in range(n_frames): info = np.random.randint(0, 2, k).astype(np.int8) codeword = ira_encode(info, H_base, H_full, z=z) llr_float, _ = poisson_channel(codeword, lam_s, lam_b) llr_q = quantize_llr(llr_float, q_bits=q_bits) decoded, converged, iters, sw = generic_decode( llr_q, H_base, z=z, max_iter=max_iter, q_bits=q_bits ) total_iter += iters errs = np.sum(decoded != info) bit_errors += errs total_bits += k if errs > 0: frame_errors += 1 ber = bit_errors / total_bits if total_bits > 0 else 0 fer = frame_errors / n_frames avg_iter = total_iter / n_frames return { 'rate': f'1/{m_base+1}', 'rate_val': rate_val, 'm_base': m_base, 'lam_s': lam_s, 'ber': float(ber), 'fer': float(fer), 'avg_iter': float(avg_iter), } def run_rate_sweep(lam_b=0.1, n_frames=500, max_iter=30): """ Compare FER across multiple code rates. Rates: 1/2, 1/3, 1/4, 1/6, 1/8. """ rates = [(1, '1/2'), (2, '1/3'), (3, '1/4'), (5, '1/6'), (7, '1/8')] lam_s_values = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0, 5.0, 7.0, 10.0] print("=" * 80) print("ANALYSIS 1: Rate Comparison (IRA staircase codes)") print(f" n_frames={n_frames}, lam_b={lam_b}, max_iter={max_iter}") print("=" * 80) all_results = [] # Header header = f"{'lam_s':>8s}" for m, label in rates: header += f" {label:>8s}" print(header) print("-" * len(header)) for lam_s in lam_s_values: row = f"{lam_s:8.1f}" for m_base, label in rates: result = rate_sweep_single( m_base=m_base, z=32, lam_s=lam_s, lam_b=lam_b, n_frames=n_frames, max_iter=max_iter, q_bits=Q_BITS, ) all_results.append(result) fer = result['fer'] if fer == 0: row += f" {'0.000':>8s}" else: row += f" {fer:8.3f}" print(row) # Threshold summary print("\nThreshold summary (lowest lam_s where FER < 10%):") for m_base, label in rates: threshold = None for lam_s in lam_s_values: matching = [r for r in all_results if r['m_base'] == m_base and r['lam_s'] == lam_s] if matching and matching[0]['fer'] < 0.10: threshold = lam_s break if threshold is not None: print(f" Rate {label}: lam_s >= {threshold}") else: print(f" Rate {label}: FER >= 10% at all tested lam_s") return all_results # ============================================================================= # Analysis 2: Base Matrix Comparison # ============================================================================= def build_improved_staircase(z=32): """ Improved staircase: same base structure as original but adds extra connections to low-degree columns. Key: col 7 gets extra connection (dv=1->dv=2), col 1 gets extra (dv=2->dv=3). Extra connections are BELOW the staircase diagonal to preserve lower-triangular parity structure and full rank. """ m_base = 7 n_base = 8 # Start from original H_BASE H_b = H_BASE.copy() # Add extra connection for col 7 (currently only connected from row 6) # Row 0 is below col 7's diagonal (row 6) in the staircase sense H_b[0, 7] = 3 # Add extra connection for col 1 (currently rows 0,1) # Row 4 is below col 1's diagonal (row 0) H_b[4, 1] = 15 # Build full matrix m_full = m_base * z n_full = n_base * z H_full = np.zeros((m_full, n_full), dtype=np.int8) for r in range(m_base): for c in range(n_base): shift = int(H_b[r, c]) if shift < 0: continue for zz in range(z): col_idx = c * z + (zz + shift) % z H_full[r * z + zz, col_idx] = 1 return H_b, H_full def build_peg_matrix(z=32): """ PEG-like construction with more uniform degree distribution. Uses staircase parity backbone (guaranteeing full parity rank and sequential encoding) plus below-diagonal cross-connections for better girth and degree distribution. VN degrees: col0=7, cols1-3=3, cols4-7=2 (avg 2.75 vs 2.5 original). All parity columns have dv >= 2 (no weak degree-1 nodes). """ m_base = 7 n_base = 8 H_b = -np.ones((m_base, n_base), dtype=np.int16) # Column 0 (info): connect to all 7 rows with spread shifts H_b[0, 0] = 0; H_b[1, 0] = 7; H_b[2, 0] = 13 H_b[3, 0] = 19; H_b[4, 0] = 25; H_b[5, 0] = 3; H_b[6, 0] = 9 # Staircase parity backbone (lower-triangular, full rank guaranteed) H_b[0, 1] = 5 H_b[1, 1] = 3; H_b[1, 2] = 0 H_b[2, 2] = 11; H_b[2, 3] = 0 H_b[3, 3] = 17; H_b[3, 4] = 0 H_b[4, 4] = 23; H_b[4, 5] = 0 H_b[5, 5] = 29; H_b[5, 6] = 0 H_b[6, 6] = 5; H_b[6, 7] = 0 # Below-diagonal cross-connections for better degree distribution H_b[3, 1] = 21 # col 1 dv=2->3 H_b[5, 2] = 15 # col 2 dv=2->3 H_b[6, 3] = 27 # col 3 dv=2->3 H_b[0, 7] = 0 # col 7 dv=1->2 # Build full matrix m_full = m_base * z n_full = n_base * z H_full = np.zeros((m_full, n_full), dtype=np.int8) for r in range(m_base): for c in range(n_base): shift = int(H_b[r, c]) if shift < 0: continue for zz in range(z): col_idx = c * z + (zz + shift) % z H_full[r * z + zz, col_idx] = 1 return H_b, H_full def peg_encode(info, H_base, H_full, z=32): """ GF(2) Gaussian elimination encoder for non-staircase matrices. Solves H_parity * p = H_info * info mod 2. Works for any base matrix structure. """ m_base_local = H_base.shape[0] n_base_local = H_base.shape[1] m_full = m_base_local * z n_full = n_base_local * z k = z assert len(info) == k # H_full = [H_info | H_parity] # H_info is first z columns, H_parity is the rest H_info = H_full[:, :k] H_parity = H_full[:, k:] # RHS: H_info * info mod 2 rhs = (H_info @ info) % 2 # Solve H_parity * p = rhs mod 2 via Gaussian elimination n_parity = n_full - k # Augmented matrix [H_parity | rhs] aug = np.zeros((m_full, n_parity + 1), dtype=np.int8) aug[:, :n_parity] = H_parity.copy() aug[:, n_parity] = rhs # Forward elimination pivot_row = 0 pivot_cols = [] for col in range(n_parity): # Find pivot found = -1 for r in range(pivot_row, m_full): if aug[r, col] == 1: found = r break if found < 0: continue # Swap rows if found != pivot_row: aug[[pivot_row, found]] = aug[[found, pivot_row]] # Eliminate for r in range(m_full): if r != pivot_row and aug[r, col] == 1: aug[r] = (aug[r] + aug[pivot_row]) % 2 pivot_cols.append(col) pivot_row += 1 # Back-substitute parity = np.zeros(n_parity, dtype=np.int8) for i, col in enumerate(pivot_cols): parity[col] = aug[i, n_parity] # Assemble codeword codeword = np.concatenate([info, parity]) # Verify check = H_full @ codeword % 2 assert np.all(check == 0), f"PEG encoding failed: syndrome weight = {check.sum()}" return codeword def compute_girth(H_base, z): """ BFS-based shortest cycle detection in Tanner graph. Samples a subset of VNs for speed. Returns the girth (length of shortest cycle) or infinity if no cycle found. """ m_base_local = H_base.shape[0] n_base_local = H_base.shape[1] # Build adjacency: VN -> list of CN, CN -> list of VN # In QC graph, VN (c, zz) connects to CN (r, (zz + shift) % z) for each row r # where H_base[r, c] >= 0 n_vn = n_base_local * z n_cn = m_base_local * z def vn_neighbors(vn_idx): """CNs connected to this VN.""" c = vn_idx // z zz = vn_idx % z neighbors = [] for r in range(m_base_local): shift = int(H_base[r, c]) if shift < 0: continue cn_idx = r * z + (zz + shift) % z neighbors.append(cn_idx) return neighbors def cn_neighbors(cn_idx): """VNs connected to this CN.""" r = cn_idx // z zz_cn = cn_idx % z neighbors = [] for c in range(n_base_local): shift = int(H_base[r, c]) if shift < 0: continue # CN (r, zz_cn) connects to VN (c, (zz_cn - shift) % z) vn_zz = (zz_cn - shift) % z vn_idx = c * z + vn_zz neighbors.append(vn_idx) return neighbors min_cycle = float('inf') # Sample VNs: pick a few from each base column sample_vns = [] for c in range(n_base_local): for zz in [0, z // 4, z // 2]: sample_vns.append(c * z + zz) for start_vn in sample_vns: # BFS from start_vn through bipartite graph # Track (node_type, node_idx, parent_type, parent_idx) # node_type: 'vn' or 'cn' dist = {} queue = deque() dist[('vn', start_vn)] = 0 queue.append(('vn', start_vn, None, None)) while queue: ntype, nidx, ptype, pidx = queue.popleft() d = dist[(ntype, nidx)] if ntype == 'vn': for cn in vn_neighbors(nidx): if ('cn', cn) == (ptype, pidx): continue # don't go back the way we came if ('cn', cn) in dist: cycle_len = d + 1 + dist[('cn', cn)] min_cycle = min(min_cycle, cycle_len) else: dist[('cn', cn)] = d + 1 queue.append(('cn', cn, 'vn', nidx)) else: # cn for vn in cn_neighbors(nidx): if ('vn', vn) == (ptype, pidx): continue if ('vn', vn) in dist: cycle_len = d + 1 + dist[('vn', vn)] min_cycle = min(min_cycle, cycle_len) else: dist[('vn', vn)] = d + 1 queue.append(('vn', vn, 'cn', nidx)) # Early termination if we already found a short cycle if min_cycle <= 6: return min_cycle return min_cycle def run_matrix_comparison(lam_b=0.1, n_frames=500, max_iter=30): """ Compare three base matrix designs for rate-1/8 code. 1. Original staircase (from ldpc_sim.py) 2. Improved staircase (extra connections) 3. PEG ring structure """ print("=" * 80) print("ANALYSIS 2: Base Matrix Comparison (rate 1/8)") print(f" n_frames={n_frames}, lam_b={lam_b}, max_iter={max_iter}") print("=" * 80) # Build matrices matrices = {} # Original staircase H_orig_base = H_BASE.copy() H_orig_full = build_full_h_matrix() matrices['Original staircase'] = { 'H_base': H_orig_base, 'H_full': H_orig_full, 'encoder': 'staircase', } # Improved staircase H_imp_base, H_imp_full = build_improved_staircase(z=32) matrices['Improved staircase'] = { 'H_base': H_imp_base, 'H_full': H_imp_full, 'encoder': 'gaussian', } # PEG ring H_peg_base, H_peg_full = build_peg_matrix(z=32) matrices['PEG ring'] = { 'H_base': H_peg_base, 'H_full': H_peg_full, 'encoder': 'gaussian', } # Print degree distributions and girth print("\nDegree distributions and girth:") print(f" {'Matrix':<22s} {'VN degrees':<30s} {'Girth':>6s}") print(" " + "-" * 60) for name, minfo in matrices.items(): H_b = minfo['H_base'] # VN degrees: count non-(-1) entries per column vn_degrees = [] for c in range(H_b.shape[1]): d = np.sum(H_b[:, c] >= 0) vn_degrees.append(d) girth = compute_girth(H_b, 32) girth_str = str(girth) if girth < float('inf') else 'inf' deg_str = str([int(d) for d in vn_degrees]) print(f" {name:<22s} {deg_str:<30s} {girth_str:>6s}") # Run simulations lam_s_values = [0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 5.0, 7.0, 10.0] all_results = [] # Header header = f"\n{'lam_s':>8s}" for name in matrices: header += f" {name[:12]:>12s}" print(header) print("-" * len(header)) for lam_s in lam_s_values: row = f"{lam_s:8.1f}" for name, minfo in matrices.items(): H_b = minfo['H_base'] H_f = minfo['H_full'] enc_type = minfo['encoder'] bit_errors = 0 frame_errors = 0 total_iter = 0 k = 32 for _ in range(n_frames): info = np.random.randint(0, 2, k).astype(np.int8) if enc_type == 'staircase': codeword = ldpc_encode(info, H_f) else: codeword = peg_encode(info, H_b, H_f, z=32) llr_float, _ = poisson_channel(codeword, lam_s, lam_b) llr_q = quantize_llr(llr_float) decoded, converged, iters, sw = generic_decode( llr_q, H_b, z=32, max_iter=max_iter, q_bits=Q_BITS ) total_iter += iters errs = np.sum(decoded != info) bit_errors += errs if errs > 0: frame_errors += 1 fer = frame_errors / n_frames avg_iter = total_iter / n_frames all_results.append({ 'matrix': name, 'lam_s': lam_s, 'fer': fer, 'avg_iter': avg_iter, }) if fer == 0: row += f" {'0.000':>12s}" else: row += f" {fer:12.3f}" print(row) return all_results # ============================================================================= # Analysis 3: Quantization Sweep # ============================================================================= def run_quant_sweep(lam_s=None, lam_b=0.1, n_frames=500, max_iter=30): """ Sweep quantization bit-width and measure FER. Uses the original H_BASE from ldpc_sim for consistency. Includes 4,5,6,8,10,16 bits plus a "float" mode (16-bit with high scale). """ if lam_s is None: lam_s_values = [2.0, 3.0, 5.0] else: lam_s_values = [lam_s] # Each entry: (label, actual_q_bits) # "float" uses 16-bit quantization (effectively lossless for this range) quant_configs = [ ('4', 4), ('5', 5), ('6', 6), ('8', 8), ('10', 10), ('16', 16), ('float', 16), ] print("=" * 80) print("ANALYSIS 3: Quantization Sweep (rate 1/8, original staircase)") print(f" n_frames={n_frames}, lam_b={lam_b}, max_iter={max_iter}") print("=" * 80) H_full = build_full_h_matrix() all_results = [] # Header header = f"{'lam_s':>8s}" for label, _ in quant_configs: header += f" {label + '-bit':>8s}" print(header) print("-" * len(header)) for lam_s_val in lam_s_values: row = f"{lam_s_val:8.1f}" for label, actual_q in quant_configs: bit_errors = 0 frame_errors = 0 total_iter = 0 for _ in range(n_frames): info = np.random.randint(0, 2, K).astype(np.int8) codeword = ldpc_encode(info, H_full) llr_float, _ = poisson_channel(codeword, lam_s_val, lam_b) llr_q = quantize_llr(llr_float, q_bits=actual_q) decoded, converged, iters, sw = generic_decode( llr_q, H_BASE, z=Z, max_iter=max_iter, q_bits=actual_q ) total_iter += iters errs = np.sum(decoded != info) bit_errors += errs if errs > 0: frame_errors += 1 fer = frame_errors / n_frames all_results.append({ 'q_bits': label, 'lam_s': lam_s_val, 'fer': float(fer), }) if fer == 0: row += f" {'0.000':>8s}" else: row += f" {fer:8.3f}" print(row) return all_results # ============================================================================= # Analysis 4: Shannon Gap # ============================================================================= def poisson_channel_capacity(lam_s, lam_b, max_y=50): """ Binary-input Poisson channel capacity. Uses scipy.stats.poisson for PMFs. Optimizes over input probability p (0 to 1 in 1% steps). C = max_p [H(Y) - p0*H(Y|X=0) - p1*H(Y|X=1)] """ from scipy.stats import poisson def entropy(pmf): """Compute entropy of a discrete distribution.""" pmf = pmf[pmf > 0] return -np.sum(pmf * np.log2(pmf)) # PMFs for Y|X=0 and Y|X=1 y_range = np.arange(0, max_y + 1) py_given_0 = poisson.pmf(y_range, lam_b) # P(Y=y | X=0) py_given_1 = poisson.pmf(y_range, lam_s + lam_b) # P(Y=y | X=1) # Conditional entropies h_y_given_0 = entropy(py_given_0) h_y_given_1 = entropy(py_given_1) best_capacity = 0.0 best_p = 0.5 # Search over input probability p1 = P(X=1) for p_int in range(0, 101): p1 = p_int / 100.0 p0 = 1.0 - p1 # P(Y=y) = p0 * P(Y=y|X=0) + p1 * P(Y=y|X=1) py = p0 * py_given_0 + p1 * py_given_1 # C = H(Y) - H(Y|X) h_y = entropy(py) h_y_given_x = p0 * h_y_given_0 + p1 * h_y_given_1 capacity = h_y - h_y_given_x if capacity > best_capacity: best_capacity = capacity best_p = p1 return best_capacity def run_shannon_gap(lam_b=0.1): """ For each rate (1/2 through 1/8): binary search for minimum lambda_s where channel capacity >= rate. """ print("=" * 80) print("ANALYSIS 4: Shannon Gap (Binary-Input Poisson Channel)") print(f" lam_b={lam_b}") print("=" * 80) rates = [ (1, '1/2', 0.5), (2, '1/3', 1.0 / 3), (3, '1/4', 0.25), (5, '1/6', 1.0 / 6), (7, '1/8', 0.125), ] results = [] print(f"\n{'Rate':>8s} {'Shannon lam_s':>14s} {'Capacity':>10s}") print("-" * 35) for m_base, label, rate_val in rates: # Binary search for minimum lam_s where C >= rate_val lo = 0.001 hi = 50.0 # First check if capacity at hi is sufficient c_hi = poisson_channel_capacity(hi, lam_b) if c_hi < rate_val: print(f"{label:>8s} {'> 50':>14s} {c_hi:10.4f}") results.append({ 'rate': label, 'rate_val': rate_val, 'shannon_lam_s': None, 'capacity': c_hi, }) continue for _ in range(50): # binary search iterations mid = (lo + hi) / 2 c_mid = poisson_channel_capacity(mid, lam_b) if c_mid >= rate_val: hi = mid else: lo = mid if hi - lo < 0.001: break shannon_lam_s = hi c_final = poisson_channel_capacity(shannon_lam_s, lam_b) print(f"{label:>8s} {shannon_lam_s:14.3f} {c_final:10.4f}") results.append({ 'rate': label, 'rate_val': rate_val, 'shannon_lam_s': float(shannon_lam_s), 'capacity': float(c_final), }) print("\nInterpretation:") print(" Shannon limit = minimum lam_s where C(lam_s) >= R.") print(" Practical LDPC codes operate ~0.5-2 dB above Shannon limit.") print(" Lower rates need fewer photons/slot (lower lam_s threshold).") print(" Gap to Shannon = 10*log10(lam_s_practical / lam_s_shannon) dB.") return results # ============================================================================= # CLI # ============================================================================= def main(): parser = argparse.ArgumentParser( description='LDPC Code Analysis Tool', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: python3 ldpc_analysis.py --rate-sweep python3 ldpc_analysis.py --matrix-compare python3 ldpc_analysis.py --quant-sweep python3 ldpc_analysis.py --shannon-gap python3 ldpc_analysis.py --all """ ) parser.add_argument('--rate-sweep', action='store_true', help='Compare FER across code rates') parser.add_argument('--matrix-compare', action='store_true', help='Compare base matrix designs') parser.add_argument('--quant-sweep', action='store_true', help='Sweep quantization bit-width') parser.add_argument('--shannon-gap', action='store_true', help='Compute Shannon limits') parser.add_argument('--all', action='store_true', help='Run all analyses') parser.add_argument('--n-frames', type=int, default=500, help='Frames per simulation point (default: 500)') parser.add_argument('--lam-b', type=float, default=0.1, help='Background photon rate (default: 0.1)') parser.add_argument('--seed', type=int, default=42, help='Random seed (default: 42)') args = parser.parse_args() np.random.seed(args.seed) all_results = {} ran_any = False if args.rate_sweep or args.all: ran_any = True results = run_rate_sweep( lam_b=args.lam_b, n_frames=args.n_frames, max_iter=30 ) all_results['rate_sweep'] = results print() if args.matrix_compare or args.all: ran_any = True results = run_matrix_comparison( lam_b=args.lam_b, n_frames=args.n_frames, max_iter=30 ) all_results['matrix_compare'] = results print() if args.quant_sweep or args.all: ran_any = True results = run_quant_sweep( lam_b=args.lam_b, n_frames=args.n_frames, max_iter=30 ) all_results['quant_sweep'] = results print() if args.shannon_gap or args.all: ran_any = True results = run_shannon_gap(lam_b=args.lam_b) all_results['shannon_gap'] = results print() if not ran_any: parser.print_help() return # Save results out_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data') os.makedirs(out_dir, exist_ok=True) out_path = os.path.join(out_dir, 'analysis_results.json') with open(out_path, 'w') as f: json.dump(all_results, f, indent=2, default=str) print(f"Results saved to {out_path}") if __name__ == '__main__': main()