From af6055242eb3f43f406446724f364a30f2a16f64 Mon Sep 17 00:00:00 2001 From: cah Date: Tue, 24 Feb 2026 04:32:16 -0700 Subject: [PATCH] test: add validation tests for existing LDPC model Co-Authored-By: Claude Opus 4.6 --- model/test_ldpc.py | 333 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 333 insertions(+) create mode 100644 model/test_ldpc.py diff --git a/model/test_ldpc.py b/model/test_ldpc.py new file mode 100644 index 0000000..0c6898d --- /dev/null +++ b/model/test_ldpc.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +""" +Validation tests for the LDPC behavioral model (ldpc_sim.py). + +Tests cover code parameters, H-matrix structure, encoding, saturating arithmetic, +min-sum CN update, decoding, and syndrome computation. + +Run: + python3 -m pytest model/test_ldpc.py -v +""" + +import numpy as np +import pytest + +from ldpc_sim import ( + N, K, M, Z, N_BASE, M_BASE, H_BASE, + Q_BITS, Q_MAX, Q_MIN, OFFSET, + build_full_h_matrix, + ldpc_encode, + poisson_channel, + quantize_llr, + decode_layered_min_sum, + compute_syndrome_weight, + cyclic_shift, + sat_add_q, + sat_sub_q, + min_sum_cn_update, +) + + +# ============================================================================= +# Fixtures +# ============================================================================= + +@pytest.fixture(scope="module") +def H(): + """Full expanded H matrix, built once per module.""" + return build_full_h_matrix() + + +# ============================================================================= +# TestCodeParameters +# ============================================================================= + +class TestCodeParameters: + """Verify fundamental code parameters match the design specification.""" + + def test_codeword_length(self): + assert N == 256 + + def test_info_length(self): + assert K == 32 + + def test_parity_checks(self): + assert M == 224 + + def test_lifting_factor(self): + assert Z == 32 + + def test_base_matrix_columns(self): + assert N_BASE == 8 + + def test_base_matrix_rows(self): + assert M_BASE == 7 + + def test_h_base_shape(self): + assert H_BASE.shape == (M_BASE, N_BASE) + assert H_BASE.shape == (7, 8) + + def test_derived_relationships(self): + assert N == N_BASE * Z + assert M == M_BASE * Z + assert K == Z + + +# ============================================================================= +# TestHMatrix +# ============================================================================= + +class TestHMatrix: + """Validate H-matrix structural properties.""" + + def test_h_base_col0_connects_all_rows(self): + """Column 0 (information bits) should connect to all 7 check rows.""" + col0 = H_BASE[:, 0] + for r in range(M_BASE): + assert col0[r] >= 0, f"Row {r} of column 0 is not connected (value={col0[r]})" + + def test_h_base_staircase_parity(self): + """Parity columns 1-7 have a lower-triangular staircase structure. + + For the staircase, row r connects to column r+1 (the diagonal) and + possibly column r (the sub-diagonal). Entries above the staircase + should be -1 (no connection). + """ + # Check that the parity sub-matrix (cols 1..7) is lower-triangular + for r in range(M_BASE): + for c in range(1, N_BASE): + if c > r + 1: + # Strictly above the diagonal+1: must be -1 + assert H_BASE[r, c] == -1, ( + f"H_BASE[{r},{c}]={H_BASE[r, c]}, expected -1 (above staircase)" + ) + + def test_full_h_dimensions(self, H): + """Full H matrix should be (M, N) = (224, 256).""" + assert H.shape == (M, N) + assert H.shape == (224, 256) + + def test_full_h_binary(self, H): + """Full H matrix should contain only 0s and 1s.""" + unique_vals = np.unique(H) + assert set(unique_vals).issubset({0, 1}), f"Non-binary values found: {unique_vals}" + + def test_qc_block_row_weight(self, H): + """Each Z-block row within a base row should have the same weight (QC property). + + For a QC-LDPC code, each Z x Z sub-block is either all-zero or a + circulant permutation matrix. So within each base row, every + expanded row has the same Hamming weight. + """ + for r in range(M_BASE): + row_start = r * Z + weights = [int(H[row_start + z, :].sum()) for z in range(Z)] + assert len(set(weights)) == 1, ( + f"Base row {r}: inconsistent row weights {set(weights)}" + ) + + def test_full_h_rank(self, H): + """H should have full row rank (rank == M == 224) for a valid code.""" + rank = np.linalg.matrix_rank(H.astype(float)) + assert rank == M, f"H rank = {rank}, expected {M}" + + +# ============================================================================= +# TestEncoder +# ============================================================================= + +class TestEncoder: + """Validate the LDPC encoder.""" + + def test_all_zero_info(self, H): + """All-zero info bits should encode to all-zero codeword.""" + info = np.zeros(K, dtype=np.int8) + codeword = ldpc_encode(info, H) + assert np.all(codeword == 0), "All-zero info did not produce all-zero codeword" + + def test_random_info_valid_codeword(self, H): + """Random info bits should encode to a valid codeword (syndrome = 0).""" + np.random.seed(12345) + info = np.random.randint(0, 2, K).astype(np.int8) + codeword = ldpc_encode(info, H) + syndrome = H @ codeword % 2 + assert np.all(syndrome == 0), f"Syndrome weight = {syndrome.sum()}" + + def test_encoder_preserves_info_bits(self, H): + """First K positions of codeword should be the original info bits.""" + np.random.seed(99) + info = np.random.randint(0, 2, K).astype(np.int8) + codeword = ldpc_encode(info, H) + assert np.array_equal(codeword[:K], info), "Info bits not preserved in codeword" + + def test_twenty_random_messages(self, H): + """20 random messages should all produce valid codewords.""" + np.random.seed(2024) + for i in range(20): + info = np.random.randint(0, 2, K).astype(np.int8) + codeword = ldpc_encode(info, H) + syndrome = H @ codeword % 2 + assert np.all(syndrome == 0), f"Message {i}: syndrome weight = {syndrome.sum()}" + + +# ============================================================================= +# TestSaturatingArithmetic +# ============================================================================= + +class TestSaturatingArithmetic: + """Validate saturating add/sub in Q-bit signed arithmetic.""" + + def test_normal_add(self): + assert sat_add_q(5, 3) == 8 + assert sat_add_q(-5, 3) == -2 + assert sat_add_q(0, 0) == 0 + + def test_normal_sub(self): + assert sat_sub_q(5, 3) == 2 + assert sat_sub_q(3, 5) == -2 + assert sat_sub_q(0, 0) == 0 + + def test_positive_overflow_saturation(self): + """Addition overflowing Q_MAX should saturate to Q_MAX.""" + result = sat_add_q(Q_MAX, 1) + assert result == Q_MAX, f"Expected {Q_MAX}, got {result}" + result = sat_add_q(20, 20) + assert result == Q_MAX, f"Expected {Q_MAX}, got {result}" + + def test_negative_overflow_saturation(self): + """Addition underflowing Q_MIN should saturate to Q_MIN.""" + result = sat_add_q(Q_MIN, -1) + assert result == Q_MIN, f"Expected {Q_MIN}, got {result}" + result = sat_add_q(-20, -20) + assert result == Q_MIN, f"Expected {Q_MIN}, got {result}" + + def test_sub_positive_overflow(self): + """Subtracting a large negative from a positive should saturate.""" + result = sat_sub_q(Q_MAX, Q_MIN) + assert result == Q_MAX + + def test_sub_negative_overflow(self): + """Subtracting a large positive from a negative should saturate.""" + result = sat_sub_q(Q_MIN, Q_MAX) + assert result == Q_MIN + + +# ============================================================================= +# TestMinSumCN +# ============================================================================= + +class TestMinSumCN: + """Validate offset min-sum check node update.""" + + def test_all_positive_inputs(self): + """All positive inputs should produce all positive outputs.""" + msgs_in = [10, 5, 8] + msgs_out = min_sum_cn_update(msgs_in) + for j, val in enumerate(msgs_out): + assert val >= 0, f"Output {j} = {val}, expected non-negative" + + def test_one_negative_input_sign_flip(self): + """One negative input among positives should flip signs of other outputs. + + The sign of output j is the XOR of all OTHER input signs. + With one negative at index i, output j (j != i) gets sign=1 (negative), + and output i gets sign=0 (positive, since XOR of all-positive others = 0). + """ + msgs_in = [10, -5, 8] + msgs_out = min_sum_cn_update(msgs_in) + # Output at index 1 (the negative input) should be positive (sign of others XOR = 0) + assert msgs_out[1] >= 0, f"Output[1] = {msgs_out[1]}, expected non-negative" + # Outputs at indices 0 and 2 should be negative (one negative among others) + assert msgs_out[0] <= 0, f"Output[0] = {msgs_out[0]}, expected non-positive" + assert msgs_out[2] <= 0, f"Output[2] = {msgs_out[2]}, expected non-positive" + + def test_magnitude_is_min_of_others_minus_offset(self): + """Output magnitude = min of OTHER magnitudes - OFFSET, clamped to 0.""" + msgs_in = [10, 5, 8] + msgs_out = min_sum_cn_update(msgs_in, offset=OFFSET) + mags_in = [abs(m) for m in msgs_in] + # For index 0: min of others = min(5, 8) = 5, output mag = 5 - OFFSET = 4 + assert abs(msgs_out[0]) == max(0, 5 - OFFSET) + # For index 1 (has min magnitude): min of others = min(10, 8) = 8, output mag = 8 - OFFSET + assert abs(msgs_out[1]) == max(0, 8 - OFFSET) + # For index 2: min of others = min(10, 5) = 5, output mag = 5 - OFFSET = 4 + assert abs(msgs_out[2]) == max(0, 5 - OFFSET) + + def test_offset_clamps_to_zero(self): + """When min magnitude equals offset, output magnitude should be 0.""" + # Use offset=1 (default OFFSET) and magnitude=1 for the minimum + msgs_in = [10, 1, 8] + msgs_out = min_sum_cn_update(msgs_in, offset=1) + # For indices 0 and 2: min of others includes mag=1, so 1-1=0 + assert abs(msgs_out[0]) == 0 + assert abs(msgs_out[2]) == 0 + + +# ============================================================================= +# TestDecoder +# ============================================================================= + +class TestDecoder: + """Validate the layered min-sum decoder.""" + + def test_noiseless_channel(self, H): + """Noiseless channel (LLR = +-20) should decode perfectly in 1 iteration.""" + np.random.seed(42) + info = np.random.randint(0, 2, K).astype(np.int8) + codeword = ldpc_encode(info, H) + + # Create noiseless LLRs: positive for bit=0, negative for bit=1 + # Convention: positive LLR -> bit 0 more likely + llr_q = np.where(codeword == 0, 20, -20).astype(np.int8) + + decoded, converged, iters, syn_wt = decode_layered_min_sum(llr_q, max_iter=30) + assert converged, "Decoder did not converge on noiseless channel" + assert iters == 1, f"Expected 1 iteration, got {iters}" + assert np.array_equal(decoded, info), "Decoded bits do not match info bits" + assert syn_wt == 0 + + def test_high_snr_decoding(self, H): + """High SNR (lambda_s=10) should decode correctly.""" + np.random.seed(7777) + info = np.random.randint(0, 2, K).astype(np.int8) + codeword = ldpc_encode(info, H) + + lam_s = 10.0 + lam_b = 0.1 + llr_float, _ = poisson_channel(codeword, lam_s, lam_b) + llr_q = quantize_llr(llr_float) + + decoded, converged, iters, syn_wt = decode_layered_min_sum(llr_q, max_iter=30) + assert converged, f"Decoder did not converge (syn_wt={syn_wt}, iters={iters})" + assert np.array_equal(decoded, info), ( + f"Decoded bits do not match info bits ({np.sum(decoded != info)} errors)" + ) + + +# ============================================================================= +# TestSyndrome +# ============================================================================= + +class TestSyndrome: + """Validate syndrome weight computation.""" + + def test_valid_codeword_zero_syndrome(self, H): + """A valid codeword should have syndrome weight 0.""" + np.random.seed(555) + info = np.random.randint(0, 2, K).astype(np.int8) + codeword = ldpc_encode(info, H) + sw = compute_syndrome_weight(codeword.tolist()) + assert sw == 0, f"Valid codeword has syndrome weight {sw}, expected 0" + + def test_flipped_bit_nonzero_syndrome(self, H): + """Flipping one bit in a valid codeword should give nonzero syndrome weight.""" + np.random.seed(555) + info = np.random.randint(0, 2, K).astype(np.int8) + codeword = ldpc_encode(info, H) + + # Flip a bit + corrupted = codeword.copy() + corrupted[0] = 1 - corrupted[0] + + sw = compute_syndrome_weight(corrupted.tolist()) + assert sw > 0, "Flipped bit did not produce nonzero syndrome weight"