feat: add Monte Carlo density evolution engine
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
262
model/density_evolution.py
Normal file
262
model/density_evolution.py
Normal file
@@ -0,0 +1,262 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Density Evolution Optimizer for QC-LDPC Codes
|
||||
|
||||
Monte Carlo density evolution to find optimal base matrix degree distributions
|
||||
that lower the decoding threshold for photon-starved optical communication.
|
||||
|
||||
Usage:
|
||||
python3 density_evolution.py threshold # Compute threshold for known matrices
|
||||
python3 density_evolution.py optimize # Run optimizer, print top-10
|
||||
python3 density_evolution.py construct DEGREES # Build matrix for given degrees
|
||||
python3 density_evolution.py validate # FER comparison: optimized vs original
|
||||
python3 density_evolution.py full # Full pipeline
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
from ldpc_sim import (
|
||||
Q_BITS, Q_MAX, Q_MIN, OFFSET, Z, N_BASE, M_BASE, H_BASE,
|
||||
build_full_h_matrix,
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# H_base profile representation for density evolution
|
||||
# =============================================================================
|
||||
# A "profile" is a dict with:
|
||||
# 'n_base': int - number of columns
|
||||
# 'm_base': int - number of rows
|
||||
# 'connections': list of lists - connections[row] = list of connected column indices
|
||||
# 'vn_degrees': list of int - degree of each column
|
||||
#
|
||||
# For DE, we don't need circulant shifts - only the connectivity pattern matters.
|
||||
|
||||
def make_profile(H_base):
|
||||
"""Convert a base matrix to a DE profile."""
|
||||
m_base = H_base.shape[0]
|
||||
n_base = H_base.shape[1]
|
||||
connections = []
|
||||
for r in range(m_base):
|
||||
row_cols = [c for c in range(n_base) if H_base[r, c] >= 0]
|
||||
connections.append(row_cols)
|
||||
vn_degrees = []
|
||||
for c in range(n_base):
|
||||
vn_degrees.append(sum(1 for r in range(m_base) if H_base[r, c] >= 0))
|
||||
return {
|
||||
'n_base': n_base,
|
||||
'm_base': m_base,
|
||||
'connections': connections,
|
||||
'vn_degrees': vn_degrees,
|
||||
}
|
||||
|
||||
|
||||
ORIGINAL_STAIRCASE_PROFILE = make_profile(H_BASE)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Monte Carlo Density Evolution Engine
|
||||
# =============================================================================
|
||||
|
||||
def de_channel_init(profile, z_pop, lam_s, lam_b):
|
||||
"""
|
||||
Generate initial LLR population for each VN column group.
|
||||
|
||||
Returns:
|
||||
beliefs: ndarray (n_base, z_pop) - LLR beliefs per column group
|
||||
msg_memory: dict mapping (row, col_idx_in_row) -> ndarray(z_pop) of CN->VN messages
|
||||
"""
|
||||
n_base = profile['n_base']
|
||||
m_base = profile['m_base']
|
||||
|
||||
# For each column, generate z_pop independent channel observations
|
||||
# All columns transmit bit=0 (all-zeros codeword) for DE
|
||||
# P(y|0): photon count ~ Poisson(lam_b)
|
||||
# LLR = lam_s - y * log((lam_s + lam_b) / lam_b)
|
||||
|
||||
beliefs = np.zeros((n_base, z_pop), dtype=np.float64)
|
||||
log_ratio = np.log((lam_s + lam_b) / lam_b) if lam_b > 0 else 100.0
|
||||
|
||||
for c in range(n_base):
|
||||
# Channel observation: all-zero codeword, so photons ~ Poisson(lam_b)
|
||||
y = np.random.poisson(lam_b, size=z_pop)
|
||||
llr_float = lam_s - y * log_ratio
|
||||
|
||||
# Quantize to 6-bit signed
|
||||
scale = Q_MAX / 5.0
|
||||
llr_q = np.round(llr_float * scale).astype(np.int32)
|
||||
llr_q = np.clip(llr_q, Q_MIN, Q_MAX)
|
||||
beliefs[c] = llr_q.astype(np.float64)
|
||||
|
||||
# Initialize CN->VN message memory to zero
|
||||
msg_memory = {}
|
||||
for r in range(m_base):
|
||||
for ci, c in enumerate(profile['connections'][r]):
|
||||
msg_memory[(r, c)] = np.zeros(z_pop, dtype=np.float64)
|
||||
|
||||
return beliefs, msg_memory
|
||||
|
||||
|
||||
def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
|
||||
"""
|
||||
Vectorized offset min-sum CN update for a batch of z_pop messages.
|
||||
|
||||
Args:
|
||||
vn_msgs_list: list of ndarrays, each (z_pop,) - one per connected VN
|
||||
offset: min-sum offset
|
||||
|
||||
Returns:
|
||||
cn_out_list: list of ndarrays, each (z_pop,) - CN->VN messages
|
||||
"""
|
||||
dc = len(vn_msgs_list)
|
||||
if dc < 2:
|
||||
return [np.zeros_like(vn_msgs_list[0])] if dc == 1 else []
|
||||
|
||||
z_pop = len(vn_msgs_list[0])
|
||||
|
||||
# Stack into (dc, z_pop) array
|
||||
msgs = np.array(vn_msgs_list, dtype=np.float64)
|
||||
|
||||
# Signs and magnitudes
|
||||
signs = (msgs < 0).astype(np.int8) # (dc, z_pop)
|
||||
mags = np.abs(msgs) # (dc, z_pop)
|
||||
|
||||
# Total sign XOR across all inputs
|
||||
sign_xor = np.sum(signs, axis=0) % 2 # (z_pop,)
|
||||
|
||||
# Find min1 and min2 magnitudes
|
||||
# Sort magnitudes along dc axis, take smallest two
|
||||
sorted_mags = np.sort(mags, axis=0) # (dc, z_pop)
|
||||
min1 = sorted_mags[0] # (z_pop,)
|
||||
min2 = sorted_mags[1] # (z_pop,)
|
||||
|
||||
# For each output j: magnitude = min2 if j has min1, else min1
|
||||
# Find which input has the minimum magnitude
|
||||
min1_mask = (mags == min1[np.newaxis, :]) # (dc, z_pop)
|
||||
# Break ties: only first occurrence gets min2
|
||||
# Use cumsum trick to mark only the first True per column
|
||||
first_min = min1_mask & (np.cumsum(min1_mask, axis=0) == 1)
|
||||
|
||||
cn_out_list = []
|
||||
for j in range(dc):
|
||||
mag = np.where(first_min[j], min2, min1)
|
||||
mag = np.maximum(0, mag - offset)
|
||||
|
||||
# Extrinsic sign: total XOR minus this input's sign
|
||||
ext_sign = sign_xor ^ signs[j]
|
||||
val = np.where(ext_sign, -mag, mag)
|
||||
|
||||
# Quantize
|
||||
val = np.clip(np.round(val), Q_MIN, Q_MAX)
|
||||
cn_out_list.append(val)
|
||||
|
||||
return cn_out_list
|
||||
|
||||
|
||||
def density_evolution_step(beliefs, msg_memory, profile, z_pop):
|
||||
"""
|
||||
One full DE iteration: process all rows (layers) of the base matrix.
|
||||
|
||||
For DE, we randomly permute within each column group to simulate
|
||||
the random interleaving effect (circulant shifts are random for DE).
|
||||
|
||||
Args:
|
||||
beliefs: ndarray (n_base, z_pop) - current VN beliefs
|
||||
msg_memory: dict (row, col) -> ndarray(z_pop) of old CN->VN messages
|
||||
profile: DE profile dict
|
||||
z_pop: population size
|
||||
|
||||
Returns:
|
||||
beliefs: updated beliefs (modified in-place and returned)
|
||||
"""
|
||||
m_base = profile['m_base']
|
||||
|
||||
for row in range(m_base):
|
||||
connected_cols = profile['connections'][row]
|
||||
dc = len(connected_cols)
|
||||
if dc < 2:
|
||||
continue
|
||||
|
||||
# Step 1: Compute VN->CN messages (subtract old CN->VN)
|
||||
# Randomly permute within each column group for DE
|
||||
vn_to_cn = []
|
||||
permuted_indices = []
|
||||
for ci, col in enumerate(connected_cols):
|
||||
perm = np.random.permutation(z_pop)
|
||||
permuted_indices.append(perm)
|
||||
old_msg = msg_memory[(row, col)]
|
||||
# VN->CN = belief[col][permuted] - old_CN->VN
|
||||
vn_msg = beliefs[col][perm] - old_msg
|
||||
# Saturate
|
||||
vn_msg = np.clip(np.round(vn_msg), Q_MIN, Q_MAX)
|
||||
vn_to_cn.append(vn_msg)
|
||||
|
||||
# Step 2: CN update (vectorized min-sum)
|
||||
cn_to_vn = de_cn_update_vectorized(vn_to_cn, offset=OFFSET)
|
||||
|
||||
# Step 3: Update beliefs and store new messages
|
||||
for ci, col in enumerate(connected_cols):
|
||||
perm = permuted_indices[ci]
|
||||
new_msg = cn_to_vn[ci]
|
||||
extrinsic = vn_to_cn[ci]
|
||||
|
||||
# New belief = extrinsic + new CN->VN
|
||||
new_belief = extrinsic + new_msg
|
||||
new_belief = np.clip(np.round(new_belief), Q_MIN, Q_MAX)
|
||||
|
||||
# Write back to beliefs at permuted positions
|
||||
beliefs[col][perm] = new_belief
|
||||
|
||||
# Store new CN->VN message
|
||||
msg_memory[(row, col)] = new_msg
|
||||
|
||||
return beliefs
|
||||
|
||||
|
||||
def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
|
||||
"""
|
||||
Run full density evolution simulation.
|
||||
|
||||
Args:
|
||||
profile: DE profile dict
|
||||
lam_s: signal photons per slot
|
||||
lam_b: background photons per slot
|
||||
z_pop: population size
|
||||
max_iter: maximum iterations
|
||||
|
||||
Returns:
|
||||
(converged: bool, avg_error_frac: float)
|
||||
Convergence means fraction of wrong-sign beliefs < 1e-4.
|
||||
"""
|
||||
beliefs, msg_memory = de_channel_init(profile, z_pop, lam_s, lam_b)
|
||||
|
||||
for it in range(max_iter):
|
||||
beliefs = density_evolution_step(beliefs, msg_memory, profile, z_pop)
|
||||
|
||||
# Check convergence: for all-zeros codeword, correct belief is positive
|
||||
# Error = fraction of beliefs with wrong sign (negative)
|
||||
total_wrong = 0
|
||||
total_count = 0
|
||||
for c in range(profile['n_base']):
|
||||
total_wrong += np.sum(beliefs[c] < 0)
|
||||
total_count += z_pop
|
||||
|
||||
error_frac = total_wrong / total_count
|
||||
if error_frac < 1e-4:
|
||||
return True, error_frac
|
||||
|
||||
return False, error_frac
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# CLI placeholder (will be extended in later tasks)
|
||||
# =============================================================================
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Density Evolution Optimizer")
|
||||
print("Run with subcommands: threshold, optimize, construct, validate, full")
|
||||
52
model/test_density_evolution.py
Normal file
52
model/test_density_evolution.py
Normal file
@@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tests for density evolution optimizer."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.dirname(__file__))
|
||||
|
||||
|
||||
class TestDensityEvolution:
|
||||
"""Tests for the Monte Carlo DE engine."""
|
||||
|
||||
def test_de_known_good_converges(self):
|
||||
"""DE with original staircase profile at lam_s=10 should converge easily."""
|
||||
from density_evolution import run_de, ORIGINAL_STAIRCASE_PROFILE
|
||||
np.random.seed(42)
|
||||
converged, error_frac = run_de(
|
||||
ORIGINAL_STAIRCASE_PROFILE, lam_s=10.0, lam_b=0.1,
|
||||
z_pop=10000, max_iter=50
|
||||
)
|
||||
assert converged, f"DE should converge at lam_s=10, error_frac={error_frac}"
|
||||
|
||||
def test_de_known_bad_fails(self):
|
||||
"""DE at very low lam_s=0.1 should not converge."""
|
||||
from density_evolution import run_de, ORIGINAL_STAIRCASE_PROFILE
|
||||
np.random.seed(42)
|
||||
converged, error_frac = run_de(
|
||||
ORIGINAL_STAIRCASE_PROFILE, lam_s=0.1, lam_b=0.1,
|
||||
z_pop=10000, max_iter=50
|
||||
)
|
||||
assert not converged, f"DE should NOT converge at lam_s=0.1, error_frac={error_frac}"
|
||||
|
||||
def test_de_population_shape(self):
|
||||
"""Verify belief arrays have correct shapes after one step."""
|
||||
from density_evolution import de_channel_init, density_evolution_step
|
||||
np.random.seed(42)
|
||||
n_base = 8
|
||||
m_base = 7
|
||||
z_pop = 1000
|
||||
|
||||
# Original staircase H_base profile
|
||||
from density_evolution import ORIGINAL_STAIRCASE_PROFILE
|
||||
beliefs, msg_memory = de_channel_init(ORIGINAL_STAIRCASE_PROFILE, z_pop, lam_s=5.0, lam_b=0.1)
|
||||
|
||||
# beliefs should be (n_base, z_pop)
|
||||
assert beliefs.shape == (n_base, z_pop), f"Expected ({n_base}, {z_pop}), got {beliefs.shape}"
|
||||
|
||||
# Run one step
|
||||
beliefs = density_evolution_step(beliefs, msg_memory, ORIGINAL_STAIRCASE_PROFILE, z_pop)
|
||||
assert beliefs.shape == (n_base, z_pop), f"Shape changed after step: {beliefs.shape}"
|
||||
Reference in New Issue
Block a user