feat: add normalized min-sum to density evolution engine
Thread cn_mode and alpha parameters through the entire DE pipeline: de_cn_update_vectorized(), density_evolution_step(), run_de(), compute_threshold(), and compute_threshold_for_profile(). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -102,13 +102,15 @@ def de_channel_init(profile, z_pop, lam_s, lam_b):
|
|||||||
return beliefs, msg_memory
|
return beliefs, msg_memory
|
||||||
|
|
||||||
|
|
||||||
def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
|
def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET, cn_mode='offset', alpha=0.75):
|
||||||
"""
|
"""
|
||||||
Vectorized offset min-sum CN update for a batch of z_pop messages.
|
Vectorized min-sum CN update for a batch of z_pop messages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vn_msgs_list: list of ndarrays, each (z_pop,) - one per connected VN
|
vn_msgs_list: list of ndarrays, each (z_pop,) - one per connected VN
|
||||||
offset: min-sum offset
|
offset: min-sum offset (used in offset mode)
|
||||||
|
cn_mode: 'offset' or 'normalized'
|
||||||
|
alpha: scaling factor for normalized mode
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
cn_out_list: list of ndarrays, each (z_pop,) - CN->VN messages
|
cn_out_list: list of ndarrays, each (z_pop,) - CN->VN messages
|
||||||
@@ -145,6 +147,9 @@ def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
|
|||||||
cn_out_list = []
|
cn_out_list = []
|
||||||
for j in range(dc):
|
for j in range(dc):
|
||||||
mag = np.where(first_min[j], min2, min1)
|
mag = np.where(first_min[j], min2, min1)
|
||||||
|
if cn_mode == 'normalized':
|
||||||
|
mag = np.floor(mag * alpha).astype(mag.dtype)
|
||||||
|
else:
|
||||||
mag = np.maximum(0, mag - offset)
|
mag = np.maximum(0, mag - offset)
|
||||||
|
|
||||||
# Extrinsic sign: total XOR minus this input's sign
|
# Extrinsic sign: total XOR minus this input's sign
|
||||||
@@ -158,7 +163,8 @@ def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
|
|||||||
return cn_out_list
|
return cn_out_list
|
||||||
|
|
||||||
|
|
||||||
def density_evolution_step(beliefs, msg_memory, profile, z_pop):
|
def density_evolution_step(beliefs, msg_memory, profile, z_pop,
|
||||||
|
cn_mode='offset', alpha=0.75):
|
||||||
"""
|
"""
|
||||||
One full DE iteration: process all rows (layers) of the base matrix.
|
One full DE iteration: process all rows (layers) of the base matrix.
|
||||||
|
|
||||||
@@ -170,6 +176,8 @@ def density_evolution_step(beliefs, msg_memory, profile, z_pop):
|
|||||||
msg_memory: dict (row, col) -> ndarray(z_pop) of old CN->VN messages
|
msg_memory: dict (row, col) -> ndarray(z_pop) of old CN->VN messages
|
||||||
profile: DE profile dict
|
profile: DE profile dict
|
||||||
z_pop: population size
|
z_pop: population size
|
||||||
|
cn_mode: 'offset' or 'normalized'
|
||||||
|
alpha: scaling factor for normalized mode
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
beliefs: updated beliefs (modified in-place and returned)
|
beliefs: updated beliefs (modified in-place and returned)
|
||||||
@@ -197,7 +205,8 @@ def density_evolution_step(beliefs, msg_memory, profile, z_pop):
|
|||||||
vn_to_cn.append(vn_msg)
|
vn_to_cn.append(vn_msg)
|
||||||
|
|
||||||
# Step 2: CN update (vectorized min-sum)
|
# Step 2: CN update (vectorized min-sum)
|
||||||
cn_to_vn = de_cn_update_vectorized(vn_to_cn, offset=OFFSET)
|
cn_to_vn = de_cn_update_vectorized(vn_to_cn, offset=OFFSET,
|
||||||
|
cn_mode=cn_mode, alpha=alpha)
|
||||||
|
|
||||||
# Step 3: Update beliefs and store new messages
|
# Step 3: Update beliefs and store new messages
|
||||||
for ci, col in enumerate(connected_cols):
|
for ci, col in enumerate(connected_cols):
|
||||||
@@ -218,7 +227,8 @@ def density_evolution_step(beliefs, msg_memory, profile, z_pop):
|
|||||||
return beliefs
|
return beliefs
|
||||||
|
|
||||||
|
|
||||||
def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
|
def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100,
|
||||||
|
cn_mode='offset', alpha=0.75):
|
||||||
"""
|
"""
|
||||||
Run full density evolution simulation.
|
Run full density evolution simulation.
|
||||||
|
|
||||||
@@ -228,6 +238,8 @@ def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
|
|||||||
lam_b: background photons per slot
|
lam_b: background photons per slot
|
||||||
z_pop: population size
|
z_pop: population size
|
||||||
max_iter: maximum iterations
|
max_iter: maximum iterations
|
||||||
|
cn_mode: 'offset' or 'normalized'
|
||||||
|
alpha: scaling factor for normalized mode
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(converged: bool, avg_error_frac: float)
|
(converged: bool, avg_error_frac: float)
|
||||||
@@ -236,7 +248,8 @@ def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
|
|||||||
beliefs, msg_memory = de_channel_init(profile, z_pop, lam_s, lam_b)
|
beliefs, msg_memory = de_channel_init(profile, z_pop, lam_s, lam_b)
|
||||||
|
|
||||||
for it in range(max_iter):
|
for it in range(max_iter):
|
||||||
beliefs = density_evolution_step(beliefs, msg_memory, profile, z_pop)
|
beliefs = density_evolution_step(beliefs, msg_memory, profile, z_pop,
|
||||||
|
cn_mode=cn_mode, alpha=alpha)
|
||||||
|
|
||||||
# Check convergence: for all-zeros codeword, correct belief is positive
|
# Check convergence: for all-zeros codeword, correct belief is positive
|
||||||
# Error = fraction of beliefs with wrong sign (negative)
|
# Error = fraction of beliefs with wrong sign (negative)
|
||||||
@@ -331,7 +344,8 @@ def build_de_profile(vn_degrees, m_base=7):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
|
def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1,
|
||||||
|
cn_mode='offset', alpha=0.75):
|
||||||
"""
|
"""
|
||||||
Binary search for minimum lam_s where DE converges.
|
Binary search for minimum lam_s where DE converges.
|
||||||
|
|
||||||
@@ -341,18 +355,21 @@ def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
|
|||||||
hi = 20.0
|
hi = 20.0
|
||||||
|
|
||||||
# Verify hi converges
|
# Verify hi converges
|
||||||
converged, _ = run_de(profile, lam_s=hi, lam_b=lam_b, z_pop=z_pop, max_iter=100)
|
converged, _ = run_de(profile, lam_s=hi, lam_b=lam_b, z_pop=z_pop,
|
||||||
|
max_iter=100, cn_mode=cn_mode, alpha=alpha)
|
||||||
if not converged:
|
if not converged:
|
||||||
return hi # Doesn't converge even at hi
|
return hi # Doesn't converge even at hi
|
||||||
|
|
||||||
# Verify lo doesn't converge
|
# Verify lo doesn't converge
|
||||||
converged, _ = run_de(profile, lam_s=lo, lam_b=lam_b, z_pop=z_pop, max_iter=100)
|
converged, _ = run_de(profile, lam_s=lo, lam_b=lam_b, z_pop=z_pop,
|
||||||
|
max_iter=100, cn_mode=cn_mode, alpha=alpha)
|
||||||
if converged:
|
if converged:
|
||||||
return lo # Converges even at lo
|
return lo # Converges even at lo
|
||||||
|
|
||||||
while hi - lo > tol:
|
while hi - lo > tol:
|
||||||
mid = (lo + hi) / 2
|
mid = (lo + hi) / 2
|
||||||
converged, _ = run_de(profile, lam_s=mid, lam_b=lam_b, z_pop=z_pop, max_iter=100)
|
converged, _ = run_de(profile, lam_s=mid, lam_b=lam_b, z_pop=z_pop,
|
||||||
|
max_iter=100, cn_mode=cn_mode, alpha=alpha)
|
||||||
if converged:
|
if converged:
|
||||||
hi = mid
|
hi = mid
|
||||||
else:
|
else:
|
||||||
@@ -361,12 +378,14 @@ def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
|
|||||||
return hi
|
return hi
|
||||||
|
|
||||||
|
|
||||||
def compute_threshold_for_profile(vn_degrees, m_base=7, lam_b=0.1, z_pop=50000, tol=0.1):
|
def compute_threshold_for_profile(vn_degrees, m_base=7, lam_b=0.1, z_pop=50000, tol=0.1,
|
||||||
|
cn_mode='offset', alpha=0.75):
|
||||||
"""
|
"""
|
||||||
Convenience wrapper: compute DE threshold from a VN degree list.
|
Convenience wrapper: compute DE threshold from a VN degree list.
|
||||||
"""
|
"""
|
||||||
profile = build_de_profile(vn_degrees, m_base=m_base)
|
profile = build_de_profile(vn_degrees, m_base=m_base)
|
||||||
return compute_threshold(profile, lam_b=lam_b, z_pop=z_pop, tol=tol)
|
return compute_threshold(profile, lam_b=lam_b, z_pop=z_pop, tol=tol,
|
||||||
|
cn_mode=cn_mode, alpha=alpha)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -249,3 +249,35 @@ class TestNormalizedMinSum:
|
|||||||
assert result_default == result_explicit, (
|
assert result_default == result_explicit, (
|
||||||
f"Default and explicit offset should match: {result_default} vs {result_explicit}"
|
f"Default and explicit offset should match: {result_default} vs {result_explicit}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizedMinSumDE:
|
||||||
|
"""Tests for normalized min-sum in the density evolution engine."""
|
||||||
|
|
||||||
|
def test_de_normalized_converges(self):
|
||||||
|
"""DE at lam_s=10 with normalized mode should converge."""
|
||||||
|
from density_evolution import run_de, ORIGINAL_STAIRCASE_PROFILE
|
||||||
|
np.random.seed(42)
|
||||||
|
converged, error_frac = run_de(
|
||||||
|
ORIGINAL_STAIRCASE_PROFILE, lam_s=10.0, lam_b=0.1,
|
||||||
|
z_pop=10000, max_iter=50,
|
||||||
|
cn_mode='normalized', alpha=0.75
|
||||||
|
)
|
||||||
|
assert converged, f"DE normalized should converge at lam_s=10, error_frac={error_frac}"
|
||||||
|
|
||||||
|
def test_de_normalized_threshold_different(self):
|
||||||
|
"""Normalized and offset modes should produce different thresholds."""
|
||||||
|
from density_evolution import compute_threshold, ORIGINAL_STAIRCASE_PROFILE
|
||||||
|
np.random.seed(42)
|
||||||
|
thresh_offset = compute_threshold(
|
||||||
|
ORIGINAL_STAIRCASE_PROFILE, lam_b=0.1, z_pop=10000, tol=0.5,
|
||||||
|
cn_mode='offset'
|
||||||
|
)
|
||||||
|
np.random.seed(42)
|
||||||
|
thresh_norm = compute_threshold(
|
||||||
|
ORIGINAL_STAIRCASE_PROFILE, lam_b=0.1, z_pop=10000, tol=0.5,
|
||||||
|
cn_mode='normalized', alpha=0.75
|
||||||
|
)
|
||||||
|
assert thresh_norm != thresh_offset, (
|
||||||
|
f"Thresholds should differ: offset={thresh_offset}, normalized={thresh_norm}"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user