feat: add normalized min-sum to density evolution engine
Thread cn_mode and alpha parameters through the entire DE pipeline: de_cn_update_vectorized(), density_evolution_step(), run_de(), compute_threshold(), and compute_threshold_for_profile(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -102,13 +102,15 @@ def de_channel_init(profile, z_pop, lam_s, lam_b):
|
||||
return beliefs, msg_memory
|
||||
|
||||
|
||||
def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
|
||||
def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET, cn_mode='offset', alpha=0.75):
|
||||
"""
|
||||
Vectorized offset min-sum CN update for a batch of z_pop messages.
|
||||
Vectorized min-sum CN update for a batch of z_pop messages.
|
||||
|
||||
Args:
|
||||
vn_msgs_list: list of ndarrays, each (z_pop,) - one per connected VN
|
||||
offset: min-sum offset
|
||||
offset: min-sum offset (used in offset mode)
|
||||
cn_mode: 'offset' or 'normalized'
|
||||
alpha: scaling factor for normalized mode
|
||||
|
||||
Returns:
|
||||
cn_out_list: list of ndarrays, each (z_pop,) - CN->VN messages
|
||||
@@ -145,7 +147,10 @@ def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
|
||||
cn_out_list = []
|
||||
for j in range(dc):
|
||||
mag = np.where(first_min[j], min2, min1)
|
||||
mag = np.maximum(0, mag - offset)
|
||||
if cn_mode == 'normalized':
|
||||
mag = np.floor(mag * alpha).astype(mag.dtype)
|
||||
else:
|
||||
mag = np.maximum(0, mag - offset)
|
||||
|
||||
# Extrinsic sign: total XOR minus this input's sign
|
||||
ext_sign = sign_xor ^ signs[j]
|
||||
@@ -158,7 +163,8 @@ def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
|
||||
return cn_out_list
|
||||
|
||||
|
||||
def density_evolution_step(beliefs, msg_memory, profile, z_pop):
|
||||
def density_evolution_step(beliefs, msg_memory, profile, z_pop,
|
||||
cn_mode='offset', alpha=0.75):
|
||||
"""
|
||||
One full DE iteration: process all rows (layers) of the base matrix.
|
||||
|
||||
@@ -170,6 +176,8 @@ def density_evolution_step(beliefs, msg_memory, profile, z_pop):
|
||||
msg_memory: dict (row, col) -> ndarray(z_pop) of old CN->VN messages
|
||||
profile: DE profile dict
|
||||
z_pop: population size
|
||||
cn_mode: 'offset' or 'normalized'
|
||||
alpha: scaling factor for normalized mode
|
||||
|
||||
Returns:
|
||||
beliefs: updated beliefs (modified in-place and returned)
|
||||
@@ -197,7 +205,8 @@ def density_evolution_step(beliefs, msg_memory, profile, z_pop):
|
||||
vn_to_cn.append(vn_msg)
|
||||
|
||||
# Step 2: CN update (vectorized min-sum)
|
||||
cn_to_vn = de_cn_update_vectorized(vn_to_cn, offset=OFFSET)
|
||||
cn_to_vn = de_cn_update_vectorized(vn_to_cn, offset=OFFSET,
|
||||
cn_mode=cn_mode, alpha=alpha)
|
||||
|
||||
# Step 3: Update beliefs and store new messages
|
||||
for ci, col in enumerate(connected_cols):
|
||||
@@ -218,7 +227,8 @@ def density_evolution_step(beliefs, msg_memory, profile, z_pop):
|
||||
return beliefs
|
||||
|
||||
|
||||
def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
|
||||
def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100,
|
||||
cn_mode='offset', alpha=0.75):
|
||||
"""
|
||||
Run full density evolution simulation.
|
||||
|
||||
@@ -228,6 +238,8 @@ def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
|
||||
lam_b: background photons per slot
|
||||
z_pop: population size
|
||||
max_iter: maximum iterations
|
||||
cn_mode: 'offset' or 'normalized'
|
||||
alpha: scaling factor for normalized mode
|
||||
|
||||
Returns:
|
||||
(converged: bool, avg_error_frac: float)
|
||||
@@ -236,7 +248,8 @@ def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
|
||||
beliefs, msg_memory = de_channel_init(profile, z_pop, lam_s, lam_b)
|
||||
|
||||
for it in range(max_iter):
|
||||
beliefs = density_evolution_step(beliefs, msg_memory, profile, z_pop)
|
||||
beliefs = density_evolution_step(beliefs, msg_memory, profile, z_pop,
|
||||
cn_mode=cn_mode, alpha=alpha)
|
||||
|
||||
# Check convergence: for all-zeros codeword, correct belief is positive
|
||||
# Error = fraction of beliefs with wrong sign (negative)
|
||||
@@ -331,7 +344,8 @@ def build_de_profile(vn_degrees, m_base=7):
|
||||
}
|
||||
|
||||
|
||||
def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
|
||||
def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1,
|
||||
cn_mode='offset', alpha=0.75):
|
||||
"""
|
||||
Binary search for minimum lam_s where DE converges.
|
||||
|
||||
@@ -341,18 +355,21 @@ def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
|
||||
hi = 20.0
|
||||
|
||||
# Verify hi converges
|
||||
converged, _ = run_de(profile, lam_s=hi, lam_b=lam_b, z_pop=z_pop, max_iter=100)
|
||||
converged, _ = run_de(profile, lam_s=hi, lam_b=lam_b, z_pop=z_pop,
|
||||
max_iter=100, cn_mode=cn_mode, alpha=alpha)
|
||||
if not converged:
|
||||
return hi # Doesn't converge even at hi
|
||||
|
||||
# Verify lo doesn't converge
|
||||
converged, _ = run_de(profile, lam_s=lo, lam_b=lam_b, z_pop=z_pop, max_iter=100)
|
||||
converged, _ = run_de(profile, lam_s=lo, lam_b=lam_b, z_pop=z_pop,
|
||||
max_iter=100, cn_mode=cn_mode, alpha=alpha)
|
||||
if converged:
|
||||
return lo # Converges even at lo
|
||||
|
||||
while hi - lo > tol:
|
||||
mid = (lo + hi) / 2
|
||||
converged, _ = run_de(profile, lam_s=mid, lam_b=lam_b, z_pop=z_pop, max_iter=100)
|
||||
converged, _ = run_de(profile, lam_s=mid, lam_b=lam_b, z_pop=z_pop,
|
||||
max_iter=100, cn_mode=cn_mode, alpha=alpha)
|
||||
if converged:
|
||||
hi = mid
|
||||
else:
|
||||
@@ -361,12 +378,14 @@ def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
|
||||
return hi
|
||||
|
||||
|
||||
def compute_threshold_for_profile(vn_degrees, m_base=7, lam_b=0.1, z_pop=50000, tol=0.1):
|
||||
def compute_threshold_for_profile(vn_degrees, m_base=7, lam_b=0.1, z_pop=50000, tol=0.1,
|
||||
cn_mode='offset', alpha=0.75):
|
||||
"""
|
||||
Convenience wrapper: compute DE threshold from a VN degree list.
|
||||
"""
|
||||
profile = build_de_profile(vn_degrees, m_base=m_base)
|
||||
return compute_threshold(profile, lam_b=lam_b, z_pop=z_pop, tol=tol)
|
||||
return compute_threshold(profile, lam_b=lam_b, z_pop=z_pop, tol=tol,
|
||||
cn_mode=cn_mode, alpha=alpha)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -249,3 +249,35 @@ class TestNormalizedMinSum:
|
||||
assert result_default == result_explicit, (
|
||||
f"Default and explicit offset should match: {result_default} vs {result_explicit}"
|
||||
)
|
||||
|
||||
|
||||
class TestNormalizedMinSumDE:
|
||||
"""Tests for normalized min-sum in the density evolution engine."""
|
||||
|
||||
def test_de_normalized_converges(self):
|
||||
"""DE at lam_s=10 with normalized mode should converge."""
|
||||
from density_evolution import run_de, ORIGINAL_STAIRCASE_PROFILE
|
||||
np.random.seed(42)
|
||||
converged, error_frac = run_de(
|
||||
ORIGINAL_STAIRCASE_PROFILE, lam_s=10.0, lam_b=0.1,
|
||||
z_pop=10000, max_iter=50,
|
||||
cn_mode='normalized', alpha=0.75
|
||||
)
|
||||
assert converged, f"DE normalized should converge at lam_s=10, error_frac={error_frac}"
|
||||
|
||||
def test_de_normalized_threshold_different(self):
|
||||
"""Normalized and offset modes should produce different thresholds."""
|
||||
from density_evolution import compute_threshold, ORIGINAL_STAIRCASE_PROFILE
|
||||
np.random.seed(42)
|
||||
thresh_offset = compute_threshold(
|
||||
ORIGINAL_STAIRCASE_PROFILE, lam_b=0.1, z_pop=10000, tol=0.5,
|
||||
cn_mode='offset'
|
||||
)
|
||||
np.random.seed(42)
|
||||
thresh_norm = compute_threshold(
|
||||
ORIGINAL_STAIRCASE_PROFILE, lam_b=0.1, z_pop=10000, tol=0.5,
|
||||
cn_mode='normalized', alpha=0.75
|
||||
)
|
||||
assert thresh_norm != thresh_offset, (
|
||||
f"Thresholds should differ: offset={thresh_offset}, normalized={thresh_norm}"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user