feat: add threshold computation via binary search
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -253,6 +253,122 @@ def run_de(profile, lam_s, lam_b, z_pop=100000, max_iter=100):
|
|||||||
return False, error_frac
|
return False, error_frac
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Threshold Computation via Binary Search
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def build_de_profile(vn_degrees, m_base=7):
|
||||||
|
"""
|
||||||
|
Build a DE profile from a VN degree list.
|
||||||
|
|
||||||
|
Col 0 is always info (connects to all m_base rows).
|
||||||
|
Cols 1..n_base-1 are parity columns with staircase backbone
|
||||||
|
plus extra connections to reach target degrees.
|
||||||
|
|
||||||
|
Extra connections are placed below-diagonal to preserve
|
||||||
|
lower-triangular parity structure.
|
||||||
|
"""
|
||||||
|
n_base = len(vn_degrees)
|
||||||
|
assert n_base == m_base + 1, f"Expected {m_base+1} columns, got {n_base}"
|
||||||
|
assert vn_degrees[0] == m_base, f"Info column must have degree {m_base}"
|
||||||
|
|
||||||
|
# Start with staircase backbone connections
|
||||||
|
# connections[row] = set of connected columns
|
||||||
|
connections = [set() for _ in range(m_base)]
|
||||||
|
|
||||||
|
# Col 0 (info) connects to all rows
|
||||||
|
for r in range(m_base):
|
||||||
|
connections[r].add(0)
|
||||||
|
|
||||||
|
# Staircase backbone: row r connects to col r+1 (diagonal)
|
||||||
|
# row r>0 also connects to col r (sub-diagonal)
|
||||||
|
for r in range(m_base):
|
||||||
|
connections[r].add(r + 1) # diagonal
|
||||||
|
if r > 0:
|
||||||
|
connections[r].add(r) # sub-diagonal
|
||||||
|
|
||||||
|
# Check current degrees
|
||||||
|
current_degrees = [0] * n_base
|
||||||
|
for r in range(m_base):
|
||||||
|
for c in connections[r]:
|
||||||
|
current_degrees[c] += 1
|
||||||
|
|
||||||
|
# Add extra connections for parity columns that need higher degree
|
||||||
|
for c in range(1, n_base):
|
||||||
|
needed = vn_degrees[c] - current_degrees[c]
|
||||||
|
if needed <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Find rows we can add this column to (not already connected)
|
||||||
|
# Prefer below-diagonal rows to preserve lower-triangular structure
|
||||||
|
available_rows = []
|
||||||
|
for r in range(m_base):
|
||||||
|
if c not in connections[r]:
|
||||||
|
available_rows.append(r)
|
||||||
|
|
||||||
|
# Sort by distance below diagonal (prefer far below for structure)
|
||||||
|
# Diagonal for col c is row c-1
|
||||||
|
available_rows.sort(key=lambda r: -(r - (c - 1)) % m_base)
|
||||||
|
|
||||||
|
for r in available_rows[:needed]:
|
||||||
|
connections[r].add(c)
|
||||||
|
current_degrees[c] += 1
|
||||||
|
|
||||||
|
# Convert sets to sorted lists
|
||||||
|
connections_list = [sorted(conns) for conns in connections]
|
||||||
|
|
||||||
|
# Verify degrees
|
||||||
|
final_degrees = [0] * n_base
|
||||||
|
for r in range(m_base):
|
||||||
|
for c in connections_list[r]:
|
||||||
|
final_degrees[c] += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
'n_base': n_base,
|
||||||
|
'm_base': m_base,
|
||||||
|
'connections': connections_list,
|
||||||
|
'vn_degrees': final_degrees,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compute_threshold(profile, lam_b=0.1, z_pop=50000, tol=0.1):
|
||||||
|
"""
|
||||||
|
Binary search for minimum lam_s where DE converges.
|
||||||
|
|
||||||
|
Returns threshold lam_s*.
|
||||||
|
"""
|
||||||
|
lo = 0.1
|
||||||
|
hi = 20.0
|
||||||
|
|
||||||
|
# Verify hi converges
|
||||||
|
converged, _ = run_de(profile, lam_s=hi, lam_b=lam_b, z_pop=z_pop, max_iter=100)
|
||||||
|
if not converged:
|
||||||
|
return hi # Doesn't converge even at hi
|
||||||
|
|
||||||
|
# Verify lo doesn't converge
|
||||||
|
converged, _ = run_de(profile, lam_s=lo, lam_b=lam_b, z_pop=z_pop, max_iter=100)
|
||||||
|
if converged:
|
||||||
|
return lo # Converges even at lo
|
||||||
|
|
||||||
|
while hi - lo > tol:
|
||||||
|
mid = (lo + hi) / 2
|
||||||
|
converged, _ = run_de(profile, lam_s=mid, lam_b=lam_b, z_pop=z_pop, max_iter=100)
|
||||||
|
if converged:
|
||||||
|
hi = mid
|
||||||
|
else:
|
||||||
|
lo = mid
|
||||||
|
|
||||||
|
return hi
|
||||||
|
|
||||||
|
|
||||||
|
def compute_threshold_for_profile(vn_degrees, m_base=7, lam_b=0.1, z_pop=50000, tol=0.1):
|
||||||
|
"""
|
||||||
|
Convenience wrapper: compute DE threshold from a VN degree list.
|
||||||
|
"""
|
||||||
|
profile = build_de_profile(vn_degrees, m_base=m_base)
|
||||||
|
return compute_threshold(profile, lam_b=lam_b, z_pop=z_pop, tol=tol)
|
||||||
|
|
||||||
|
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
# CLI placeholder (will be extended in later tasks)
|
# CLI placeholder (will be extended in later tasks)
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|||||||
@@ -50,3 +50,45 @@ class TestDensityEvolution:
|
|||||||
# Run one step
|
# Run one step
|
||||||
beliefs = density_evolution_step(beliefs, msg_memory, ORIGINAL_STAIRCASE_PROFILE, z_pop)
|
beliefs = density_evolution_step(beliefs, msg_memory, ORIGINAL_STAIRCASE_PROFILE, z_pop)
|
||||||
assert beliefs.shape == (n_base, z_pop), f"Shape changed after step: {beliefs.shape}"
|
assert beliefs.shape == (n_base, z_pop), f"Shape changed after step: {beliefs.shape}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestThresholdComputation:
|
||||||
|
"""Tests for threshold binary search."""
|
||||||
|
|
||||||
|
def test_threshold_original_staircase(self):
|
||||||
|
"""Threshold for original staircase [7,2,2,2,2,2,2,1] should be ~3-6 photons."""
|
||||||
|
from density_evolution import compute_threshold_for_profile
|
||||||
|
np.random.seed(42)
|
||||||
|
threshold = compute_threshold_for_profile(
|
||||||
|
[7, 2, 2, 2, 2, 2, 2, 1], m_base=7, lam_b=0.1,
|
||||||
|
z_pop=10000, tol=0.5
|
||||||
|
)
|
||||||
|
assert 2.0 < threshold < 8.0, f"Expected threshold ~3-6, got {threshold}"
|
||||||
|
|
||||||
|
def test_threshold_peg_ring(self):
|
||||||
|
"""PEG ring [7,3,3,3,2,2,2,2] should have lower or equal threshold than original."""
|
||||||
|
from density_evolution import compute_threshold_for_profile
|
||||||
|
np.random.seed(42)
|
||||||
|
thresh_orig = compute_threshold_for_profile(
|
||||||
|
[7, 2, 2, 2, 2, 2, 2, 1], m_base=7, lam_b=0.1,
|
||||||
|
z_pop=15000, tol=0.25
|
||||||
|
)
|
||||||
|
np.random.seed(123)
|
||||||
|
thresh_peg = compute_threshold_for_profile(
|
||||||
|
[7, 3, 3, 3, 2, 2, 2, 2], m_base=7, lam_b=0.1,
|
||||||
|
z_pop=15000, tol=0.25
|
||||||
|
)
|
||||||
|
assert thresh_peg <= thresh_orig, (
|
||||||
|
f"PEG threshold {thresh_peg} should be <= original {thresh_orig}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_profile_to_hbase(self):
|
||||||
|
"""build_de_profile should produce valid profile with correct column degrees."""
|
||||||
|
from density_evolution import build_de_profile
|
||||||
|
profile = build_de_profile([7, 3, 2, 2, 2, 2, 2, 2], m_base=7)
|
||||||
|
assert profile['n_base'] == 8
|
||||||
|
assert profile['m_base'] == 7
|
||||||
|
assert profile['vn_degrees'] == [7, 3, 2, 2, 2, 2, 2, 2]
|
||||||
|
# Every row should have at least 2 connections
|
||||||
|
for r, conns in enumerate(profile['connections']):
|
||||||
|
assert len(conns) >= 2, f"Row {r} has only {len(conns)} connections"
|
||||||
|
|||||||
Reference in New Issue
Block a user