feat: add SC-LDPC chain construction
Implement spatially-coupled LDPC code construction with: - split_protograph(): split base matrix edges into w components - build_sc_chain(): build full SC-LDPC H matrix with L positions - sc_encode(): GF(2) Gaussian elimination encoder for SC chain Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
212
model/sc_ldpc.py
Normal file
212
model/sc_ldpc.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Spatially-Coupled LDPC (SC-LDPC) Code Construction and Decoding
|
||||||
|
|
||||||
|
SC-LDPC codes achieve threshold saturation: the BP threshold approaches
|
||||||
|
the MAP threshold, closing a significant portion of the gap to Shannon limit.
|
||||||
|
|
||||||
|
Construction: replicate a protograph base matrix along a chain of L positions
|
||||||
|
with coupling width w, creating a convolutional-like structure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
def split_protograph(B, w=2, seed=None):
|
||||||
|
"""
|
||||||
|
Split a protograph base matrix into w component matrices.
|
||||||
|
|
||||||
|
Each edge in B (entry >= 0) is randomly assigned to exactly one of the
|
||||||
|
w component matrices. The component that receives the edge gets value 0
|
||||||
|
(circulant shift assigned later during chain construction), while all
|
||||||
|
other components get -1 (no connection) at that position.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected.
|
||||||
|
w: Coupling width (number of component matrices).
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of w component matrices, each with shape (m_base, n_base).
|
||||||
|
Component values: 0 where edge is assigned, -1 otherwise.
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
m_base, n_base = B.shape
|
||||||
|
|
||||||
|
# Initialize all components to -1 (no connection)
|
||||||
|
components = [np.full((m_base, n_base), -1, dtype=np.int16) for _ in range(w)]
|
||||||
|
|
||||||
|
for r in range(m_base):
|
||||||
|
for c in range(n_base):
|
||||||
|
if B[r, c] >= 0:
|
||||||
|
# Randomly assign this edge to one component
|
||||||
|
chosen = rng.integers(0, w)
|
||||||
|
components[chosen][r, c] = 0
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
|
|
||||||
|
def build_sc_chain(B, L=20, w=2, z=32, seed=None):
|
||||||
|
"""
|
||||||
|
Build the full SC-LDPC parity-check matrix as a dense binary matrix.
|
||||||
|
|
||||||
|
The chain has L CN positions and L+w-1 VN positions. Position t's CNs
|
||||||
|
connect to VN positions t, t+1, ..., t+w-1 using the w component matrices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
B: Base matrix (m_base x n_base) where B[r,c] >= 0 means connected.
|
||||||
|
L: Chain length (number of CN positions).
|
||||||
|
w: Coupling width.
|
||||||
|
z: Lifting factor (circulant size).
|
||||||
|
seed: Random seed for reproducibility.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
H_full: Binary parity-check matrix of shape
|
||||||
|
(L * m_base * z) x ((L + w - 1) * n_base * z).
|
||||||
|
components: List of w component matrices from split_protograph.
|
||||||
|
meta: Dictionary with construction metadata.
|
||||||
|
"""
|
||||||
|
rng = np.random.default_rng(seed)
|
||||||
|
m_base, n_base = B.shape
|
||||||
|
|
||||||
|
# Split the protograph into w components
|
||||||
|
# Use a sub-seed derived from the main seed for splitting
|
||||||
|
split_seed = int(rng.integers(0, 2**31)) if seed is not None else None
|
||||||
|
components = split_protograph(B, w=w, seed=split_seed)
|
||||||
|
|
||||||
|
n_cn_positions = L
|
||||||
|
n_vn_positions = L + w - 1
|
||||||
|
|
||||||
|
total_rows = n_cn_positions * m_base * z
|
||||||
|
total_cols = n_vn_positions * n_base * z
|
||||||
|
|
||||||
|
H_full = np.zeros((total_rows, total_cols), dtype=np.int8)
|
||||||
|
|
||||||
|
# For each CN position t, for each component i, connect CN group t
|
||||||
|
# to VN group t+i using component B_i with random circulant shifts.
|
||||||
|
for t in range(L):
|
||||||
|
for i in range(w):
|
||||||
|
vn_pos = t + i # VN position this component connects to
|
||||||
|
comp = components[i]
|
||||||
|
|
||||||
|
for r in range(m_base):
|
||||||
|
for c in range(n_base):
|
||||||
|
if comp[r, c] >= 0:
|
||||||
|
# This entry is connected; assign a random circulant shift
|
||||||
|
shift = int(rng.integers(0, z))
|
||||||
|
|
||||||
|
# Place the z x z circulant permutation matrix
|
||||||
|
# CN rows: [t * m_base * z + r * z, ... + (r+1)*z)
|
||||||
|
# VN cols: [vn_pos * n_base * z + c * z, ... + (c+1)*z)
|
||||||
|
for zz in range(z):
|
||||||
|
row_idx = t * m_base * z + r * z + zz
|
||||||
|
col_idx = vn_pos * n_base * z + c * z + (zz + shift) % z
|
||||||
|
H_full[row_idx, col_idx] = 1
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
'L': L,
|
||||||
|
'w': w,
|
||||||
|
'z': z,
|
||||||
|
'm_base': m_base,
|
||||||
|
'n_base': n_base,
|
||||||
|
'total_rows': total_rows,
|
||||||
|
'total_cols': total_cols,
|
||||||
|
'n_cn_positions': n_cn_positions,
|
||||||
|
'n_vn_positions': n_vn_positions,
|
||||||
|
'rate_design': 1.0 - (total_rows / total_cols),
|
||||||
|
}
|
||||||
|
|
||||||
|
return H_full, components, meta
|
||||||
|
|
||||||
|
|
||||||
|
def sc_encode(info_bits, H_full, k_total):
|
||||||
|
"""
|
||||||
|
Encode using GF(2) Gaussian elimination on the SC-LDPC parity-check matrix.
|
||||||
|
|
||||||
|
Places info_bits in the first k_total positions of the codeword and solves
|
||||||
|
for the remaining parity bits such that H_full * codeword = 0 (mod 2).
|
||||||
|
|
||||||
|
Handles rank-deficient H matrices (common with SC-LDPC boundary effects)
|
||||||
|
by leaving free variables as zero.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
info_bits: Information bits array of length k_total.
|
||||||
|
H_full: Binary parity-check matrix (m x n).
|
||||||
|
k_total: Number of information bits.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
codeword: Binary codeword array of length n (= H_full.shape[1]).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If encoding fails (syndrome is nonzero).
|
||||||
|
"""
|
||||||
|
m_full, n_full = H_full.shape
|
||||||
|
n_parity = n_full - k_total
|
||||||
|
|
||||||
|
assert len(info_bits) == k_total, (
|
||||||
|
f"info_bits length {len(info_bits)} != k_total {k_total}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Split H into information and parity parts
|
||||||
|
# H_full = [H_info | H_parity]
|
||||||
|
H_info = H_full[:, :k_total]
|
||||||
|
H_parity = H_full[:, k_total:]
|
||||||
|
|
||||||
|
# Compute RHS: H_info * info_bits mod 2
|
||||||
|
rhs = (H_info @ info_bits) % 2
|
||||||
|
|
||||||
|
# Solve H_parity * p = rhs (mod 2) via Gaussian elimination
|
||||||
|
# Build augmented matrix [H_parity | rhs]
|
||||||
|
aug = np.zeros((m_full, n_parity + 1), dtype=np.int8)
|
||||||
|
aug[:, :n_parity] = H_parity.copy()
|
||||||
|
aug[:, n_parity] = rhs
|
||||||
|
|
||||||
|
# Forward elimination with partial pivoting
|
||||||
|
pivot_row = 0
|
||||||
|
pivot_cols = []
|
||||||
|
|
||||||
|
for col in range(n_parity):
|
||||||
|
# Find a pivot row for this column
|
||||||
|
found = -1
|
||||||
|
for r in range(pivot_row, m_full):
|
||||||
|
if aug[r, col] == 1:
|
||||||
|
found = r
|
||||||
|
break
|
||||||
|
|
||||||
|
if found < 0:
|
||||||
|
# No pivot for this column (rank deficient) - skip
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Swap pivot row into position
|
||||||
|
if found != pivot_row:
|
||||||
|
aug[[pivot_row, found]] = aug[[found, pivot_row]]
|
||||||
|
|
||||||
|
# Eliminate all other rows with a 1 in this column
|
||||||
|
for r in range(m_full):
|
||||||
|
if r != pivot_row and aug[r, col] == 1:
|
||||||
|
aug[r] = (aug[r] + aug[pivot_row]) % 2
|
||||||
|
|
||||||
|
pivot_cols.append(col)
|
||||||
|
pivot_row += 1
|
||||||
|
|
||||||
|
# Back-substitute: extract parity bit values from pivot columns
|
||||||
|
parity = np.zeros(n_parity, dtype=np.int8)
|
||||||
|
for i, col in enumerate(pivot_cols):
|
||||||
|
parity[col] = aug[i, n_parity]
|
||||||
|
|
||||||
|
# Assemble codeword: [info_bits | parity]
|
||||||
|
codeword = np.concatenate([info_bits.astype(np.int8), parity])
|
||||||
|
|
||||||
|
# Verify syndrome
|
||||||
|
syndrome = (H_full @ codeword) % 2
|
||||||
|
if not np.all(syndrome == 0):
|
||||||
|
raise ValueError(
|
||||||
|
f"SC-LDPC encoding failed: syndrome weight = {syndrome.sum()}, "
|
||||||
|
f"H rank ~{len(pivot_cols)}, H rows = {m_full}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return codeword
|
||||||
63
model/test_sc_ldpc.py
Normal file
63
model/test_sc_ldpc.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Tests for SC-LDPC construction."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
class TestSCLDPCConstruction:
|
||||||
|
"""Tests for SC-LDPC chain construction."""
|
||||||
|
|
||||||
|
def test_split_preserves_edges(self):
|
||||||
|
"""Split B into components, verify sum(B_i >= 0) == (B >= 0) for each position."""
|
||||||
|
from sc_ldpc import split_protograph
|
||||||
|
from ldpc_sim import H_BASE
|
||||||
|
components = split_protograph(H_BASE, w=2, seed=42)
|
||||||
|
assert len(components) == 2
|
||||||
|
# Each edge in B should appear in exactly one component
|
||||||
|
for r in range(H_BASE.shape[0]):
|
||||||
|
for c in range(H_BASE.shape[1]):
|
||||||
|
if H_BASE[r, c] >= 0:
|
||||||
|
count = sum(1 for comp in components if comp[r, c] >= 0)
|
||||||
|
assert count == 1, f"Edge ({r},{c}) in {count} components, expected 1"
|
||||||
|
else:
|
||||||
|
for comp in components:
|
||||||
|
assert comp[r, c] < 0, f"Edge ({r},{c}) should be absent"
|
||||||
|
|
||||||
|
def test_chain_dimensions(self):
|
||||||
|
"""Build chain with L=5, w=2, Z=32. Verify H dimensions."""
|
||||||
|
from sc_ldpc import build_sc_chain
|
||||||
|
from ldpc_sim import H_BASE
|
||||||
|
m_base, n_base = H_BASE.shape
|
||||||
|
L, w, z = 5, 2, 32
|
||||||
|
H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42)
|
||||||
|
expected_rows = L * m_base * z # 5*7*32 = 1120
|
||||||
|
expected_cols = (L + w - 1) * n_base * z # 6*8*32 = 1536
|
||||||
|
assert H_full.shape == (expected_rows, expected_cols), (
|
||||||
|
f"Expected ({expected_rows}, {expected_cols}), got {H_full.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_chain_has_coupled_structure(self):
|
||||||
|
"""Verify non-zero blocks only appear at positions t and t+1 for w=2."""
|
||||||
|
from sc_ldpc import build_sc_chain
|
||||||
|
from ldpc_sim import H_BASE
|
||||||
|
m_base, n_base = H_BASE.shape
|
||||||
|
L, w, z = 5, 2, 32
|
||||||
|
H_full, components, meta = build_sc_chain(H_BASE, L=L, w=w, z=z, seed=42)
|
||||||
|
# For each CN position t, check which VN positions have connections
|
||||||
|
for t in range(L):
|
||||||
|
cn_rows = slice(t * m_base * z, (t + 1) * m_base * z)
|
||||||
|
for v in range(L + w - 1):
|
||||||
|
vn_cols = slice(v * n_base * z, (v + 1) * n_base * z)
|
||||||
|
block = H_full[cn_rows, vn_cols]
|
||||||
|
has_connections = np.any(block != 0)
|
||||||
|
if t <= v <= t + w - 1:
|
||||||
|
# Should have connections (t connects to t..t+w-1)
|
||||||
|
assert has_connections, f"CN pos {t} should connect to VN pos {v}"
|
||||||
|
else:
|
||||||
|
# Should NOT have connections
|
||||||
|
assert not has_connections, f"CN pos {t} should NOT connect to VN pos {v}"
|
||||||
Reference in New Issue
Block a user