Files
ldpc_optical/model/density_evolution.py
2026-02-24 05:53:23 -07:00

522 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Density Evolution Optimizer for QC-LDPC Codes
Monte Carlo density evolution to find optimal base matrix degree distributions
that lower the decoding threshold for photon-starved optical communication.
Usage:
python3 density_evolution.py threshold # Compute threshold for known matrices
python3 density_evolution.py optimize # Run optimizer, print top-10
python3 density_evolution.py construct DEGREES # Build matrix for given degrees
python3 density_evolution.py validate # FER comparison: optimized vs original
python3 density_evolution.py full # Full pipeline
"""
import numpy as np
import argparse
import json
import os
import sys
sys.path.insert(0, os.path.dirname(__file__))
from ldpc_sim import (
Q_BITS, Q_MAX, Q_MIN, OFFSET, Z, N_BASE, M_BASE, H_BASE,
build_full_h_matrix,
)
# =============================================================================
# H_base profile representation for density evolution
# =============================================================================
# A "profile" is a dict with:
# 'n_base': int - number of columns
# 'm_base': int - number of rows
# 'connections': list of lists - connections[row] = list of connected column indices
# 'vn_degrees': list of int - degree of each column
#
# For DE, we don't need circulant shifts - only the connectivity pattern matters.
def make_profile(H_base):
"""Convert a base matrix to a DE profile."""
m_base = H_base.shape[0]
n_base = H_base.shape[1]
connections = []
for r in range(m_base):
row_cols = [c for c in range(n_base) if H_base[r, c] >= 0]
connections.append(row_cols)
vn_degrees = []
for c in range(n_base):
vn_degrees.append(sum(1 for r in range(m_base) if H_base[r, c] >= 0))
return {
'n_base': n_base,
'm_base': m_base,
'connections': connections,
'vn_degrees': vn_degrees,
}
ORIGINAL_STAIRCASE_PROFILE = make_profile(H_BASE)
# =============================================================================
# Monte Carlo Density Evolution Engine
# =============================================================================
def de_channel_init(profile, z_pop, lam_s, lam_b):
"""
Generate initial LLR population for each VN column group.
Returns:
beliefs: ndarray (n_base, z_pop) - LLR beliefs per column group
msg_memory: dict mapping (row, col_idx_in_row) -> ndarray(z_pop) of CN->VN messages
"""
n_base = profile['n_base']
m_base = profile['m_base']
# For each column, generate z_pop independent channel observations
# All columns transmit bit=0 (all-zeros codeword) for DE
# P(y|0): photon count ~ Poisson(lam_b)
# LLR = lam_s - y * log((lam_s + lam_b) / lam_b)
beliefs = np.zeros((n_base, z_pop), dtype=np.float64)
log_ratio = np.log((lam_s + lam_b) / lam_b) if lam_b > 0 else 100.0
for c in range(n_base):
# Channel observation: all-zero codeword, so photons ~ Poisson(lam_b)
y = np.random.poisson(lam_b, size=z_pop)
llr_float = lam_s - y * log_ratio
# Quantize to 6-bit signed
scale = Q_MAX / 5.0
llr_q = np.round(llr_float * scale).astype(np.int32)
llr_q = np.clip(llr_q, Q_MIN, Q_MAX)
beliefs[c] = llr_q.astype(np.float64)
# Initialize CN->VN message memory to zero
msg_memory = {}
for r in range(m_base):
for ci, c in enumerate(profile['connections'][r]):
msg_memory[(r, c)] = np.zeros(z_pop, dtype=np.float64)
return beliefs, msg_memory
def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
"""
Vectorized offset min-sum CN update for a batch of z_pop messages.
Args:
vn_msgs_list: list of ndarrays, each (z_pop,) - one per connected VN
offset: min-sum offset
Returns:
cn_out_list: list of ndarrays, each (z_pop,) - CN->VN messages
"""
dc = len(vn_msgs_list)
if dc < 2:
return [np.zeros_like(vn_msgs_list[0])] if dc == 1 else []
z_pop = len(vn_msgs_list[0])
# Stack into (dc, z_pop) array
msgs = np.array(vn_msgs_list, dtype=np.float64)
# Signs and magnitudes
signs = (msgs < 0).astype(np.int8) # (dc, z_pop)
mags = np.abs(msgs) # (dc, z_pop)
# Total sign XOR across all inputs
sign_xor = np.sum(signs, axis=0) % 2 # (z_pop,)
# Find min1 and min2 magnitudes
# Sort magnitudes along dc axis, take smallest two
sorted_mags = np.sort(mags, axis=0) # (dc, z_pop)
min1 = sorted_mags[0] # (z_pop,)
min2 = sorted_mags[1] # (z_pop,)
# For each output j: magnitude = min2 if j has min1, else min1
# Find which input has the minimum magnitude
min1_mask = (mags == min1[np.newaxis, :]) # (dc, z_pop)
# Break ties: only first occurrence gets min2
# Use cumsum trick to mark only the first True per column
first_min = min1_mask & (np.cumsum(min1_mask, axis=0) == 1)
cn_out_list = []
for j in range(dc):
mag = np.where(first_min[j], min2, min1)
mag = np.maximum(0, mag - offset)
# Extrinsic sign: total XOR minus this input's sign
ext_sign = sign_xor ^ signs[j]
val = np.where(ext_sign, -mag, mag)
# Quantize
val = np.clip(np.round(val), Q_MIN, Q_MAX)
cn_out_list.append(val)
return cn_out_list
def density_evolution_step(beliefs, msg_memory, profile, z_pop):
"""
One full DE iteration: process all rows (layers) of the base matrix.
For DE, we randomly permute within each column group to simulate
the random interleaving effect (circulant shifts are random for DE).
Args:
beliefs: ndarray (n_base, z_pop) - current VN beliefs
msg_memory: dict (row, col) -> ndarray(z_pop) of old CN->VN messages
profile: DE profile dict
z_pop: population size
Returns:
beliefs: updated beliefs (modified in-place and returned)
"""
m_base = profile['m_base']
for row in range(m_base):
connected_cols = profile['connections'][row]
dc = len(connected_cols)
if dc < 2:
continue
# Step 1: Compute VN->CN messages (subtract old CN->VN)
# Randomly permute within each column group for DE
vn_to_cn = []
permuted_indices = []
for ci, col in enumerate(connected_cols):
perm = np.random.permutation(z_pop)
permuted_indices.append(perm)
old_msg = msg_memory[(row, col)]
# VN->CN = belief[col][permuted] - old_CN->VN
vn_msg = beliefs[col][perm] - old_msg
# Saturate
vn_msg = np.clip(np.round(vn_msg), Q_MIN, Q_MAX)
vn_to_cn.append(vn_msg)
# Step 2: CN update (vectorized min-sum)
cn_to_vn = de_cn_update_vectorized(vn_to_cn, offset=OFFSET)
# Step 3: Update beliefs and store new messages
for ci, col in enumerate(connected_cols):
perm = permuted_indices[ci]
new_msg = cn_to_vn[ci]
extrinsic = vn_to_cn[ci]
# New belief = extrinsic + new CN->VN
new_belief = extrinsic + new_msg
new_belief = np.clip(np.round(new_belief), Q_MIN, Q_MAX)
# Write back to beliefs at permuted positions
beliefs[col][perm] = new_belief
# Store new CN->VN message
msg_memory[(row, col)] = new_msg
return beliefs
def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
"""
Run full density evolution simulation.
Args:
profile: DE profile dict
lam_s: signal photons per slot
lam_b: background photons per slot
z_pop: population size
max_iter: maximum iterations
Returns:
(converged: bool, avg_error_frac: float)
Convergence means fraction of wrong-sign beliefs < 1e-4.
"""
beliefs, msg_memory = de_channel_init(profile, z_pop, lam_s, lam_b)
for it in range(max_iter):
beliefs = density_evolution_step(beliefs, msg_memory, profile, z_pop)
# Check convergence: for all-zeros codeword, correct belief is positive
# Error = fraction of beliefs with wrong sign (negative)
total_wrong = 0
total_count = 0
for c in range(profile['n_base']):
total_wrong += np.sum(beliefs[c] < 0)
total_count += z_pop
error_frac = total_wrong / total_count
if error_frac < 1e-4:
return True, error_frac
return False, error_frac
# =============================================================================
# Threshold Computation via Binary Search
# =============================================================================
def build_de_profile(vn_degrees, m_base=7):
"""
Build a DE profile from a VN degree list.
Col 0 is always info (connects to all m_base rows).
Cols 1..n_base-1 are parity columns with staircase backbone
plus extra connections to reach target degrees.
Extra connections are placed below-diagonal to preserve
lower-triangular parity structure.
"""
n_base = len(vn_degrees)
assert n_base == m_base + 1, f"Expected {m_base+1} columns, got {n_base}"
assert vn_degrees[0] == m_base, f"Info column must have degree {m_base}"
# Start with staircase backbone connections
# connections[row] = set of connected columns
connections = [set() for _ in range(m_base)]
# Col 0 (info) connects to all rows
for r in range(m_base):
connections[r].add(0)
# Staircase backbone: row r connects to col r+1 (diagonal)
# row r>0 also connects to col r (sub-diagonal)
for r in range(m_base):
connections[r].add(r + 1) # diagonal
if r > 0:
connections[r].add(r) # sub-diagonal
# Check current degrees
current_degrees = [0] * n_base
for r in range(m_base):
for c in connections[r]:
current_degrees[c] += 1
# Add extra connections for parity columns that need higher degree
for c in range(1, n_base):
needed = vn_degrees[c] - current_degrees[c]
if needed <= 0:
continue
# Find rows we can add this column to (not already connected)
# Prefer below-diagonal rows to preserve lower-triangular structure
available_rows = []
for r in range(m_base):
if c not in connections[r]:
available_rows.append(r)
# Sort by distance below diagonal (prefer far below for structure)
# Diagonal for col c is row c-1
available_rows.sort(key=lambda r: -(r - (c - 1)) % m_base)
for r in available_rows[:needed]:
connections[r].add(c)
current_degrees[c] += 1
# Convert sets to sorted lists
connections_list = [sorted(conns) for conns in connections]
# Verify degrees
final_degrees = [0] * n_base
for r in range(m_base):
for c in connections_list[r]:
final_degrees[c] += 1
return {
'n_base': n_base,
'm_base': m_base,
'connections': connections_list,
'vn_degrees': final_degrees,
}
def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
"""
Binary search for minimum lam_s where DE converges.
Returns threshold lam_s*.
"""
lo = 0.1
hi = 20.0
# Verify hi converges
converged, _ = run_de(profile, lam_s=hi, lam_b=lam_b, z_pop=z_pop, max_iter=100)
if not converged:
return hi # Doesn't converge even at hi
# Verify lo doesn't converge
converged, _ = run_de(profile, lam_s=lo, lam_b=lam_b, z_pop=z_pop, max_iter=100)
if converged:
return lo # Converges even at lo
while hi - lo > tol:
mid = (lo + hi) / 2
converged, _ = run_de(profile, lam_s=mid, lam_b=lam_b, z_pop=z_pop, max_iter=100)
if converged:
hi = mid
else:
lo = mid
return hi
def compute_threshold_for_profile(vn_degrees, m_base=7, lam_b=0.1, z_pop=50000, tol=0.1):
"""
Convenience wrapper: compute DE threshold from a VN degree list.
"""
profile = build_de_profile(vn_degrees, m_base=m_base)
return compute_threshold(profile, lam_b=lam_b, z_pop=z_pop, tol=tol)
# =============================================================================
# Degree Distribution Optimizer
# =============================================================================
def enumerate_vn_candidates(m_base=7):
"""
Enumerate all VN degree distributions for parity columns.
Col 0 is always dv=m_base. Parity cols 1..m_base each have dv in {2, 3, 4}.
Returns list of degree vectors (length m_base+1).
"""
from itertools import product
candidates = []
for combo in product([2, 3, 4], repeat=m_base):
degrees = [m_base] + list(combo)
candidates.append(degrees)
return candidates
def filter_by_row_degree(candidates, m_base=7, dc_min=3, dc_max=6):
"""
Filter candidates by row degree constraints.
For a valid distribution, the total edges must be distributable such that
each row has degree in [dc_min, dc_max].
For our structure: info col contributes 1 edge to each row (m_base total).
Parity edges must distribute to give each row dc in [dc_min, dc_max].
"""
filtered = []
for degrees in candidates:
n_base = len(degrees)
# Total parity edges = sum of parity column degrees
parity_edges = sum(degrees[1:])
# Info col contributes 1 edge per row
# So total edges per row = 1 (from info) + parity edges assigned to that row
# Total parity edges must be distributable: each row gets (dc - 1) parity edges
# where dc_min <= dc <= dc_max
# So: m_base * (dc_min - 1) <= parity_edges <= m_base * (dc_max - 1)
min_parity = m_base * (dc_min - 1)
max_parity = m_base * (dc_max - 1)
if min_parity <= parity_edges <= max_parity:
filtered.append(degrees)
return filtered
def coarse_screen(candidates, lam_s_test, lam_b, z_pop, max_iter, m_base=7):
"""
Quick convergence test: run DE at a test point, keep candidates that converge.
"""
survivors = []
for degrees in candidates:
profile = build_de_profile(degrees, m_base=m_base)
converged, error_frac = run_de(
profile, lam_s=lam_s_test, lam_b=lam_b,
z_pop=z_pop, max_iter=max_iter
)
if converged:
survivors.append(degrees)
return survivors
def get_unique_distributions(candidates):
"""
Group candidates by sorted parity degree sequence.
For DE, only the degree distribution matters, not which column has
which degree. Returns list of representative degree vectors (one per
unique distribution), with parity degrees sorted descending.
"""
seen = set()
unique = []
for degrees in candidates:
# Sort parity degrees descending for canonical form
parity_sorted = tuple(sorted(degrees[1:], reverse=True))
if parity_sorted not in seen:
seen.add(parity_sorted)
# Use canonical form: info degree + sorted parity
unique.append([degrees[0]] + list(parity_sorted))
return unique
def optimize_degree_distribution(m_base=7, lam_b=0.1, top_k=10,
z_pop_coarse=10000, z_pop_fine=50000,
tol=0.1):
"""
Full optimization pipeline: enumerate, filter, coarse screen, fine threshold.
Key optimization: for DE, only the degree distribution matters (not column
ordering), so we group 2187 candidates into ~36 unique distributions.
Returns list of (vn_degrees, threshold) sorted by threshold ascending.
"""
print("Step 1: Enumerating candidates...")
candidates = enumerate_vn_candidates(m_base=m_base)
print(f" {len(candidates)} total candidates")
print("Step 2: Filtering by row degree constraints...")
filtered = filter_by_row_degree(candidates, m_base=m_base, dc_min=3, dc_max=6)
print(f" {len(filtered)} candidates after filtering")
print("Step 3: Grouping by unique degree distribution...")
unique = get_unique_distributions(filtered)
print(f" {len(unique)} unique distributions")
print("Step 4: Coarse screening at lam_s=2.0...")
survivors = coarse_screen(
unique, lam_s_test=2.0, lam_b=lam_b,
z_pop=z_pop_coarse, max_iter=50, m_base=m_base
)
print(f" {len(survivors)} survivors after coarse screen")
if not survivors:
print(" No survivors at lam_s=2.0, trying lam_s=3.0...")
survivors = coarse_screen(
unique, lam_s_test=3.0, lam_b=lam_b,
z_pop=z_pop_coarse, max_iter=50, m_base=m_base
)
print(f" {len(survivors)} survivors at lam_s=3.0")
if not survivors:
print(" No survivors found, returning empty list")
return []
print(f"Step 5: Fine threshold computation for {len(survivors)} survivors...")
results = []
for i, degrees in enumerate(survivors):
profile = build_de_profile(degrees, m_base=m_base)
threshold = compute_threshold(profile, lam_b=lam_b, z_pop=z_pop_fine, tol=tol)
results.append((degrees, threshold))
if (i + 1) % 5 == 0:
print(f" {i+1}/{len(survivors)} done...")
# Sort by threshold ascending
results.sort(key=lambda x: x[1])
print(f"\nTop-{min(top_k, len(results))} degree distributions:")
for i, (degrees, threshold) in enumerate(results[:top_k]):
print(f" #{i+1}: {degrees} -> threshold = {threshold:.2f} photons/slot")
return results[:top_k]
# =============================================================================
# CLI placeholder (will be extended in later tasks)
# =============================================================================
if __name__ == '__main__':
print("Density Evolution Optimizer")
print("Run with subcommands: threshold, optimize, construct, validate, full")