263 lines
8.7 KiB
Python
263 lines
8.7 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Density Evolution Optimizer for QC-LDPC Codes
|
|
|
|
Monte Carlo density evolution to find optimal base matrix degree distributions
|
|
that lower the decoding threshold for photon-starved optical communication.
|
|
|
|
Usage:
|
|
python3 density_evolution.py threshold # Compute threshold for known matrices
|
|
python3 density_evolution.py optimize # Run optimizer, print top-10
|
|
python3 density_evolution.py construct DEGREES # Build matrix for given degrees
|
|
python3 density_evolution.py validate # FER comparison: optimized vs original
|
|
python3 density_evolution.py full # Full pipeline
|
|
"""
|
|
|
|
import numpy as np
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
|
|
sys.path.insert(0, os.path.dirname(__file__))
|
|
|
|
from ldpc_sim import (
|
|
Q_BITS, Q_MAX, Q_MIN, OFFSET, Z, N_BASE, M_BASE, H_BASE,
|
|
build_full_h_matrix,
|
|
)
|
|
|
|
# =============================================================================
|
|
# H_base profile representation for density evolution
|
|
# =============================================================================
|
|
# A "profile" is a dict with:
|
|
# 'n_base': int - number of columns
|
|
# 'm_base': int - number of rows
|
|
# 'connections': list of lists - connections[row] = list of connected column indices
|
|
# 'vn_degrees': list of int - degree of each column
|
|
#
|
|
# For DE, we don't need circulant shifts - only the connectivity pattern matters.
|
|
|
|
def make_profile(H_base):
|
|
"""Convert a base matrix to a DE profile."""
|
|
m_base = H_base.shape[0]
|
|
n_base = H_base.shape[1]
|
|
connections = []
|
|
for r in range(m_base):
|
|
row_cols = [c for c in range(n_base) if H_base[r, c] >= 0]
|
|
connections.append(row_cols)
|
|
vn_degrees = []
|
|
for c in range(n_base):
|
|
vn_degrees.append(sum(1 for r in range(m_base) if H_base[r, c] >= 0))
|
|
return {
|
|
'n_base': n_base,
|
|
'm_base': m_base,
|
|
'connections': connections,
|
|
'vn_degrees': vn_degrees,
|
|
}
|
|
|
|
|
|
ORIGINAL_STAIRCASE_PROFILE = make_profile(H_BASE)
|
|
|
|
|
|
# =============================================================================
|
|
# Monte Carlo Density Evolution Engine
|
|
# =============================================================================
|
|
|
|
def de_channel_init(profile, z_pop, lam_s, lam_b):
|
|
"""
|
|
Generate initial LLR population for each VN column group.
|
|
|
|
Returns:
|
|
beliefs: ndarray (n_base, z_pop) - LLR beliefs per column group
|
|
msg_memory: dict mapping (row, col_idx_in_row) -> ndarray(z_pop) of CN->VN messages
|
|
"""
|
|
n_base = profile['n_base']
|
|
m_base = profile['m_base']
|
|
|
|
# For each column, generate z_pop independent channel observations
|
|
# All columns transmit bit=0 (all-zeros codeword) for DE
|
|
# P(y|0): photon count ~ Poisson(lam_b)
|
|
# LLR = lam_s - y * log((lam_s + lam_b) / lam_b)
|
|
|
|
beliefs = np.zeros((n_base, z_pop), dtype=np.float64)
|
|
log_ratio = np.log((lam_s + lam_b) / lam_b) if lam_b > 0 else 100.0
|
|
|
|
for c in range(n_base):
|
|
# Channel observation: all-zero codeword, so photons ~ Poisson(lam_b)
|
|
y = np.random.poisson(lam_b, size=z_pop)
|
|
llr_float = lam_s - y * log_ratio
|
|
|
|
# Quantize to 6-bit signed
|
|
scale = Q_MAX / 5.0
|
|
llr_q = np.round(llr_float * scale).astype(np.int32)
|
|
llr_q = np.clip(llr_q, Q_MIN, Q_MAX)
|
|
beliefs[c] = llr_q.astype(np.float64)
|
|
|
|
# Initialize CN->VN message memory to zero
|
|
msg_memory = {}
|
|
for r in range(m_base):
|
|
for ci, c in enumerate(profile['connections'][r]):
|
|
msg_memory[(r, c)] = np.zeros(z_pop, dtype=np.float64)
|
|
|
|
return beliefs, msg_memory
|
|
|
|
|
|
def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
|
|
"""
|
|
Vectorized offset min-sum CN update for a batch of z_pop messages.
|
|
|
|
Args:
|
|
vn_msgs_list: list of ndarrays, each (z_pop,) - one per connected VN
|
|
offset: min-sum offset
|
|
|
|
Returns:
|
|
cn_out_list: list of ndarrays, each (z_pop,) - CN->VN messages
|
|
"""
|
|
dc = len(vn_msgs_list)
|
|
if dc < 2:
|
|
return [np.zeros_like(vn_msgs_list[0])] if dc == 1 else []
|
|
|
|
z_pop = len(vn_msgs_list[0])
|
|
|
|
# Stack into (dc, z_pop) array
|
|
msgs = np.array(vn_msgs_list, dtype=np.float64)
|
|
|
|
# Signs and magnitudes
|
|
signs = (msgs < 0).astype(np.int8) # (dc, z_pop)
|
|
mags = np.abs(msgs) # (dc, z_pop)
|
|
|
|
# Total sign XOR across all inputs
|
|
sign_xor = np.sum(signs, axis=0) % 2 # (z_pop,)
|
|
|
|
# Find min1 and min2 magnitudes
|
|
# Sort magnitudes along dc axis, take smallest two
|
|
sorted_mags = np.sort(mags, axis=0) # (dc, z_pop)
|
|
min1 = sorted_mags[0] # (z_pop,)
|
|
min2 = sorted_mags[1] # (z_pop,)
|
|
|
|
# For each output j: magnitude = min2 if j has min1, else min1
|
|
# Find which input has the minimum magnitude
|
|
min1_mask = (mags == min1[np.newaxis, :]) # (dc, z_pop)
|
|
# Break ties: only first occurrence gets min2
|
|
# Use cumsum trick to mark only the first True per column
|
|
first_min = min1_mask & (np.cumsum(min1_mask, axis=0) == 1)
|
|
|
|
cn_out_list = []
|
|
for j in range(dc):
|
|
mag = np.where(first_min[j], min2, min1)
|
|
mag = np.maximum(0, mag - offset)
|
|
|
|
# Extrinsic sign: total XOR minus this input's sign
|
|
ext_sign = sign_xor ^ signs[j]
|
|
val = np.where(ext_sign, -mag, mag)
|
|
|
|
# Quantize
|
|
val = np.clip(np.round(val), Q_MIN, Q_MAX)
|
|
cn_out_list.append(val)
|
|
|
|
return cn_out_list
|
|
|
|
|
|
def density_evolution_step(beliefs, msg_memory, profile, z_pop):
|
|
"""
|
|
One full DE iteration: process all rows (layers) of the base matrix.
|
|
|
|
For DE, we randomly permute within each column group to simulate
|
|
the random interleaving effect (circulant shifts are random for DE).
|
|
|
|
Args:
|
|
beliefs: ndarray (n_base, z_pop) - current VN beliefs
|
|
msg_memory: dict (row, col) -> ndarray(z_pop) of old CN->VN messages
|
|
profile: DE profile dict
|
|
z_pop: population size
|
|
|
|
Returns:
|
|
beliefs: updated beliefs (modified in-place and returned)
|
|
"""
|
|
m_base = profile['m_base']
|
|
|
|
for row in range(m_base):
|
|
connected_cols = profile['connections'][row]
|
|
dc = len(connected_cols)
|
|
if dc < 2:
|
|
continue
|
|
|
|
# Step 1: Compute VN->CN messages (subtract old CN->VN)
|
|
# Randomly permute within each column group for DE
|
|
vn_to_cn = []
|
|
permuted_indices = []
|
|
for ci, col in enumerate(connected_cols):
|
|
perm = np.random.permutation(z_pop)
|
|
permuted_indices.append(perm)
|
|
old_msg = msg_memory[(row, col)]
|
|
# VN->CN = belief[col][permuted] - old_CN->VN
|
|
vn_msg = beliefs[col][perm] - old_msg
|
|
# Saturate
|
|
vn_msg = np.clip(np.round(vn_msg), Q_MIN, Q_MAX)
|
|
vn_to_cn.append(vn_msg)
|
|
|
|
# Step 2: CN update (vectorized min-sum)
|
|
cn_to_vn = de_cn_update_vectorized(vn_to_cn, offset=OFFSET)
|
|
|
|
# Step 3: Update beliefs and store new messages
|
|
for ci, col in enumerate(connected_cols):
|
|
perm = permuted_indices[ci]
|
|
new_msg = cn_to_vn[ci]
|
|
extrinsic = vn_to_cn[ci]
|
|
|
|
# New belief = extrinsic + new CN->VN
|
|
new_belief = extrinsic + new_msg
|
|
new_belief = np.clip(np.round(new_belief), Q_MIN, Q_MAX)
|
|
|
|
# Write back to beliefs at permuted positions
|
|
beliefs[col][perm] = new_belief
|
|
|
|
# Store new CN->VN message
|
|
msg_memory[(row, col)] = new_msg
|
|
|
|
return beliefs
|
|
|
|
|
|
def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
|
|
"""
|
|
Run full density evolution simulation.
|
|
|
|
Args:
|
|
profile: DE profile dict
|
|
lam_s: signal photons per slot
|
|
lam_b: background photons per slot
|
|
z_pop: population size
|
|
max_iter: maximum iterations
|
|
|
|
Returns:
|
|
(converged: bool, avg_error_frac: float)
|
|
Convergence means fraction of wrong-sign beliefs < 1e-4.
|
|
"""
|
|
beliefs, msg_memory = de_channel_init(profile, z_pop, lam_s, lam_b)
|
|
|
|
for it in range(max_iter):
|
|
beliefs = density_evolution_step(beliefs, msg_memory, profile, z_pop)
|
|
|
|
# Check convergence: for all-zeros codeword, correct belief is positive
|
|
# Error = fraction of beliefs with wrong sign (negative)
|
|
total_wrong = 0
|
|
total_count = 0
|
|
for c in range(profile['n_base']):
|
|
total_wrong += np.sum(beliefs[c] < 0)
|
|
total_count += z_pop
|
|
|
|
error_frac = total_wrong / total_count
|
|
if error_frac < 1e-4:
|
|
return True, error_frac
|
|
|
|
return False, error_frac
|
|
|
|
|
|
# =============================================================================
|
|
# CLI placeholder (will be extended in later tasks)
|
|
# =============================================================================
|
|
|
|
if __name__ == '__main__':
|
|
print("Density Evolution Optimizer")
|
|
print("Run with subcommands: threshold, optimize, construct, validate, full")
|