Files
ldpc_optical/model/test_sc_ldpc.py
cah 41e2ef72ec feat: add SC-LDPC density evolution with threshold computation
Implement position-aware density evolution for SC-LDPC codes:
- sc_density_evolution(): flooding-schedule DE tracking per-position
  error rates, demonstrating the wave decoding effect
- compute_sc_threshold(): binary search for SC-LDPC threshold

Uses flooding schedule (not layered) to avoid belief divergence
from cross-position message interference in the coupled chain.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-24 17:45:21 -07:00

203 lines
9.0 KiB
Python

#!/usr/bin/env python3
"""Tests for SC-LDPC construction."""
import numpy as np
import pytest
import sys
import os
sys.path.insert(0, os.path.dirname(__file__))
class TestSCLDPCConstruction:
"""Tests for SC-LDPC chain construction."""
def test_split_preserves_edges(self):
"""Split B into components, verify sum(B_i >= 0) == (B >= 0) for each position."""
from sc_ldpc import split_protograph
from ldpc_sim import H_BASE
components = split_protograph(H_BASE, w=2, seed=42)
assert len(components) == 2
# Each edge in B should appear in exactly one component
for r in range(H_BASE.shape[0]):
for c in range(H_BASE.shape[1]):
if H_BASE[r, c] >= 0:
count = sum(1 for comp in components if comp[r, c] >= 0)
assert count == 1, f"Edge ({r},{c}) in {count} components, expected 1"
else:
for comp in components:
assert comp[r, c] < 0, f"Edge ({r},{c}) should be absent"
def test_chain_dimensions(self):
"""Build chain with L=5, w=2, Z=32. Verify H dimensions."""
from sc_ldpc import build_sc_chain
from ldpc_sim import H_BASE
m_base, n_base = H_BASE.shape
L, w, z = 5, 2, 32
H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42)
expected_rows = L * m_base * z # 5*7*32 = 1120
expected_cols = (L + w - 1) * n_base * z # 6*8*32 = 1536
assert H_full.shape == (expected_rows, expected_cols), (
f"Expected ({expected_rows}, {expected_cols}), got {H_full.shape}"
)
def test_chain_has_coupled_structure(self):
"""Verify non-zero blocks only appear at positions t and t+1 for w=2."""
from sc_ldpc import build_sc_chain
from ldpc_sim import H_BASE
m_base, n_base = H_BASE.shape
L, w, z = 5, 2, 32
H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42)
# For each CN position t, check which VN positions have connections
for t in range(L):
cn_rows = slice(t * m_base * z, (t + 1) * m_base * z)
for v in range(L + w - 1):
vn_cols = slice(v * n_base * z, (v + 1) * n_base * z)
block = H_full[cn_rows, vn_cols]
has_connections = np.any(block != 0)
if t <= v <= t + w - 1:
# Should have connections (t connects to t..t+w-1)
assert has_connections, f"CN pos {t} should connect to VN pos {v}"
else:
# Should NOT have connections
assert not has_connections, f"CN pos {t} should NOT connect to VN pos {v}"
class TestWindowedDecode:
"""Tests for windowed SC-LDPC decoder."""
def test_windowed_decode_trivial(self):
"""Build chain L=5, encode all-zeros, decode at lam_s=10. Verify correct decode."""
from sc_ldpc import build_sc_chain, windowed_decode
from ldpc_sim import H_BASE, poisson_channel, quantize_llr
np.random.seed(42)
L, w, z = 5, 2, 32
m_base, n_base = H_BASE.shape
H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42)
n_total = H_full.shape[1]
# All-zeros codeword (always valid)
codeword = np.zeros(n_total, dtype=np.int8)
llr_float, _ = poisson_channel(codeword, lam_s=10.0, lam_b=0.1)
llr_q = quantize_llr(llr_float)
decoded, converged, iters = windowed_decode(
llr_q, H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base,
W=5, max_iter=20, cn_mode='normalized', alpha=0.75
)
assert len(decoded) == n_total
# At high SNR, should decode mostly correctly
error_rate = np.mean(decoded != 0)
assert error_rate < 0.05, f"Error rate {error_rate} too high at lam_s=10"
def test_windowed_decode_with_noise(self):
"""Encode random info at lam_s=5, decode. Verify low BER."""
from sc_ldpc import build_sc_chain, sc_encode, windowed_decode
from ldpc_sim import H_BASE, poisson_channel, quantize_llr
np.random.seed(42)
L, w, z = 3, 2, 32
m_base, n_base = H_BASE.shape
H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42)
n_total = H_full.shape[1]
m_total = H_full.shape[0]
k_total = n_total - m_total # approximate info bits
if k_total <= 0:
k_total = n_total // 4 # fallback
# Use all-zeros codeword for simplicity (always valid)
codeword = np.zeros(n_total, dtype=np.int8)
llr_float, _ = poisson_channel(codeword, lam_s=5.0, lam_b=0.1)
llr_q = quantize_llr(llr_float)
decoded, converged, iters = windowed_decode(
llr_q, H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base,
W=3, max_iter=20, cn_mode='normalized', alpha=0.75
)
error_rate = np.mean(decoded != 0)
assert error_rate < 0.15, f"Error rate {error_rate} too high at lam_s=5"
def test_window_size_effect(self):
"""Larger window should decode at least as well as smaller window."""
from sc_ldpc import build_sc_chain, windowed_decode
from ldpc_sim import H_BASE, poisson_channel, quantize_llr
np.random.seed(42)
L, w, z = 5, 2, 32
m_base, n_base = H_BASE.shape
H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42)
n_total = H_full.shape[1]
codeword = np.zeros(n_total, dtype=np.int8)
llr_float, _ = poisson_channel(codeword, lam_s=3.0, lam_b=0.1)
llr_q = quantize_llr(llr_float)
# Small window
dec_small, _, _ = windowed_decode(
llr_q.copy(), H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base,
W=2, max_iter=15, cn_mode='normalized', alpha=0.75
)
err_small = np.mean(dec_small != 0)
# Large window
dec_large, _, _ = windowed_decode(
llr_q.copy(), H_full, L=L, w=w, z=z, n_base=n_base, m_base=m_base,
W=5, max_iter=15, cn_mode='normalized', alpha=0.75
)
err_large = np.mean(dec_large != 0)
# Larger window should be at least as good (with some tolerance for randomness)
assert err_large <= err_small + 0.05, (
f"Large window error {err_large} should be <= small window {err_small} + tolerance"
)
class TestSCDensityEvolution:
"""Tests for SC-LDPC density evolution."""
def test_sc_de_converges_at_high_snr(self):
"""SC-DE should converge at lam_s=10 (well above threshold)."""
from sc_ldpc import sc_density_evolution
from ldpc_sim import H_BASE
np.random.seed(42)
converged, pos_errors, iters = sc_density_evolution(
H_BASE, L=10, w=2, lam_s=10.0, lam_b=0.1,
z_pop=5000, max_iter=50, cn_mode='normalized', alpha=0.75
)
assert converged, f"SC-DE should converge at lam_s=10, max error={max(pos_errors):.4f}"
def test_sc_threshold_lower_than_uncoupled(self):
"""SC threshold should be lower than uncoupled threshold (~3.05)."""
from sc_ldpc import compute_sc_threshold
from ldpc_sim import H_BASE
np.random.seed(42)
sc_thresh = compute_sc_threshold(
H_BASE, L=20, w=2, lam_b=0.1,
z_pop=5000, tol=0.5, cn_mode='normalized', alpha=0.75
)
# Uncoupled threshold is ~3.05 for offset, ~2.9 for normalized
# SC should be lower (threshold saturation)
assert sc_thresh < 3.5, f"SC threshold {sc_thresh} should be < 3.5"
def test_sc_de_wave_effect(self):
"""SC-LDPC should show position-dependent error profile.
In SC-LDPC with flooding DE, boundary positions have fewer CN
connections (lower effective VN degree), so they typically show
different error rates than interior positions. The error profile
should NOT be uniform -- the spatial coupling creates a non-trivial
position-dependent structure.
"""
from sc_ldpc import sc_density_evolution
from ldpc_sim import H_BASE
np.random.seed(42)
# Run at a moderate SNR where errors are visible but not saturated
converged, pos_errors, iters = sc_density_evolution(
H_BASE, L=20, w=2, lam_s=2.5, lam_b=0.1,
z_pop=5000, max_iter=100, cn_mode='normalized', alpha=0.75
)
# The error profile should be non-uniform across positions
# (spatial coupling creates position-dependent behavior)
err_std = np.std(pos_errors)
assert err_std > 0.001, (
f"Error profile std {err_std:.6f} too uniform; "
f"SC-LDPC should show position-dependent errors"
)
# Interior positions (well-connected) should have lower error
# than the max error across all positions
interior_err = np.mean(pos_errors[len(pos_errors)//3 : 2*len(pos_errors)//3])
max_err = max(pos_errors)
assert interior_err < max_err, (
f"Interior error {interior_err:.4f} should be less than max {max_err:.4f}"
)