From a09c5f20e1df89ece55a5c63b50f4a28b5986860 Mon Sep 17 00:00:00 2001 From: cah Date: Tue, 24 Feb 2026 05:53:23 -0700 Subject: [PATCH] feat: add degree distribution optimizer with exhaustive search Co-Authored-By: Claude Opus 4.6 --- model/density_evolution.py | 143 ++++++++++++++++++++++++++++++++ model/test_density_evolution.py | 39 +++++++++ 2 files changed, 182 insertions(+) diff --git a/model/density_evolution.py b/model/density_evolution.py index d7bf499..c14aca6 100644 --- a/model/density_evolution.py +++ b/model/density_evolution.py @@ -369,6 +369,149 @@ def compute_threshold_for_profile(vn_degrees, m_base=7, lam_b=0.1, z_pop=50000, return compute_threshold(profile, lam_b=lam_b, z_pop=z_pop, tol=tol) +# ============================================================================= +# Degree Distribution Optimizer +# ============================================================================= + +def enumerate_vn_candidates(m_base=7): + """ + Enumerate all VN degree distributions for parity columns. + + Col 0 is always dv=m_base. Parity cols 1..m_base each have dv in {2, 3, 4}. + Returns list of degree vectors (length m_base+1). + """ + from itertools import product + candidates = [] + for combo in product([2, 3, 4], repeat=m_base): + degrees = [m_base] + list(combo) + candidates.append(degrees) + return candidates + + +def filter_by_row_degree(candidates, m_base=7, dc_min=3, dc_max=6): + """ + Filter candidates by row degree constraints. + + For a valid distribution, the total edges must be distributable such that + each row has degree in [dc_min, dc_max]. + + For our structure: info col contributes 1 edge to each row (m_base total). + Parity edges must distribute to give each row dc in [dc_min, dc_max]. + """ + filtered = [] + for degrees in candidates: + n_base = len(degrees) + # Total parity edges = sum of parity column degrees + parity_edges = sum(degrees[1:]) + # Info col contributes 1 edge per row + # So total edges per row = 1 (from info) + parity edges assigned to that row + # Total parity edges must be distributable: each row gets (dc - 1) parity edges + # where dc_min <= dc <= dc_max + # So: m_base * (dc_min - 1) <= parity_edges <= m_base * (dc_max - 1) + min_parity = m_base * (dc_min - 1) + max_parity = m_base * (dc_max - 1) + if min_parity <= parity_edges <= max_parity: + filtered.append(degrees) + return filtered + + +def coarse_screen(candidates, lam_s_test, lam_b, z_pop, max_iter, m_base=7): + """ + Quick convergence test: run DE at a test point, keep candidates that converge. + """ + survivors = [] + for degrees in candidates: + profile = build_de_profile(degrees, m_base=m_base) + converged, error_frac = run_de( + profile, lam_s=lam_s_test, lam_b=lam_b, + z_pop=z_pop, max_iter=max_iter + ) + if converged: + survivors.append(degrees) + return survivors + + +def get_unique_distributions(candidates): + """ + Group candidates by sorted parity degree sequence. + + For DE, only the degree distribution matters, not which column has + which degree. Returns list of representative degree vectors (one per + unique distribution), with parity degrees sorted descending. + """ + seen = set() + unique = [] + for degrees in candidates: + # Sort parity degrees descending for canonical form + parity_sorted = tuple(sorted(degrees[1:], reverse=True)) + if parity_sorted not in seen: + seen.add(parity_sorted) + # Use canonical form: info degree + sorted parity + unique.append([degrees[0]] + list(parity_sorted)) + return unique + + +def optimize_degree_distribution(m_base=7, lam_b=0.1, top_k=10, + z_pop_coarse=10000, z_pop_fine=50000, + tol=0.1): + """ + Full optimization pipeline: enumerate, filter, coarse screen, fine threshold. + + Key optimization: for DE, only the degree distribution matters (not column + ordering), so we group 2187 candidates into ~36 unique distributions. + + Returns list of (vn_degrees, threshold) sorted by threshold ascending. + """ + print("Step 1: Enumerating candidates...") + candidates = enumerate_vn_candidates(m_base=m_base) + print(f" {len(candidates)} total candidates") + + print("Step 2: Filtering by row degree constraints...") + filtered = filter_by_row_degree(candidates, m_base=m_base, dc_min=3, dc_max=6) + print(f" {len(filtered)} candidates after filtering") + + print("Step 3: Grouping by unique degree distribution...") + unique = get_unique_distributions(filtered) + print(f" {len(unique)} unique distributions") + + print("Step 4: Coarse screening at lam_s=2.0...") + survivors = coarse_screen( + unique, lam_s_test=2.0, lam_b=lam_b, + z_pop=z_pop_coarse, max_iter=50, m_base=m_base + ) + print(f" {len(survivors)} survivors after coarse screen") + + if not survivors: + print(" No survivors at lam_s=2.0, trying lam_s=3.0...") + survivors = coarse_screen( + unique, lam_s_test=3.0, lam_b=lam_b, + z_pop=z_pop_coarse, max_iter=50, m_base=m_base + ) + print(f" {len(survivors)} survivors at lam_s=3.0") + + if not survivors: + print(" No survivors found, returning empty list") + return [] + + print(f"Step 5: Fine threshold computation for {len(survivors)} survivors...") + results = [] + for i, degrees in enumerate(survivors): + profile = build_de_profile(degrees, m_base=m_base) + threshold = compute_threshold(profile, lam_b=lam_b, z_pop=z_pop_fine, tol=tol) + results.append((degrees, threshold)) + if (i + 1) % 5 == 0: + print(f" {i+1}/{len(survivors)} done...") + + # Sort by threshold ascending + results.sort(key=lambda x: x[1]) + + print(f"\nTop-{min(top_k, len(results))} degree distributions:") + for i, (degrees, threshold) in enumerate(results[:top_k]): + print(f" #{i+1}: {degrees} -> threshold = {threshold:.2f} photons/slot") + + return results[:top_k] + + # ============================================================================= # CLI placeholder (will be extended in later tasks) # ============================================================================= diff --git a/model/test_density_evolution.py b/model/test_density_evolution.py index 06a5ba4..a84869a 100644 --- a/model/test_density_evolution.py +++ b/model/test_density_evolution.py @@ -92,3 +92,42 @@ class TestThresholdComputation: # Every row should have at least 2 connections for r, conns in enumerate(profile['connections']): assert len(conns) >= 2, f"Row {r} has only {len(conns)} connections" + + +class TestDegreeDistributionOptimizer: + """Tests for the exhaustive search optimizer.""" + + def test_enumerate_candidates(self): + """Enumeration should produce 3^7 = 2187 candidates.""" + from density_evolution import enumerate_vn_candidates + candidates = enumerate_vn_candidates(m_base=7) + assert len(candidates) == 3**7, f"Expected 2187, got {len(candidates)}" + # Each candidate should have 8 elements (info col + 7 parity) + for c in candidates: + assert len(c) == 8 + assert c[0] == 7 # info column always degree 7 + + def test_filter_removes_invalid(self): + """Filter should keep valid distributions and remove truly invalid ones.""" + from density_evolution import filter_by_row_degree + # All-dv=2 parity: parity_edges=14, dc_avg=3 -> valid for [3,6] + all_2 = [7, 2, 2, 2, 2, 2, 2, 2] + assert filter_by_row_degree([all_2], m_base=7, dc_min=3, dc_max=6) == [all_2] + # All-dv=4 parity: parity_edges=28, dc_avg=5 -> valid for [3,6] + all_4 = [7, 4, 4, 4, 4, 4, 4, 4] + assert filter_by_row_degree([all_4], m_base=7, dc_min=3, dc_max=6) == [all_4] + # A hypothetical all-dv=1 parity: parity_edges=7, total=14, avg dc=2 < 3 -> invalid + all_1 = [7, 1, 1, 1, 1, 1, 1, 1] + assert filter_by_row_degree([all_1], m_base=7, dc_min=3, dc_max=6) == [] + # With tighter constraints (dc_min=4), all-dv=2 should be removed + assert filter_by_row_degree([all_2], m_base=7, dc_min=4, dc_max=6) == [] + + def test_optimizer_finds_better_than_original(self): + """Optimizer should find a distribution with threshold <= original staircase.""" + from density_evolution import optimize_degree_distribution, compute_threshold_for_profile + np.random.seed(42) + results = optimize_degree_distribution(m_base=7, lam_b=0.1, top_k=5, z_pop_coarse=5000, z_pop_fine=10000, tol=0.5) + assert len(results) > 0, "Optimizer should return at least one result" + best_degrees, best_threshold = results[0] + # Original staircase threshold is ~3-5 photons + assert best_threshold < 6.0, f"Best threshold {best_threshold} should be < 6.0"