feat: add normalized min-sum to density evolution engine

Thread cn_mode and alpha parameters through the entire DE pipeline:
de_cn_update_vectorized(), density_evolution_step(), run_de(),
compute_threshold(), and compute_threshold_for_profile().

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
cah
2026-02-24 16:37:57 -07:00
parent b04813fa7c
commit e657e9baf1
2 changed files with 65 additions and 14 deletions

View File

@@ -102,13 +102,15 @@ def de_channel_init(profile, z_pop, lam_s, lam_b):
return beliefs, msg_memory
def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET, cn_mode='offset', alpha=0.75):
"""
Vectorized offset min-sum CN update for a batch of z_pop messages.
Vectorized min-sum CN update for a batch of z_pop messages.
Args:
vn_msgs_list: list of ndarrays, each (z_pop,) - one per connected VN
offset: min-sum offset
offset: min-sum offset (used in offset mode)
cn_mode: 'offset' or 'normalized'
alpha: scaling factor for normalized mode
Returns:
cn_out_list: list of ndarrays, each (z_pop,) - CN->VN messages
@@ -145,7 +147,10 @@ def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
cn_out_list = []
for j in range(dc):
mag = np.where(first_min[j], min2, min1)
mag = np.maximum(0, mag - offset)
if cn_mode == 'normalized':
mag = np.floor(mag * alpha).astype(mag.dtype)
else:
mag = np.maximum(0, mag - offset)
# Extrinsic sign: total XOR minus this input's sign
ext_sign = sign_xor ^ signs[j]
@@ -158,7 +163,8 @@ def de_cn_update_vectorized(vn_msgs_list, offset=OFFSET):
return cn_out_list
def density_evolution_step(beliefs, msg_memory, profile, z_pop):
def density_evolution_step(beliefs, msg_memory, profile, z_pop,
cn_mode='offset', alpha=0.75):
"""
One full DE iteration: process all rows (layers) of the base matrix.
@@ -170,6 +176,8 @@ def density_evolution_step(beliefs, msg_memory, profile, z_pop):
msg_memory: dict (row, col) -> ndarray(z_pop) of old CN->VN messages
profile: DE profile dict
z_pop: population size
cn_mode: 'offset' or 'normalized'
alpha: scaling factor for normalized mode
Returns:
beliefs: updated beliefs (modified in-place and returned)
@@ -197,7 +205,8 @@ def density_evolution_step(beliefs, msg_memory, profile, z_pop):
vn_to_cn.append(vn_msg)
# Step 2: CN update (vectorized min-sum)
cn_to_vn = de_cn_update_vectorized(vn_to_cn, offset=OFFSET)
cn_to_vn = de_cn_update_vectorized(vn_to_cn, offset=OFFSET,
cn_mode=cn_mode, alpha=alpha)
# Step 3: Update beliefs and store new messages
for ci, col in enumerate(connected_cols):
@@ -218,7 +227,8 @@ def density_evolution_step(beliefs, msg_memory, profile, z_pop):
return beliefs
def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100,
cn_mode='offset', alpha=0.75):
"""
Run full density evolution simulation.
@@ -228,6 +238,8 @@ def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
lam_b: background photons per slot
z_pop: population size
max_iter: maximum iterations
cn_mode: 'offset' or 'normalized'
alpha: scaling factor for normalized mode
Returns:
(converged: bool, avg_error_frac: float)
@@ -236,7 +248,8 @@ def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
beliefs, msg_memory = de_channel_init(profile, z_pop, lam_s, lam_b)
for it in range(max_iter):
beliefs = density_evolution_step(beliefs, msg_memory, profile, z_pop)
beliefs = density_evolution_step(beliefs, msg_memory, profile, z_pop,
cn_mode=cn_mode, alpha=alpha)
# Check convergence: for all-zeros codeword, correct belief is positive
# Error = fraction of beliefs with wrong sign (negative)
@@ -331,7 +344,8 @@ def build_de_profile(vn_degrees, m_base=7):
}
def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1,
cn_mode='offset', alpha=0.75):
"""
Binary search for minimum lam_s where DE converges.
@@ -341,18 +355,21 @@ def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
hi = 20.0
# Verify hi converges
converged, _ = run_de(profile, lam_s=hi, lam_b=lam_b, z_pop=z_pop, max_iter=100)
converged, _ = run_de(profile, lam_s=hi, lam_b=lam_b, z_pop=z_pop,
max_iter=100, cn_mode=cn_mode, alpha=alpha)
if not converged:
return hi # Doesn't converge even at hi
# Verify lo doesn't converge
converged, _ = run_de(profile, lam_s=lo, lam_b=lam_b, z_pop=z_pop, max_iter=100)
converged, _ = run_de(profile, lam_s=lo, lam_b=lam_b, z_pop=z_pop,
max_iter=100, cn_mode=cn_mode, alpha=alpha)
if converged:
return lo # Converges even at lo
while hi - lo > tol:
mid = (lo + hi) / 2
converged, _ = run_de(profile, lam_s=mid, lam_b=lam_b, z_pop=z_pop, max_iter=100)
converged, _ = run_de(profile, lam_s=mid, lam_b=lam_b, z_pop=z_pop,
max_iter=100, cn_mode=cn_mode, alpha=alpha)
if converged:
hi = mid
else:
@@ -361,12 +378,14 @@ def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
return hi
def compute_threshold_for_profile(vn_degrees, m_base=7, lam_b=0.1, z_pop=50000, tol=0.1):
def compute_threshold_for_profile(vn_degrees, m_base=7, lam_b=0.1, z_pop=50000, tol=0.1,
cn_mode='offset', alpha=0.75):
"""
Convenience wrapper: compute DE threshold from a VN degree list.
"""
profile = build_de_profile(vn_degrees, m_base=m_base)
return compute_threshold(profile, lam_b=lam_b, z_pop=z_pop, tol=tol)
return compute_threshold(profile, lam_b=lam_b, z_pop=z_pop, tol=tol,
cn_mode=cn_mode, alpha=alpha)
# =============================================================================