334 lines
13 KiB
Python
334 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Validation tests for the LDPC behavioral model (ldpc_sim.py).
|
|
|
|
Tests cover code parameters, H-matrix structure, encoding, saturating arithmetic,
|
|
min-sum CN update, decoding, and syndrome computation.
|
|
|
|
Run:
|
|
python3 -m pytest model/test_ldpc.py -v
|
|
"""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from ldpc_sim import (
|
|
N, K, M, Z, N_BASE, M_BASE, H_BASE,
|
|
Q_BITS, Q_MAX, Q_MIN, OFFSET,
|
|
build_full_h_matrix,
|
|
ldpc_encode,
|
|
poisson_channel,
|
|
quantize_llr,
|
|
decode_layered_min_sum,
|
|
compute_syndrome_weight,
|
|
cyclic_shift,
|
|
sat_add_q,
|
|
sat_sub_q,
|
|
min_sum_cn_update,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Fixtures
|
|
# =============================================================================
|
|
|
|
@pytest.fixture(scope="module")
|
|
def H():
|
|
"""Full expanded H matrix, built once per module."""
|
|
return build_full_h_matrix()
|
|
|
|
|
|
# =============================================================================
|
|
# TestCodeParameters
|
|
# =============================================================================
|
|
|
|
class TestCodeParameters:
|
|
"""Verify fundamental code parameters match the design specification."""
|
|
|
|
def test_codeword_length(self):
|
|
assert N == 256
|
|
|
|
def test_info_length(self):
|
|
assert K == 32
|
|
|
|
def test_parity_checks(self):
|
|
assert M == 224
|
|
|
|
def test_lifting_factor(self):
|
|
assert Z == 32
|
|
|
|
def test_base_matrix_columns(self):
|
|
assert N_BASE == 8
|
|
|
|
def test_base_matrix_rows(self):
|
|
assert M_BASE == 7
|
|
|
|
def test_h_base_shape(self):
|
|
assert H_BASE.shape == (M_BASE, N_BASE)
|
|
assert H_BASE.shape == (7, 8)
|
|
|
|
def test_derived_relationships(self):
|
|
assert N == N_BASE * Z
|
|
assert M == M_BASE * Z
|
|
assert K == Z
|
|
|
|
|
|
# =============================================================================
|
|
# TestHMatrix
|
|
# =============================================================================
|
|
|
|
class TestHMatrix:
|
|
"""Validate H-matrix structural properties."""
|
|
|
|
def test_h_base_col0_connects_all_rows(self):
|
|
"""Column 0 (information bits) should connect to all 7 check rows."""
|
|
col0 = H_BASE[:, 0]
|
|
for r in range(M_BASE):
|
|
assert col0[r] >= 0, f"Row {r} of column 0 is not connected (value={col0[r]})"
|
|
|
|
def test_h_base_staircase_parity(self):
|
|
"""Parity columns 1-7 have a lower-triangular staircase structure.
|
|
|
|
For the staircase, row r connects to column r+1 (the diagonal) and
|
|
possibly column r (the sub-diagonal). Entries above the staircase
|
|
should be -1 (no connection).
|
|
"""
|
|
# Check that the parity sub-matrix (cols 1..7) is lower-triangular
|
|
for r in range(M_BASE):
|
|
for c in range(1, N_BASE):
|
|
if c > r + 1:
|
|
# Strictly above the diagonal+1: must be -1
|
|
assert H_BASE[r, c] == -1, (
|
|
f"H_BASE[{r},{c}]={H_BASE[r, c]}, expected -1 (above staircase)"
|
|
)
|
|
|
|
def test_full_h_dimensions(self, H):
|
|
"""Full H matrix should be (M, N) = (224, 256)."""
|
|
assert H.shape == (M, N)
|
|
assert H.shape == (224, 256)
|
|
|
|
def test_full_h_binary(self, H):
|
|
"""Full H matrix should contain only 0s and 1s."""
|
|
unique_vals = np.unique(H)
|
|
assert set(unique_vals).issubset({0, 1}), f"Non-binary values found: {unique_vals}"
|
|
|
|
def test_qc_block_row_weight(self, H):
|
|
"""Each Z-block row within a base row should have the same weight (QC property).
|
|
|
|
For a QC-LDPC code, each Z x Z sub-block is either all-zero or a
|
|
circulant permutation matrix. So within each base row, every
|
|
expanded row has the same Hamming weight.
|
|
"""
|
|
for r in range(M_BASE):
|
|
row_start = r * Z
|
|
weights = [int(H[row_start + z, :].sum()) for z in range(Z)]
|
|
assert len(set(weights)) == 1, (
|
|
f"Base row {r}: inconsistent row weights {set(weights)}"
|
|
)
|
|
|
|
def test_full_h_rank(self, H):
|
|
"""H should have full row rank (rank == M == 224) for a valid code."""
|
|
rank = np.linalg.matrix_rank(H.astype(float))
|
|
assert rank == M, f"H rank = {rank}, expected {M}"
|
|
|
|
|
|
# =============================================================================
|
|
# TestEncoder
|
|
# =============================================================================
|
|
|
|
class TestEncoder:
|
|
"""Validate the LDPC encoder."""
|
|
|
|
def test_all_zero_info(self, H):
|
|
"""All-zero info bits should encode to all-zero codeword."""
|
|
info = np.zeros(K, dtype=np.int8)
|
|
codeword = ldpc_encode(info, H)
|
|
assert np.all(codeword == 0), "All-zero info did not produce all-zero codeword"
|
|
|
|
def test_random_info_valid_codeword(self, H):
|
|
"""Random info bits should encode to a valid codeword (syndrome = 0)."""
|
|
np.random.seed(12345)
|
|
info = np.random.randint(0, 2, K).astype(np.int8)
|
|
codeword = ldpc_encode(info, H)
|
|
syndrome = H @ codeword % 2
|
|
assert np.all(syndrome == 0), f"Syndrome weight = {syndrome.sum()}"
|
|
|
|
def test_encoder_preserves_info_bits(self, H):
|
|
"""First K positions of codeword should be the original info bits."""
|
|
np.random.seed(99)
|
|
info = np.random.randint(0, 2, K).astype(np.int8)
|
|
codeword = ldpc_encode(info, H)
|
|
assert np.array_equal(codeword[:K], info), "Info bits not preserved in codeword"
|
|
|
|
def test_twenty_random_messages(self, H):
|
|
"""20 random messages should all produce valid codewords."""
|
|
np.random.seed(2024)
|
|
for i in range(20):
|
|
info = np.random.randint(0, 2, K).astype(np.int8)
|
|
codeword = ldpc_encode(info, H)
|
|
syndrome = H @ codeword % 2
|
|
assert np.all(syndrome == 0), f"Message {i}: syndrome weight = {syndrome.sum()}"
|
|
|
|
|
|
# =============================================================================
|
|
# TestSaturatingArithmetic
|
|
# =============================================================================
|
|
|
|
class TestSaturatingArithmetic:
|
|
"""Validate saturating add/sub in Q-bit signed arithmetic."""
|
|
|
|
def test_normal_add(self):
|
|
assert sat_add_q(5, 3) == 8
|
|
assert sat_add_q(-5, 3) == -2
|
|
assert sat_add_q(0, 0) == 0
|
|
|
|
def test_normal_sub(self):
|
|
assert sat_sub_q(5, 3) == 2
|
|
assert sat_sub_q(3, 5) == -2
|
|
assert sat_sub_q(0, 0) == 0
|
|
|
|
def test_positive_overflow_saturation(self):
|
|
"""Addition overflowing Q_MAX should saturate to Q_MAX."""
|
|
result = sat_add_q(Q_MAX, 1)
|
|
assert result == Q_MAX, f"Expected {Q_MAX}, got {result}"
|
|
result = sat_add_q(20, 20)
|
|
assert result == Q_MAX, f"Expected {Q_MAX}, got {result}"
|
|
|
|
def test_negative_overflow_saturation(self):
|
|
"""Addition underflowing Q_MIN should saturate to Q_MIN."""
|
|
result = sat_add_q(Q_MIN, -1)
|
|
assert result == Q_MIN, f"Expected {Q_MIN}, got {result}"
|
|
result = sat_add_q(-20, -20)
|
|
assert result == Q_MIN, f"Expected {Q_MIN}, got {result}"
|
|
|
|
def test_sub_positive_overflow(self):
|
|
"""Subtracting a large negative from a positive should saturate."""
|
|
result = sat_sub_q(Q_MAX, Q_MIN)
|
|
assert result == Q_MAX
|
|
|
|
def test_sub_negative_overflow(self):
|
|
"""Subtracting a large positive from a negative should saturate."""
|
|
result = sat_sub_q(Q_MIN, Q_MAX)
|
|
assert result == Q_MIN
|
|
|
|
|
|
# =============================================================================
|
|
# TestMinSumCN
|
|
# =============================================================================
|
|
|
|
class TestMinSumCN:
|
|
"""Validate offset min-sum check node update."""
|
|
|
|
def test_all_positive_inputs(self):
|
|
"""All positive inputs should produce all positive outputs."""
|
|
msgs_in = [10, 5, 8]
|
|
msgs_out = min_sum_cn_update(msgs_in)
|
|
for j, val in enumerate(msgs_out):
|
|
assert val >= 0, f"Output {j} = {val}, expected non-negative"
|
|
|
|
def test_one_negative_input_sign_flip(self):
|
|
"""One negative input among positives should flip signs of other outputs.
|
|
|
|
The sign of output j is the XOR of all OTHER input signs.
|
|
With one negative at index i, output j (j != i) gets sign=1 (negative),
|
|
and output i gets sign=0 (positive, since XOR of all-positive others = 0).
|
|
"""
|
|
msgs_in = [10, -5, 8]
|
|
msgs_out = min_sum_cn_update(msgs_in)
|
|
# Output at index 1 (the negative input) should be positive (sign of others XOR = 0)
|
|
assert msgs_out[1] >= 0, f"Output[1] = {msgs_out[1]}, expected non-negative"
|
|
# Outputs at indices 0 and 2 should be negative (one negative among others)
|
|
assert msgs_out[0] <= 0, f"Output[0] = {msgs_out[0]}, expected non-positive"
|
|
assert msgs_out[2] <= 0, f"Output[2] = {msgs_out[2]}, expected non-positive"
|
|
|
|
def test_magnitude_is_min_of_others_minus_offset(self):
|
|
"""Output magnitude = min of OTHER magnitudes - OFFSET, clamped to 0."""
|
|
msgs_in = [10, 5, 8]
|
|
msgs_out = min_sum_cn_update(msgs_in, offset=OFFSET)
|
|
mags_in = [abs(m) for m in msgs_in]
|
|
# For index 0: min of others = min(5, 8) = 5, output mag = 5 - OFFSET = 4
|
|
assert abs(msgs_out[0]) == max(0, 5 - OFFSET)
|
|
# For index 1 (has min magnitude): min of others = min(10, 8) = 8, output mag = 8 - OFFSET
|
|
assert abs(msgs_out[1]) == max(0, 8 - OFFSET)
|
|
# For index 2: min of others = min(10, 5) = 5, output mag = 5 - OFFSET = 4
|
|
assert abs(msgs_out[2]) == max(0, 5 - OFFSET)
|
|
|
|
def test_offset_clamps_to_zero(self):
|
|
"""When min magnitude equals offset, output magnitude should be 0."""
|
|
# Use offset=1 (default OFFSET) and magnitude=1 for the minimum
|
|
msgs_in = [10, 1, 8]
|
|
msgs_out = min_sum_cn_update(msgs_in, offset=1)
|
|
# For indices 0 and 2: min of others includes mag=1, so 1-1=0
|
|
assert abs(msgs_out[0]) == 0
|
|
assert abs(msgs_out[2]) == 0
|
|
|
|
|
|
# =============================================================================
|
|
# TestDecoder
|
|
# =============================================================================
|
|
|
|
class TestDecoder:
|
|
"""Validate the layered min-sum decoder."""
|
|
|
|
def test_noiseless_channel(self, H):
|
|
"""Noiseless channel (LLR = +-20) should decode perfectly in 1 iteration."""
|
|
np.random.seed(42)
|
|
info = np.random.randint(0, 2, K).astype(np.int8)
|
|
codeword = ldpc_encode(info, H)
|
|
|
|
# Create noiseless LLRs: positive for bit=0, negative for bit=1
|
|
# Convention: positive LLR -> bit 0 more likely
|
|
llr_q = np.where(codeword == 0, 20, -20).astype(np.int8)
|
|
|
|
decoded, converged, iters, syn_wt = decode_layered_min_sum(llr_q, max_iter=30)
|
|
assert converged, "Decoder did not converge on noiseless channel"
|
|
assert iters == 1, f"Expected 1 iteration, got {iters}"
|
|
assert np.array_equal(decoded, info), "Decoded bits do not match info bits"
|
|
assert syn_wt == 0
|
|
|
|
def test_high_snr_decoding(self, H):
|
|
"""High SNR (lambda_s=10) should decode correctly."""
|
|
np.random.seed(7777)
|
|
info = np.random.randint(0, 2, K).astype(np.int8)
|
|
codeword = ldpc_encode(info, H)
|
|
|
|
lam_s = 10.0
|
|
lam_b = 0.1
|
|
llr_float, _ = poisson_channel(codeword, lam_s, lam_b)
|
|
llr_q = quantize_llr(llr_float)
|
|
|
|
decoded, converged, iters, syn_wt = decode_layered_min_sum(llr_q, max_iter=30)
|
|
assert converged, f"Decoder did not converge (syn_wt={syn_wt}, iters={iters})"
|
|
assert np.array_equal(decoded, info), (
|
|
f"Decoded bits do not match info bits ({np.sum(decoded != info)} errors)"
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# TestSyndrome
|
|
# =============================================================================
|
|
|
|
class TestSyndrome:
|
|
"""Validate syndrome weight computation."""
|
|
|
|
def test_valid_codeword_zero_syndrome(self, H):
|
|
"""A valid codeword should have syndrome weight 0."""
|
|
np.random.seed(555)
|
|
info = np.random.randint(0, 2, K).astype(np.int8)
|
|
codeword = ldpc_encode(info, H)
|
|
sw = compute_syndrome_weight(codeword.tolist())
|
|
assert sw == 0, f"Valid codeword has syndrome weight {sw}, expected 0"
|
|
|
|
def test_flipped_bit_nonzero_syndrome(self, H):
|
|
"""Flipping one bit in a valid codeword should give nonzero syndrome weight."""
|
|
np.random.seed(555)
|
|
info = np.random.randint(0, 2, K).astype(np.int8)
|
|
codeword = ldpc_encode(info, H)
|
|
|
|
# Flip a bit
|
|
corrupted = codeword.copy()
|
|
corrupted[0] = 1 - corrupted[0]
|
|
|
|
sw = compute_syndrome_weight(corrupted.tolist())
|
|
assert sw > 0, "Flipped bit did not produce nonzero syndrome weight"
|