test: add validation tests for existing LDPC model
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
333
model/test_ldpc.py
Normal file
333
model/test_ldpc.py
Normal file
@@ -0,0 +1,333 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Validation tests for the LDPC behavioral model (ldpc_sim.py).
|
||||
|
||||
Tests cover code parameters, H-matrix structure, encoding, saturating arithmetic,
|
||||
min-sum CN update, decoding, and syndrome computation.
|
||||
|
||||
Run:
|
||||
python3 -m pytest model/test_ldpc.py -v
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from ldpc_sim import (
|
||||
N, K, M, Z, N_BASE, M_BASE, H_BASE,
|
||||
Q_BITS, Q_MAX, Q_MIN, OFFSET,
|
||||
build_full_h_matrix,
|
||||
ldpc_encode,
|
||||
poisson_channel,
|
||||
quantize_llr,
|
||||
decode_layered_min_sum,
|
||||
compute_syndrome_weight,
|
||||
cyclic_shift,
|
||||
sat_add_q,
|
||||
sat_sub_q,
|
||||
min_sum_cn_update,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Fixtures
|
||||
# =============================================================================
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def H():
|
||||
"""Full expanded H matrix, built once per module."""
|
||||
return build_full_h_matrix()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestCodeParameters
|
||||
# =============================================================================
|
||||
|
||||
class TestCodeParameters:
|
||||
"""Verify fundamental code parameters match the design specification."""
|
||||
|
||||
def test_codeword_length(self):
|
||||
assert N == 256
|
||||
|
||||
def test_info_length(self):
|
||||
assert K == 32
|
||||
|
||||
def test_parity_checks(self):
|
||||
assert M == 224
|
||||
|
||||
def test_lifting_factor(self):
|
||||
assert Z == 32
|
||||
|
||||
def test_base_matrix_columns(self):
|
||||
assert N_BASE == 8
|
||||
|
||||
def test_base_matrix_rows(self):
|
||||
assert M_BASE == 7
|
||||
|
||||
def test_h_base_shape(self):
|
||||
assert H_BASE.shape == (M_BASE, N_BASE)
|
||||
assert H_BASE.shape == (7, 8)
|
||||
|
||||
def test_derived_relationships(self):
|
||||
assert N == N_BASE * Z
|
||||
assert M == M_BASE * Z
|
||||
assert K == Z
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestHMatrix
|
||||
# =============================================================================
|
||||
|
||||
class TestHMatrix:
|
||||
"""Validate H-matrix structural properties."""
|
||||
|
||||
def test_h_base_col0_connects_all_rows(self):
|
||||
"""Column 0 (information bits) should connect to all 7 check rows."""
|
||||
col0 = H_BASE[:, 0]
|
||||
for r in range(M_BASE):
|
||||
assert col0[r] >= 0, f"Row {r} of column 0 is not connected (value={col0[r]})"
|
||||
|
||||
def test_h_base_staircase_parity(self):
|
||||
"""Parity columns 1-7 have a lower-triangular staircase structure.
|
||||
|
||||
For the staircase, row r connects to column r+1 (the diagonal) and
|
||||
possibly column r (the sub-diagonal). Entries above the staircase
|
||||
should be -1 (no connection).
|
||||
"""
|
||||
# Check that the parity sub-matrix (cols 1..7) is lower-triangular
|
||||
for r in range(M_BASE):
|
||||
for c in range(1, N_BASE):
|
||||
if c > r + 1:
|
||||
# Strictly above the diagonal+1: must be -1
|
||||
assert H_BASE[r, c] == -1, (
|
||||
f"H_BASE[{r},{c}]={H_BASE[r, c]}, expected -1 (above staircase)"
|
||||
)
|
||||
|
||||
def test_full_h_dimensions(self, H):
|
||||
"""Full H matrix should be (M, N) = (224, 256)."""
|
||||
assert H.shape == (M, N)
|
||||
assert H.shape == (224, 256)
|
||||
|
||||
def test_full_h_binary(self, H):
|
||||
"""Full H matrix should contain only 0s and 1s."""
|
||||
unique_vals = np.unique(H)
|
||||
assert set(unique_vals).issubset({0, 1}), f"Non-binary values found: {unique_vals}"
|
||||
|
||||
def test_qc_block_row_weight(self, H):
|
||||
"""Each Z-block row within a base row should have the same weight (QC property).
|
||||
|
||||
For a QC-LDPC code, each Z x Z sub-block is either all-zero or a
|
||||
circulant permutation matrix. So within each base row, every
|
||||
expanded row has the same Hamming weight.
|
||||
"""
|
||||
for r in range(M_BASE):
|
||||
row_start = r * Z
|
||||
weights = [int(H[row_start + z, :].sum()) for z in range(Z)]
|
||||
assert len(set(weights)) == 1, (
|
||||
f"Base row {r}: inconsistent row weights {set(weights)}"
|
||||
)
|
||||
|
||||
def test_full_h_rank(self, H):
|
||||
"""H should have full row rank (rank == M == 224) for a valid code."""
|
||||
rank = np.linalg.matrix_rank(H.astype(float))
|
||||
assert rank == M, f"H rank = {rank}, expected {M}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestEncoder
|
||||
# =============================================================================
|
||||
|
||||
class TestEncoder:
|
||||
"""Validate the LDPC encoder."""
|
||||
|
||||
def test_all_zero_info(self, H):
|
||||
"""All-zero info bits should encode to all-zero codeword."""
|
||||
info = np.zeros(K, dtype=np.int8)
|
||||
codeword = ldpc_encode(info, H)
|
||||
assert np.all(codeword == 0), "All-zero info did not produce all-zero codeword"
|
||||
|
||||
def test_random_info_valid_codeword(self, H):
|
||||
"""Random info bits should encode to a valid codeword (syndrome = 0)."""
|
||||
np.random.seed(12345)
|
||||
info = np.random.randint(0, 2, K).astype(np.int8)
|
||||
codeword = ldpc_encode(info, H)
|
||||
syndrome = H @ codeword % 2
|
||||
assert np.all(syndrome == 0), f"Syndrome weight = {syndrome.sum()}"
|
||||
|
||||
def test_encoder_preserves_info_bits(self, H):
|
||||
"""First K positions of codeword should be the original info bits."""
|
||||
np.random.seed(99)
|
||||
info = np.random.randint(0, 2, K).astype(np.int8)
|
||||
codeword = ldpc_encode(info, H)
|
||||
assert np.array_equal(codeword[:K], info), "Info bits not preserved in codeword"
|
||||
|
||||
def test_twenty_random_messages(self, H):
|
||||
"""20 random messages should all produce valid codewords."""
|
||||
np.random.seed(2024)
|
||||
for i in range(20):
|
||||
info = np.random.randint(0, 2, K).astype(np.int8)
|
||||
codeword = ldpc_encode(info, H)
|
||||
syndrome = H @ codeword % 2
|
||||
assert np.all(syndrome == 0), f"Message {i}: syndrome weight = {syndrome.sum()}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestSaturatingArithmetic
|
||||
# =============================================================================
|
||||
|
||||
class TestSaturatingArithmetic:
|
||||
"""Validate saturating add/sub in Q-bit signed arithmetic."""
|
||||
|
||||
def test_normal_add(self):
|
||||
assert sat_add_q(5, 3) == 8
|
||||
assert sat_add_q(-5, 3) == -2
|
||||
assert sat_add_q(0, 0) == 0
|
||||
|
||||
def test_normal_sub(self):
|
||||
assert sat_sub_q(5, 3) == 2
|
||||
assert sat_sub_q(3, 5) == -2
|
||||
assert sat_sub_q(0, 0) == 0
|
||||
|
||||
def test_positive_overflow_saturation(self):
|
||||
"""Addition overflowing Q_MAX should saturate to Q_MAX."""
|
||||
result = sat_add_q(Q_MAX, 1)
|
||||
assert result == Q_MAX, f"Expected {Q_MAX}, got {result}"
|
||||
result = sat_add_q(20, 20)
|
||||
assert result == Q_MAX, f"Expected {Q_MAX}, got {result}"
|
||||
|
||||
def test_negative_overflow_saturation(self):
|
||||
"""Addition underflowing Q_MIN should saturate to Q_MIN."""
|
||||
result = sat_add_q(Q_MIN, -1)
|
||||
assert result == Q_MIN, f"Expected {Q_MIN}, got {result}"
|
||||
result = sat_add_q(-20, -20)
|
||||
assert result == Q_MIN, f"Expected {Q_MIN}, got {result}"
|
||||
|
||||
def test_sub_positive_overflow(self):
|
||||
"""Subtracting a large negative from a positive should saturate."""
|
||||
result = sat_sub_q(Q_MAX, Q_MIN)
|
||||
assert result == Q_MAX
|
||||
|
||||
def test_sub_negative_overflow(self):
|
||||
"""Subtracting a large positive from a negative should saturate."""
|
||||
result = sat_sub_q(Q_MIN, Q_MAX)
|
||||
assert result == Q_MIN
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestMinSumCN
|
||||
# =============================================================================
|
||||
|
||||
class TestMinSumCN:
|
||||
"""Validate offset min-sum check node update."""
|
||||
|
||||
def test_all_positive_inputs(self):
|
||||
"""All positive inputs should produce all positive outputs."""
|
||||
msgs_in = [10, 5, 8]
|
||||
msgs_out = min_sum_cn_update(msgs_in)
|
||||
for j, val in enumerate(msgs_out):
|
||||
assert val >= 0, f"Output {j} = {val}, expected non-negative"
|
||||
|
||||
def test_one_negative_input_sign_flip(self):
|
||||
"""One negative input among positives should flip signs of other outputs.
|
||||
|
||||
The sign of output j is the XOR of all OTHER input signs.
|
||||
With one negative at index i, output j (j != i) gets sign=1 (negative),
|
||||
and output i gets sign=0 (positive, since XOR of all-positive others = 0).
|
||||
"""
|
||||
msgs_in = [10, -5, 8]
|
||||
msgs_out = min_sum_cn_update(msgs_in)
|
||||
# Output at index 1 (the negative input) should be positive (sign of others XOR = 0)
|
||||
assert msgs_out[1] >= 0, f"Output[1] = {msgs_out[1]}, expected non-negative"
|
||||
# Outputs at indices 0 and 2 should be negative (one negative among others)
|
||||
assert msgs_out[0] <= 0, f"Output[0] = {msgs_out[0]}, expected non-positive"
|
||||
assert msgs_out[2] <= 0, f"Output[2] = {msgs_out[2]}, expected non-positive"
|
||||
|
||||
def test_magnitude_is_min_of_others_minus_offset(self):
|
||||
"""Output magnitude = min of OTHER magnitudes - OFFSET, clamped to 0."""
|
||||
msgs_in = [10, 5, 8]
|
||||
msgs_out = min_sum_cn_update(msgs_in, offset=OFFSET)
|
||||
mags_in = [abs(m) for m in msgs_in]
|
||||
# For index 0: min of others = min(5, 8) = 5, output mag = 5 - OFFSET = 4
|
||||
assert abs(msgs_out[0]) == max(0, 5 - OFFSET)
|
||||
# For index 1 (has min magnitude): min of others = min(10, 8) = 8, output mag = 8 - OFFSET
|
||||
assert abs(msgs_out[1]) == max(0, 8 - OFFSET)
|
||||
# For index 2: min of others = min(10, 5) = 5, output mag = 5 - OFFSET = 4
|
||||
assert abs(msgs_out[2]) == max(0, 5 - OFFSET)
|
||||
|
||||
def test_offset_clamps_to_zero(self):
|
||||
"""When min magnitude equals offset, output magnitude should be 0."""
|
||||
# Use offset=1 (default OFFSET) and magnitude=1 for the minimum
|
||||
msgs_in = [10, 1, 8]
|
||||
msgs_out = min_sum_cn_update(msgs_in, offset=1)
|
||||
# For indices 0 and 2: min of others includes mag=1, so 1-1=0
|
||||
assert abs(msgs_out[0]) == 0
|
||||
assert abs(msgs_out[2]) == 0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestDecoder
|
||||
# =============================================================================
|
||||
|
||||
class TestDecoder:
|
||||
"""Validate the layered min-sum decoder."""
|
||||
|
||||
def test_noiseless_channel(self, H):
|
||||
"""Noiseless channel (LLR = +-20) should decode perfectly in 1 iteration."""
|
||||
np.random.seed(42)
|
||||
info = np.random.randint(0, 2, K).astype(np.int8)
|
||||
codeword = ldpc_encode(info, H)
|
||||
|
||||
# Create noiseless LLRs: positive for bit=0, negative for bit=1
|
||||
# Convention: positive LLR -> bit 0 more likely
|
||||
llr_q = np.where(codeword == 0, 20, -20).astype(np.int8)
|
||||
|
||||
decoded, converged, iters, syn_wt = decode_layered_min_sum(llr_q, max_iter=30)
|
||||
assert converged, "Decoder did not converge on noiseless channel"
|
||||
assert iters == 1, f"Expected 1 iteration, got {iters}"
|
||||
assert np.array_equal(decoded, info), "Decoded bits do not match info bits"
|
||||
assert syn_wt == 0
|
||||
|
||||
def test_high_snr_decoding(self, H):
|
||||
"""High SNR (lambda_s=10) should decode correctly."""
|
||||
np.random.seed(7777)
|
||||
info = np.random.randint(0, 2, K).astype(np.int8)
|
||||
codeword = ldpc_encode(info, H)
|
||||
|
||||
lam_s = 10.0
|
||||
lam_b = 0.1
|
||||
llr_float, _ = poisson_channel(codeword, lam_s, lam_b)
|
||||
llr_q = quantize_llr(llr_float)
|
||||
|
||||
decoded, converged, iters, syn_wt = decode_layered_min_sum(llr_q, max_iter=30)
|
||||
assert converged, f"Decoder did not converge (syn_wt={syn_wt}, iters={iters})"
|
||||
assert np.array_equal(decoded, info), (
|
||||
f"Decoded bits do not match info bits ({np.sum(decoded != info)} errors)"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# TestSyndrome
|
||||
# =============================================================================
|
||||
|
||||
class TestSyndrome:
|
||||
"""Validate syndrome weight computation."""
|
||||
|
||||
def test_valid_codeword_zero_syndrome(self, H):
|
||||
"""A valid codeword should have syndrome weight 0."""
|
||||
np.random.seed(555)
|
||||
info = np.random.randint(0, 2, K).astype(np.int8)
|
||||
codeword = ldpc_encode(info, H)
|
||||
sw = compute_syndrome_weight(codeword.tolist())
|
||||
assert sw == 0, f"Valid codeword has syndrome weight {sw}, expected 0"
|
||||
|
||||
def test_flipped_bit_nonzero_syndrome(self, H):
|
||||
"""Flipping one bit in a valid codeword should give nonzero syndrome weight."""
|
||||
np.random.seed(555)
|
||||
info = np.random.randint(0, 2, K).astype(np.int8)
|
||||
codeword = ldpc_encode(info, H)
|
||||
|
||||
# Flip a bit
|
||||
corrupted = codeword.copy()
|
||||
corrupted[0] = 1 - corrupted[0]
|
||||
|
||||
sw = compute_syndrome_weight(corrupted.tolist())
|
||||
assert sw > 0, "Flipped bit did not produce nonzero syndrome weight"
|
||||
Reference in New Issue
Block a user