Files
ldpc_optical/model/test_ldpc.py
2026-02-24 04:32:16 -07:00

334 lines
13 KiB
Python

#!/usr/bin/env python3
"""
Validation tests for the LDPC behavioral model (ldpc_sim.py).
Tests cover code parameters, H-matrix structure, encoding, saturating arithmetic,
min-sum CN update, decoding, and syndrome computation.
Run:
python3 -m pytest model/test_ldpc.py -v
"""
import numpy as np
import pytest
from ldpc_sim import (
N, K, M, Z, N_BASE, M_BASE, H_BASE,
Q_BITS, Q_MAX, Q_MIN, OFFSET,
build_full_h_matrix,
ldpc_encode,
poisson_channel,
quantize_llr,
decode_layered_min_sum,
compute_syndrome_weight,
cyclic_shift,
sat_add_q,
sat_sub_q,
min_sum_cn_update,
)
# =============================================================================
# Fixtures
# =============================================================================
@pytest.fixture(scope="module")
def H():
"""Full expanded H matrix, built once per module."""
return build_full_h_matrix()
# =============================================================================
# TestCodeParameters
# =============================================================================
class TestCodeParameters:
"""Verify fundamental code parameters match the design specification."""
def test_codeword_length(self):
assert N == 256
def test_info_length(self):
assert K == 32
def test_parity_checks(self):
assert M == 224
def test_lifting_factor(self):
assert Z == 32
def test_base_matrix_columns(self):
assert N_BASE == 8
def test_base_matrix_rows(self):
assert M_BASE == 7
def test_h_base_shape(self):
assert H_BASE.shape == (M_BASE, N_BASE)
assert H_BASE.shape == (7, 8)
def test_derived_relationships(self):
assert N == N_BASE * Z
assert M == M_BASE * Z
assert K == Z
# =============================================================================
# TestHMatrix
# =============================================================================
class TestHMatrix:
"""Validate H-matrix structural properties."""
def test_h_base_col0_connects_all_rows(self):
"""Column 0 (information bits) should connect to all 7 check rows."""
col0 = H_BASE[:, 0]
for r in range(M_BASE):
assert col0[r] >= 0, f"Row {r} of column 0 is not connected (value={col0[r]})"
def test_h_base_staircase_parity(self):
"""Parity columns 1-7 have a lower-triangular staircase structure.
For the staircase, row r connects to column r+1 (the diagonal) and
possibly column r (the sub-diagonal). Entries above the staircase
should be -1 (no connection).
"""
# Check that the parity sub-matrix (cols 1..7) is lower-triangular
for r in range(M_BASE):
for c in range(1, N_BASE):
if c > r + 1:
# Strictly above the diagonal+1: must be -1
assert H_BASE[r, c] == -1, (
f"H_BASE[{r},{c}]={H_BASE[r, c]}, expected -1 (above staircase)"
)
def test_full_h_dimensions(self, H):
"""Full H matrix should be (M, N) = (224, 256)."""
assert H.shape == (M, N)
assert H.shape == (224, 256)
def test_full_h_binary(self, H):
"""Full H matrix should contain only 0s and 1s."""
unique_vals = np.unique(H)
assert set(unique_vals).issubset({0, 1}), f"Non-binary values found: {unique_vals}"
def test_qc_block_row_weight(self, H):
"""Each Z-block row within a base row should have the same weight (QC property).
For a QC-LDPC code, each Z x Z sub-block is either all-zero or a
circulant permutation matrix. So within each base row, every
expanded row has the same Hamming weight.
"""
for r in range(M_BASE):
row_start = r * Z
weights = [int(H[row_start + z, :].sum()) for z in range(Z)]
assert len(set(weights)) == 1, (
f"Base row {r}: inconsistent row weights {set(weights)}"
)
def test_full_h_rank(self, H):
"""H should have full row rank (rank == M == 224) for a valid code."""
rank = np.linalg.matrix_rank(H.astype(float))
assert rank == M, f"H rank = {rank}, expected {M}"
# =============================================================================
# TestEncoder
# =============================================================================
class TestEncoder:
"""Validate the LDPC encoder."""
def test_all_zero_info(self, H):
"""All-zero info bits should encode to all-zero codeword."""
info = np.zeros(K, dtype=np.int8)
codeword = ldpc_encode(info, H)
assert np.all(codeword == 0), "All-zero info did not produce all-zero codeword"
def test_random_info_valid_codeword(self, H):
"""Random info bits should encode to a valid codeword (syndrome = 0)."""
np.random.seed(12345)
info = np.random.randint(0, 2, K).astype(np.int8)
codeword = ldpc_encode(info, H)
syndrome = H @ codeword % 2
assert np.all(syndrome == 0), f"Syndrome weight = {syndrome.sum()}"
def test_encoder_preserves_info_bits(self, H):
"""First K positions of codeword should be the original info bits."""
np.random.seed(99)
info = np.random.randint(0, 2, K).astype(np.int8)
codeword = ldpc_encode(info, H)
assert np.array_equal(codeword[:K], info), "Info bits not preserved in codeword"
def test_twenty_random_messages(self, H):
"""20 random messages should all produce valid codewords."""
np.random.seed(2024)
for i in range(20):
info = np.random.randint(0, 2, K).astype(np.int8)
codeword = ldpc_encode(info, H)
syndrome = H @ codeword % 2
assert np.all(syndrome == 0), f"Message {i}: syndrome weight = {syndrome.sum()}"
# =============================================================================
# TestSaturatingArithmetic
# =============================================================================
class TestSaturatingArithmetic:
"""Validate saturating add/sub in Q-bit signed arithmetic."""
def test_normal_add(self):
assert sat_add_q(5, 3) == 8
assert sat_add_q(-5, 3) == -2
assert sat_add_q(0, 0) == 0
def test_normal_sub(self):
assert sat_sub_q(5, 3) == 2
assert sat_sub_q(3, 5) == -2
assert sat_sub_q(0, 0) == 0
def test_positive_overflow_saturation(self):
"""Addition overflowing Q_MAX should saturate to Q_MAX."""
result = sat_add_q(Q_MAX, 1)
assert result == Q_MAX, f"Expected {Q_MAX}, got {result}"
result = sat_add_q(20, 20)
assert result == Q_MAX, f"Expected {Q_MAX}, got {result}"
def test_negative_overflow_saturation(self):
"""Addition underflowing Q_MIN should saturate to Q_MIN."""
result = sat_add_q(Q_MIN, -1)
assert result == Q_MIN, f"Expected {Q_MIN}, got {result}"
result = sat_add_q(-20, -20)
assert result == Q_MIN, f"Expected {Q_MIN}, got {result}"
def test_sub_positive_overflow(self):
"""Subtracting a large negative from a positive should saturate."""
result = sat_sub_q(Q_MAX, Q_MIN)
assert result == Q_MAX
def test_sub_negative_overflow(self):
"""Subtracting a large positive from a negative should saturate."""
result = sat_sub_q(Q_MIN, Q_MAX)
assert result == Q_MIN
# =============================================================================
# TestMinSumCN
# =============================================================================
class TestMinSumCN:
"""Validate offset min-sum check node update."""
def test_all_positive_inputs(self):
"""All positive inputs should produce all positive outputs."""
msgs_in = [10, 5, 8]
msgs_out = min_sum_cn_update(msgs_in)
for j, val in enumerate(msgs_out):
assert val >= 0, f"Output {j} = {val}, expected non-negative"
def test_one_negative_input_sign_flip(self):
"""One negative input among positives should flip signs of other outputs.
The sign of output j is the XOR of all OTHER input signs.
With one negative at index i, output j (j != i) gets sign=1 (negative),
and output i gets sign=0 (positive, since XOR of all-positive others = 0).
"""
msgs_in = [10, -5, 8]
msgs_out = min_sum_cn_update(msgs_in)
# Output at index 1 (the negative input) should be positive (sign of others XOR = 0)
assert msgs_out[1] >= 0, f"Output[1] = {msgs_out[1]}, expected non-negative"
# Outputs at indices 0 and 2 should be negative (one negative among others)
assert msgs_out[0] <= 0, f"Output[0] = {msgs_out[0]}, expected non-positive"
assert msgs_out[2] <= 0, f"Output[2] = {msgs_out[2]}, expected non-positive"
def test_magnitude_is_min_of_others_minus_offset(self):
"""Output magnitude = min of OTHER magnitudes - OFFSET, clamped to 0."""
msgs_in = [10, 5, 8]
msgs_out = min_sum_cn_update(msgs_in, offset=OFFSET)
mags_in = [abs(m) for m in msgs_in]
# For index 0: min of others = min(5, 8) = 5, output mag = 5 - OFFSET = 4
assert abs(msgs_out[0]) == max(0, 5 - OFFSET)
# For index 1 (has min magnitude): min of others = min(10, 8) = 8, output mag = 8 - OFFSET
assert abs(msgs_out[1]) == max(0, 8 - OFFSET)
# For index 2: min of others = min(10, 5) = 5, output mag = 5 - OFFSET = 4
assert abs(msgs_out[2]) == max(0, 5 - OFFSET)
def test_offset_clamps_to_zero(self):
"""When min magnitude equals offset, output magnitude should be 0."""
# Use offset=1 (default OFFSET) and magnitude=1 for the minimum
msgs_in = [10, 1, 8]
msgs_out = min_sum_cn_update(msgs_in, offset=1)
# For indices 0 and 2: min of others includes mag=1, so 1-1=0
assert abs(msgs_out[0]) == 0
assert abs(msgs_out[2]) == 0
# =============================================================================
# TestDecoder
# =============================================================================
class TestDecoder:
"""Validate the layered min-sum decoder."""
def test_noiseless_channel(self, H):
"""Noiseless channel (LLR = +-20) should decode perfectly in 1 iteration."""
np.random.seed(42)
info = np.random.randint(0, 2, K).astype(np.int8)
codeword = ldpc_encode(info, H)
# Create noiseless LLRs: positive for bit=0, negative for bit=1
# Convention: positive LLR -> bit 0 more likely
llr_q = np.where(codeword == 0, 20, -20).astype(np.int8)
decoded, converged, iters, syn_wt = decode_layered_min_sum(llr_q, max_iter=30)
assert converged, "Decoder did not converge on noiseless channel"
assert iters == 1, f"Expected 1 iteration, got {iters}"
assert np.array_equal(decoded, info), "Decoded bits do not match info bits"
assert syn_wt == 0
def test_high_snr_decoding(self, H):
"""High SNR (lambda_s=10) should decode correctly."""
np.random.seed(7777)
info = np.random.randint(0, 2, K).astype(np.int8)
codeword = ldpc_encode(info, H)
lam_s = 10.0
lam_b = 0.1
llr_float, _ = poisson_channel(codeword, lam_s, lam_b)
llr_q = quantize_llr(llr_float)
decoded, converged, iters, syn_wt = decode_layered_min_sum(llr_q, max_iter=30)
assert converged, f"Decoder did not converge (syn_wt={syn_wt}, iters={iters})"
assert np.array_equal(decoded, info), (
f"Decoded bits do not match info bits ({np.sum(decoded != info)} errors)"
)
# =============================================================================
# TestSyndrome
# =============================================================================
class TestSyndrome:
"""Validate syndrome weight computation."""
def test_valid_codeword_zero_syndrome(self, H):
"""A valid codeword should have syndrome weight 0."""
np.random.seed(555)
info = np.random.randint(0, 2, K).astype(np.int8)
codeword = ldpc_encode(info, H)
sw = compute_syndrome_weight(codeword.tolist())
assert sw == 0, f"Valid codeword has syndrome weight {sw}, expected 0"
def test_flipped_bit_nonzero_syndrome(self, H):
"""Flipping one bit in a valid codeword should give nonzero syndrome weight."""
np.random.seed(555)
info = np.random.randint(0, 2, K).astype(np.int8)
codeword = ldpc_encode(info, H)
# Flip a bit
corrupted = codeword.copy()
corrupted[0] = 1 - corrupted[0]
sw = compute_syndrome_weight(corrupted.tolist())
assert sw > 0, "Flipped bit did not produce nonzero syndrome weight"