Files
ldpc_optical/model/ldpc_analysis.py
cah 1967ae90e4 Fix matrix rank issues and run all code analyses
- Fixed improved staircase: below-diagonal connections preserve full
  parity rank (col7->row0 s3, col1->row4 s15)
- Fixed PEG matrix: staircase backbone with cross-connections,
  all parity cols dv>=2, VN degrees [7,3,3,3,2,2,2,2]
- Clean up VN degree display (remove np.int64 wrapper)
- Ran all four analyses with 200 frames per point

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 04:56:37 -07:00

1040 lines
32 KiB
Python

#!/usr/bin/env python3
"""
LDPC Code Analysis Tool
Four analyses for the photon-starved optical LDPC decoder:
1. Rate comparison (--rate-sweep)
2. Base matrix comparison (--matrix-compare)
3. Quantization sweep (--quant-sweep)
4. Shannon gap analysis (--shannon-gap)
Usage:
python3 ldpc_analysis.py --rate-sweep
python3 ldpc_analysis.py --matrix-compare
python3 ldpc_analysis.py --quant-sweep
python3 ldpc_analysis.py --shannon-gap
python3 ldpc_analysis.py --all
"""
import numpy as np
import argparse
import json
import os
import sys
from collections import deque
from ldpc_sim import (
poisson_channel, quantize_llr, decode_layered_min_sum,
compute_syndrome_weight, sat_add_q, sat_sub_q, min_sum_cn_update,
Q_BITS, Q_MAX, Q_MIN, OFFSET,
N, K, M, Z, N_BASE, M_BASE, H_BASE,
build_full_h_matrix, ldpc_encode,
)
# =============================================================================
# Core infrastructure
# =============================================================================
def build_ira_staircase(m_base, z=32):
"""
Build generic IRA staircase QC-LDPC code for any rate.
Rate = 1/(m_base+1), n_base = m_base+1.
Col 0 (info) connects to ALL m_base rows with spread shifts.
Lower-triangular staircase parity: row r connects to col r+1 (diagonal),
row r>0 also connects to col r (sub-diagonal).
Returns:
(H_base, H_full) where H_full is the expanded binary matrix.
"""
n_base = m_base + 1
H_b = -np.ones((m_base, n_base), dtype=np.int16)
# Column 0 (info) connects to all rows with spread shifts
for r in range(m_base):
H_b[r, 0] = (r * z // m_base) % z
# Staircase parity structure
for r in range(m_base):
# Diagonal: row r connects to col r+1
if r == 0:
H_b[r, r + 1] = (z // 4) % z # some spread for row 0
else:
H_b[r, r + 1] = 0 # identity for sub-diagonal rows
# Sub-diagonal: row r>0 connects to col r
if r > 0:
H_b[r, r] = (r * 7) % z # spread shift
# Build full expanded matrix
m_full = m_base * z
n_full = n_base * z
H_full = np.zeros((m_full, n_full), dtype=np.int8)
for r in range(m_base):
for c in range(n_base):
shift = int(H_b[r, c])
if shift < 0:
continue
for zz in range(z):
col_idx = c * z + (zz + shift) % z
H_full[r * z + zz, col_idx] = 1
return H_b, H_full
def ira_encode(info, H_base, H_full, z=32):
"""
Generic IRA encoder using sequential staircase solve.
Row 0: solve for p[1].
Rows 1..m_base-1: solve for p[2]..p[m_base].
Asserts syndrome = 0.
"""
m_base = H_base.shape[0]
n_base = H_base.shape[1]
k = z
assert len(info) == k
def apply_p(x, s):
"""Apply circulant P_s: result[i] = x[(i+s)%Z]."""
return np.roll(x, -int(s))
def inv_p(y, s):
"""Apply P_s inverse."""
return np.roll(y, int(s))
# Parity blocks: p[c] for c = 1..n_base-1
p = [np.zeros(z, dtype=np.int8) for _ in range(n_base)]
# Step 1: Solve row 0 for p[1]
# Accumulate contributions from all connected columns except col 1 (the new unknown)
accum = np.zeros(z, dtype=np.int8)
for c in range(n_base):
shift = int(H_base[0, c])
if shift < 0:
continue
if c == 0:
accum = (accum + apply_p(info, shift)) % 2
elif c == 1:
# This is the unknown we're solving for
pass
else:
accum = (accum + apply_p(p[c], shift)) % 2
# p[1] = P_{H[0][1]}^-1 * accum
p[1] = inv_p(accum, H_base[0, 1])
# Step 2: Solve rows 1..m_base-1 for p[2]..p[m_base]
for r in range(1, m_base):
accum = np.zeros(z, dtype=np.int8)
target_col = r + 1 # the new unknown column
for c in range(n_base):
shift = int(H_base[r, c])
if shift < 0:
continue
if c == target_col:
continue # skip the unknown
if c == 0:
accum = (accum + apply_p(info, shift)) % 2
else:
accum = (accum + apply_p(p[c], shift)) % 2
p[target_col] = inv_p(accum, H_base[r, target_col])
# Assemble codeword
codeword = np.concatenate([info] + [p[c] for c in range(1, n_base)])
# Verify
check = H_full @ codeword % 2
assert np.all(check == 0), f"IRA encoding failed: syndrome weight = {check.sum()}"
return codeword
def generic_decode(llr_q, H_base, z=32, max_iter=30, early_term=True, q_bits=6):
"""
Parameterized layered min-sum decoder for any QC-LDPC base matrix.
Handles variable check node degree (different rows may have different
numbers of connected columns).
Args:
llr_q: quantized channel LLRs
H_base: base matrix (m_base x n_base), -1 = no connection
z: lifting factor
max_iter: maximum iterations
early_term: stop when syndrome is zero
q_bits: quantization bits
Returns:
(decoded_info_bits, converged, iterations, syndrome_weight)
"""
m_base_local = H_base.shape[0]
n_base_local = H_base.shape[1]
n_local = n_base_local * z
k_local = z
q_max = 2**(q_bits-1) - 1
q_min = -(2**(q_bits-1))
def sat_add(a, b):
s = int(a) + int(b)
return max(q_min, min(q_max, s))
def sat_sub(a, b):
return sat_add(a, -b)
def cn_update(msgs_in, offset=1):
dc = len(msgs_in)
signs = [1 if m < 0 else 0 for m in msgs_in]
mags = [abs(m) for m in msgs_in]
sign_xor = sum(signs) % 2
min1 = q_max
min2 = q_max
min1_idx = 0
for i in range(dc):
if mags[i] < min1:
min2 = min1
min1 = mags[i]
min1_idx = i
elif mags[i] < min2:
min2 = mags[i]
msgs_out = []
for j in range(dc):
mag = min2 if j == min1_idx else min1
mag = max(0, mag - offset)
sgn = sign_xor ^ signs[j]
val = -mag if sgn else mag
msgs_out.append(val)
return msgs_out
# Initialize beliefs from channel LLRs
beliefs = [int(x) for x in llr_q]
# Initialize CN->VN messages to zero
msg = [[[0 for _ in range(z)] for _ in range(n_base_local)] for _ in range(m_base_local)]
for iteration in range(max_iter):
for row in range(m_base_local):
connected_cols = [c for c in range(n_base_local) if H_base[row, c] >= 0]
dc = len(connected_cols)
# VN->CN messages
vn_to_cn = [[0]*z for _ in range(dc)]
for ci, col in enumerate(connected_cols):
shift = int(H_base[row, col])
for zz in range(z):
shifted_z = (zz + shift) % z
bit_idx = col * z + shifted_z
old_msg = msg[row][col][zz]
vn_to_cn[ci][zz] = sat_sub(beliefs[bit_idx], old_msg)
# CN update
cn_to_vn = [[0]*z for _ in range(dc)]
for zz in range(z):
cn_inputs = [vn_to_cn[ci][zz] for ci in range(dc)]
cn_outputs = cn_update(cn_inputs)
for ci in range(dc):
cn_to_vn[ci][zz] = cn_outputs[ci]
# Update beliefs and store new messages
for ci, col in enumerate(connected_cols):
shift = int(H_base[row, col])
for zz in range(z):
shifted_z = (zz + shift) % z
bit_idx = col * z + shifted_z
new_msg = cn_to_vn[ci][zz]
extrinsic = vn_to_cn[ci][zz]
beliefs[bit_idx] = sat_add(extrinsic, new_msg)
msg[row][col][zz] = new_msg
# Syndrome check
hard = [1 if b < 0 else 0 for b in beliefs]
sw = _compute_syndrome_weight(hard, H_base, z)
if early_term and sw == 0:
return np.array(hard[:k_local]), True, iteration + 1, 0
hard = [1 if b < 0 else 0 for b in beliefs]
sw = _compute_syndrome_weight(hard, H_base, z)
return np.array(hard[:k_local]), sw == 0, max_iter, sw
def _compute_syndrome_weight(hard_bits, H_base, z):
"""Compute syndrome weight for a generic base matrix."""
m_base_local = H_base.shape[0]
n_base_local = H_base.shape[1]
weight = 0
for r in range(m_base_local):
for zz in range(z):
parity = 0
for c in range(n_base_local):
shift = int(H_base[r, c])
if shift < 0:
continue
shifted_z = (zz + shift) % z
bit_idx = c * z + shifted_z
parity ^= hard_bits[bit_idx]
if parity:
weight += 1
return weight
# =============================================================================
# Analysis 1: Rate Comparison
# =============================================================================
def rate_sweep_single(m_base, z, lam_s, lam_b, n_frames, max_iter, q_bits):
"""
Run n_frames through encode/channel/decode for one rate and one lambda_s.
Returns dict with rate, rate_val, m_base, lam_s, ber, fer, avg_iter.
"""
H_base, H_full = build_ira_staircase(m_base=m_base, z=z)
k = z
n = (m_base + 1) * z
rate_val = 1.0 / (m_base + 1)
bit_errors = 0
frame_errors = 0
total_bits = 0
total_iter = 0
for _ in range(n_frames):
info = np.random.randint(0, 2, k).astype(np.int8)
codeword = ira_encode(info, H_base, H_full, z=z)
llr_float, _ = poisson_channel(codeword, lam_s, lam_b)
llr_q = quantize_llr(llr_float, q_bits=q_bits)
decoded, converged, iters, sw = generic_decode(
llr_q, H_base, z=z, max_iter=max_iter, q_bits=q_bits
)
total_iter += iters
errs = np.sum(decoded != info)
bit_errors += errs
total_bits += k
if errs > 0:
frame_errors += 1
ber = bit_errors / total_bits if total_bits > 0 else 0
fer = frame_errors / n_frames
avg_iter = total_iter / n_frames
return {
'rate': f'1/{m_base+1}',
'rate_val': rate_val,
'm_base': m_base,
'lam_s': lam_s,
'ber': float(ber),
'fer': float(fer),
'avg_iter': float(avg_iter),
}
def run_rate_sweep(lam_b=0.1, n_frames=500, max_iter=30):
"""
Compare FER across multiple code rates.
Rates: 1/2, 1/3, 1/4, 1/6, 1/8.
"""
rates = [(1, '1/2'), (2, '1/3'), (3, '1/4'), (5, '1/6'), (7, '1/8')]
lam_s_values = [0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 4.0, 5.0, 7.0, 10.0]
print("=" * 80)
print("ANALYSIS 1: Rate Comparison (IRA staircase codes)")
print(f" n_frames={n_frames}, lam_b={lam_b}, max_iter={max_iter}")
print("=" * 80)
all_results = []
# Header
header = f"{'lam_s':>8s}"
for m, label in rates:
header += f" {label:>8s}"
print(header)
print("-" * len(header))
for lam_s in lam_s_values:
row = f"{lam_s:8.1f}"
for m_base, label in rates:
result = rate_sweep_single(
m_base=m_base, z=32, lam_s=lam_s, lam_b=lam_b,
n_frames=n_frames, max_iter=max_iter, q_bits=Q_BITS,
)
all_results.append(result)
fer = result['fer']
if fer == 0:
row += f" {'0.000':>8s}"
else:
row += f" {fer:8.3f}"
print(row)
# Threshold summary
print("\nThreshold summary (lowest lam_s where FER < 10%):")
for m_base, label in rates:
threshold = None
for lam_s in lam_s_values:
matching = [r for r in all_results
if r['m_base'] == m_base and r['lam_s'] == lam_s]
if matching and matching[0]['fer'] < 0.10:
threshold = lam_s
break
if threshold is not None:
print(f" Rate {label}: lam_s >= {threshold}")
else:
print(f" Rate {label}: FER >= 10% at all tested lam_s")
return all_results
# =============================================================================
# Analysis 2: Base Matrix Comparison
# =============================================================================
def build_improved_staircase(z=32):
"""
Improved staircase: same base structure as original but adds extra
connections to low-degree columns.
Key: col 7 gets extra connection (dv=1->dv=2), col 1 gets extra (dv=2->dv=3).
Extra connections are BELOW the staircase diagonal to preserve
lower-triangular parity structure and full rank.
"""
m_base = 7
n_base = 8
# Start from original H_BASE
H_b = H_BASE.copy()
# Add extra connection for col 7 (currently only connected from row 6)
# Row 0 is below col 7's diagonal (row 6) in the staircase sense
H_b[0, 7] = 3
# Add extra connection for col 1 (currently rows 0,1)
# Row 4 is below col 1's diagonal (row 0)
H_b[4, 1] = 15
# Build full matrix
m_full = m_base * z
n_full = n_base * z
H_full = np.zeros((m_full, n_full), dtype=np.int8)
for r in range(m_base):
for c in range(n_base):
shift = int(H_b[r, c])
if shift < 0:
continue
for zz in range(z):
col_idx = c * z + (zz + shift) % z
H_full[r * z + zz, col_idx] = 1
return H_b, H_full
def build_peg_matrix(z=32):
"""
PEG-like construction with more uniform degree distribution.
Uses staircase parity backbone (guaranteeing full parity rank and
sequential encoding) plus below-diagonal cross-connections for
better girth and degree distribution.
VN degrees: col0=7, cols1-3=3, cols4-7=2 (avg 2.75 vs 2.5 original).
All parity columns have dv >= 2 (no weak degree-1 nodes).
"""
m_base = 7
n_base = 8
H_b = -np.ones((m_base, n_base), dtype=np.int16)
# Column 0 (info): connect to all 7 rows with spread shifts
H_b[0, 0] = 0; H_b[1, 0] = 7; H_b[2, 0] = 13
H_b[3, 0] = 19; H_b[4, 0] = 25; H_b[5, 0] = 3; H_b[6, 0] = 9
# Staircase parity backbone (lower-triangular, full rank guaranteed)
H_b[0, 1] = 5
H_b[1, 1] = 3; H_b[1, 2] = 0
H_b[2, 2] = 11; H_b[2, 3] = 0
H_b[3, 3] = 17; H_b[3, 4] = 0
H_b[4, 4] = 23; H_b[4, 5] = 0
H_b[5, 5] = 29; H_b[5, 6] = 0
H_b[6, 6] = 5; H_b[6, 7] = 0
# Below-diagonal cross-connections for better degree distribution
H_b[3, 1] = 21 # col 1 dv=2->3
H_b[5, 2] = 15 # col 2 dv=2->3
H_b[6, 3] = 27 # col 3 dv=2->3
H_b[0, 7] = 0 # col 7 dv=1->2
# Build full matrix
m_full = m_base * z
n_full = n_base * z
H_full = np.zeros((m_full, n_full), dtype=np.int8)
for r in range(m_base):
for c in range(n_base):
shift = int(H_b[r, c])
if shift < 0:
continue
for zz in range(z):
col_idx = c * z + (zz + shift) % z
H_full[r * z + zz, col_idx] = 1
return H_b, H_full
def peg_encode(info, H_base, H_full, z=32):
"""
GF(2) Gaussian elimination encoder for non-staircase matrices.
Solves H_parity * p = H_info * info mod 2.
Works for any base matrix structure.
"""
m_base_local = H_base.shape[0]
n_base_local = H_base.shape[1]
m_full = m_base_local * z
n_full = n_base_local * z
k = z
assert len(info) == k
# H_full = [H_info | H_parity]
# H_info is first z columns, H_parity is the rest
H_info = H_full[:, :k]
H_parity = H_full[:, k:]
# RHS: H_info * info mod 2
rhs = (H_info @ info) % 2
# Solve H_parity * p = rhs mod 2 via Gaussian elimination
n_parity = n_full - k
# Augmented matrix [H_parity | rhs]
aug = np.zeros((m_full, n_parity + 1), dtype=np.int8)
aug[:, :n_parity] = H_parity.copy()
aug[:, n_parity] = rhs
# Forward elimination
pivot_row = 0
pivot_cols = []
for col in range(n_parity):
# Find pivot
found = -1
for r in range(pivot_row, m_full):
if aug[r, col] == 1:
found = r
break
if found < 0:
continue
# Swap rows
if found != pivot_row:
aug[[pivot_row, found]] = aug[[found, pivot_row]]
# Eliminate
for r in range(m_full):
if r != pivot_row and aug[r, col] == 1:
aug[r] = (aug[r] + aug[pivot_row]) % 2
pivot_cols.append(col)
pivot_row += 1
# Back-substitute
parity = np.zeros(n_parity, dtype=np.int8)
for i, col in enumerate(pivot_cols):
parity[col] = aug[i, n_parity]
# Assemble codeword
codeword = np.concatenate([info, parity])
# Verify
check = H_full @ codeword % 2
assert np.all(check == 0), f"PEG encoding failed: syndrome weight = {check.sum()}"
return codeword
def compute_girth(H_base, z):
"""
BFS-based shortest cycle detection in Tanner graph.
Samples a subset of VNs for speed. Returns the girth (length of
shortest cycle) or infinity if no cycle found.
"""
m_base_local = H_base.shape[0]
n_base_local = H_base.shape[1]
# Build adjacency: VN -> list of CN, CN -> list of VN
# In QC graph, VN (c, zz) connects to CN (r, (zz + shift) % z) for each row r
# where H_base[r, c] >= 0
n_vn = n_base_local * z
n_cn = m_base_local * z
def vn_neighbors(vn_idx):
"""CNs connected to this VN."""
c = vn_idx // z
zz = vn_idx % z
neighbors = []
for r in range(m_base_local):
shift = int(H_base[r, c])
if shift < 0:
continue
cn_idx = r * z + (zz + shift) % z
neighbors.append(cn_idx)
return neighbors
def cn_neighbors(cn_idx):
"""VNs connected to this CN."""
r = cn_idx // z
zz_cn = cn_idx % z
neighbors = []
for c in range(n_base_local):
shift = int(H_base[r, c])
if shift < 0:
continue
# CN (r, zz_cn) connects to VN (c, (zz_cn - shift) % z)
vn_zz = (zz_cn - shift) % z
vn_idx = c * z + vn_zz
neighbors.append(vn_idx)
return neighbors
min_cycle = float('inf')
# Sample VNs: pick a few from each base column
sample_vns = []
for c in range(n_base_local):
for zz in [0, z // 4, z // 2]:
sample_vns.append(c * z + zz)
for start_vn in sample_vns:
# BFS from start_vn through bipartite graph
# Track (node_type, node_idx, parent_type, parent_idx)
# node_type: 'vn' or 'cn'
dist = {}
queue = deque()
dist[('vn', start_vn)] = 0
queue.append(('vn', start_vn, None, None))
while queue:
ntype, nidx, ptype, pidx = queue.popleft()
d = dist[(ntype, nidx)]
if ntype == 'vn':
for cn in vn_neighbors(nidx):
if ('cn', cn) == (ptype, pidx):
continue # don't go back the way we came
if ('cn', cn) in dist:
cycle_len = d + 1 + dist[('cn', cn)]
min_cycle = min(min_cycle, cycle_len)
else:
dist[('cn', cn)] = d + 1
queue.append(('cn', cn, 'vn', nidx))
else: # cn
for vn in cn_neighbors(nidx):
if ('vn', vn) == (ptype, pidx):
continue
if ('vn', vn) in dist:
cycle_len = d + 1 + dist[('vn', vn)]
min_cycle = min(min_cycle, cycle_len)
else:
dist[('vn', vn)] = d + 1
queue.append(('vn', vn, 'cn', nidx))
# Early termination if we already found a short cycle
if min_cycle <= 6:
return min_cycle
return min_cycle
def run_matrix_comparison(lam_b=0.1, n_frames=500, max_iter=30):
"""
Compare three base matrix designs for rate-1/8 code.
1. Original staircase (from ldpc_sim.py)
2. Improved staircase (extra connections)
3. PEG ring structure
"""
print("=" * 80)
print("ANALYSIS 2: Base Matrix Comparison (rate 1/8)")
print(f" n_frames={n_frames}, lam_b={lam_b}, max_iter={max_iter}")
print("=" * 80)
# Build matrices
matrices = {}
# Original staircase
H_orig_base = H_BASE.copy()
H_orig_full = build_full_h_matrix()
matrices['Original staircase'] = {
'H_base': H_orig_base, 'H_full': H_orig_full,
'encoder': 'staircase',
}
# Improved staircase
H_imp_base, H_imp_full = build_improved_staircase(z=32)
matrices['Improved staircase'] = {
'H_base': H_imp_base, 'H_full': H_imp_full,
'encoder': 'gaussian',
}
# PEG ring
H_peg_base, H_peg_full = build_peg_matrix(z=32)
matrices['PEG ring'] = {
'H_base': H_peg_base, 'H_full': H_peg_full,
'encoder': 'gaussian',
}
# Print degree distributions and girth
print("\nDegree distributions and girth:")
print(f" {'Matrix':<22s} {'VN degrees':<30s} {'Girth':>6s}")
print(" " + "-" * 60)
for name, minfo in matrices.items():
H_b = minfo['H_base']
# VN degrees: count non-(-1) entries per column
vn_degrees = []
for c in range(H_b.shape[1]):
d = np.sum(H_b[:, c] >= 0)
vn_degrees.append(d)
girth = compute_girth(H_b, 32)
girth_str = str(girth) if girth < float('inf') else 'inf'
deg_str = str([int(d) for d in vn_degrees])
print(f" {name:<22s} {deg_str:<30s} {girth_str:>6s}")
# Run simulations
lam_s_values = [0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 5.0, 7.0, 10.0]
all_results = []
# Header
header = f"\n{'lam_s':>8s}"
for name in matrices:
header += f" {name[:12]:>12s}"
print(header)
print("-" * len(header))
for lam_s in lam_s_values:
row = f"{lam_s:8.1f}"
for name, minfo in matrices.items():
H_b = minfo['H_base']
H_f = minfo['H_full']
enc_type = minfo['encoder']
bit_errors = 0
frame_errors = 0
total_iter = 0
k = 32
for _ in range(n_frames):
info = np.random.randint(0, 2, k).astype(np.int8)
if enc_type == 'staircase':
codeword = ldpc_encode(info, H_f)
else:
codeword = peg_encode(info, H_b, H_f, z=32)
llr_float, _ = poisson_channel(codeword, lam_s, lam_b)
llr_q = quantize_llr(llr_float)
decoded, converged, iters, sw = generic_decode(
llr_q, H_b, z=32, max_iter=max_iter, q_bits=Q_BITS
)
total_iter += iters
errs = np.sum(decoded != info)
bit_errors += errs
if errs > 0:
frame_errors += 1
fer = frame_errors / n_frames
avg_iter = total_iter / n_frames
all_results.append({
'matrix': name, 'lam_s': lam_s,
'fer': fer, 'avg_iter': avg_iter,
})
if fer == 0:
row += f" {'0.000':>12s}"
else:
row += f" {fer:12.3f}"
print(row)
return all_results
# =============================================================================
# Analysis 3: Quantization Sweep
# =============================================================================
def run_quant_sweep(lam_s=None, lam_b=0.1, n_frames=500, max_iter=30):
"""
Sweep quantization bit-width and measure FER.
Uses the original H_BASE from ldpc_sim for consistency.
Includes 4,5,6,8,10,16 bits plus a "float" mode (16-bit with high scale).
"""
if lam_s is None:
lam_s_values = [2.0, 3.0, 5.0]
else:
lam_s_values = [lam_s]
# Each entry: (label, actual_q_bits)
# "float" uses 16-bit quantization (effectively lossless for this range)
quant_configs = [
('4', 4), ('5', 5), ('6', 6), ('8', 8), ('10', 10), ('16', 16),
('float', 16),
]
print("=" * 80)
print("ANALYSIS 3: Quantization Sweep (rate 1/8, original staircase)")
print(f" n_frames={n_frames}, lam_b={lam_b}, max_iter={max_iter}")
print("=" * 80)
H_full = build_full_h_matrix()
all_results = []
# Header
header = f"{'lam_s':>8s}"
for label, _ in quant_configs:
header += f" {label + '-bit':>8s}"
print(header)
print("-" * len(header))
for lam_s_val in lam_s_values:
row = f"{lam_s_val:8.1f}"
for label, actual_q in quant_configs:
bit_errors = 0
frame_errors = 0
total_iter = 0
for _ in range(n_frames):
info = np.random.randint(0, 2, K).astype(np.int8)
codeword = ldpc_encode(info, H_full)
llr_float, _ = poisson_channel(codeword, lam_s_val, lam_b)
llr_q = quantize_llr(llr_float, q_bits=actual_q)
decoded, converged, iters, sw = generic_decode(
llr_q, H_BASE, z=Z, max_iter=max_iter, q_bits=actual_q
)
total_iter += iters
errs = np.sum(decoded != info)
bit_errors += errs
if errs > 0:
frame_errors += 1
fer = frame_errors / n_frames
all_results.append({
'q_bits': label, 'lam_s': lam_s_val,
'fer': float(fer),
})
if fer == 0:
row += f" {'0.000':>8s}"
else:
row += f" {fer:8.3f}"
print(row)
return all_results
# =============================================================================
# Analysis 4: Shannon Gap
# =============================================================================
def poisson_channel_capacity(lam_s, lam_b, max_y=50):
"""
Binary-input Poisson channel capacity.
Uses scipy.stats.poisson for PMFs.
Optimizes over input probability p (0 to 1 in 1% steps).
C = max_p [H(Y) - p0*H(Y|X=0) - p1*H(Y|X=1)]
"""
from scipy.stats import poisson
def entropy(pmf):
"""Compute entropy of a discrete distribution."""
pmf = pmf[pmf > 0]
return -np.sum(pmf * np.log2(pmf))
# PMFs for Y|X=0 and Y|X=1
y_range = np.arange(0, max_y + 1)
py_given_0 = poisson.pmf(y_range, lam_b) # P(Y=y | X=0)
py_given_1 = poisson.pmf(y_range, lam_s + lam_b) # P(Y=y | X=1)
# Conditional entropies
h_y_given_0 = entropy(py_given_0)
h_y_given_1 = entropy(py_given_1)
best_capacity = 0.0
best_p = 0.5
# Search over input probability p1 = P(X=1)
for p_int in range(0, 101):
p1 = p_int / 100.0
p0 = 1.0 - p1
# P(Y=y) = p0 * P(Y=y|X=0) + p1 * P(Y=y|X=1)
py = p0 * py_given_0 + p1 * py_given_1
# C = H(Y) - H(Y|X)
h_y = entropy(py)
h_y_given_x = p0 * h_y_given_0 + p1 * h_y_given_1
capacity = h_y - h_y_given_x
if capacity > best_capacity:
best_capacity = capacity
best_p = p1
return best_capacity
def run_shannon_gap(lam_b=0.1):
"""
For each rate (1/2 through 1/8): binary search for minimum lambda_s
where channel capacity >= rate.
"""
print("=" * 80)
print("ANALYSIS 4: Shannon Gap (Binary-Input Poisson Channel)")
print(f" lam_b={lam_b}")
print("=" * 80)
rates = [
(1, '1/2', 0.5),
(2, '1/3', 1.0 / 3),
(3, '1/4', 0.25),
(5, '1/6', 1.0 / 6),
(7, '1/8', 0.125),
]
results = []
print(f"\n{'Rate':>8s} {'Shannon lam_s':>14s} {'Capacity':>10s}")
print("-" * 35)
for m_base, label, rate_val in rates:
# Binary search for minimum lam_s where C >= rate_val
lo = 0.001
hi = 50.0
# First check if capacity at hi is sufficient
c_hi = poisson_channel_capacity(hi, lam_b)
if c_hi < rate_val:
print(f"{label:>8s} {'> 50':>14s} {c_hi:10.4f}")
results.append({
'rate': label, 'rate_val': rate_val,
'shannon_lam_s': None, 'capacity': c_hi,
})
continue
for _ in range(50): # binary search iterations
mid = (lo + hi) / 2
c_mid = poisson_channel_capacity(mid, lam_b)
if c_mid >= rate_val:
hi = mid
else:
lo = mid
if hi - lo < 0.001:
break
shannon_lam_s = hi
c_final = poisson_channel_capacity(shannon_lam_s, lam_b)
print(f"{label:>8s} {shannon_lam_s:14.3f} {c_final:10.4f}")
results.append({
'rate': label, 'rate_val': rate_val,
'shannon_lam_s': float(shannon_lam_s),
'capacity': float(c_final),
})
print("\nInterpretation:")
print(" Shannon limit = minimum lam_s where C(lam_s) >= R.")
print(" Practical LDPC codes operate ~0.5-2 dB above Shannon limit.")
print(" Lower rates need fewer photons/slot (lower lam_s threshold).")
print(" Gap to Shannon = 10*log10(lam_s_practical / lam_s_shannon) dB.")
return results
# =============================================================================
# CLI
# =============================================================================
def main():
parser = argparse.ArgumentParser(
description='LDPC Code Analysis Tool',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 ldpc_analysis.py --rate-sweep
python3 ldpc_analysis.py --matrix-compare
python3 ldpc_analysis.py --quant-sweep
python3 ldpc_analysis.py --shannon-gap
python3 ldpc_analysis.py --all
"""
)
parser.add_argument('--rate-sweep', action='store_true',
help='Compare FER across code rates')
parser.add_argument('--matrix-compare', action='store_true',
help='Compare base matrix designs')
parser.add_argument('--quant-sweep', action='store_true',
help='Sweep quantization bit-width')
parser.add_argument('--shannon-gap', action='store_true',
help='Compute Shannon limits')
parser.add_argument('--all', action='store_true',
help='Run all analyses')
parser.add_argument('--n-frames', type=int, default=500,
help='Frames per simulation point (default: 500)')
parser.add_argument('--lam-b', type=float, default=0.1,
help='Background photon rate (default: 0.1)')
parser.add_argument('--seed', type=int, default=42,
help='Random seed (default: 42)')
args = parser.parse_args()
np.random.seed(args.seed)
all_results = {}
ran_any = False
if args.rate_sweep or args.all:
ran_any = True
results = run_rate_sweep(
lam_b=args.lam_b, n_frames=args.n_frames, max_iter=30
)
all_results['rate_sweep'] = results
print()
if args.matrix_compare or args.all:
ran_any = True
results = run_matrix_comparison(
lam_b=args.lam_b, n_frames=args.n_frames, max_iter=30
)
all_results['matrix_compare'] = results
print()
if args.quant_sweep or args.all:
ran_any = True
results = run_quant_sweep(
lam_b=args.lam_b, n_frames=args.n_frames, max_iter=30
)
all_results['quant_sweep'] = results
print()
if args.shannon_gap or args.all:
ran_any = True
results = run_shannon_gap(lam_b=args.lam_b)
all_results['shannon_gap'] = results
print()
if not ran_any:
parser.print_help()
return
# Save results
out_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data')
os.makedirs(out_dir, exist_ok=True)
out_path = os.path.join(out_dir, 'analysis_results.json')
with open(out_path, 'w') as f:
json.dump(all_results, f, indent=2, default=str)
print(f"Results saved to {out_path}")
if __name__ == '__main__':
main()