diff --git a/model/__pycache__/ldpc_sim.cpython-313.pyc b/model/__pycache__/ldpc_sim.cpython-313.pyc new file mode 100644 index 0000000..94e7d62 Binary files /dev/null and b/model/__pycache__/ldpc_sim.cpython-313.pyc differ diff --git a/model/ldpc_sim.py b/model/ldpc_sim.py index 1165f1e..d3b0dba 100644 --- a/model/ldpc_sim.py +++ b/model/ldpc_sim.py @@ -34,15 +34,25 @@ OFFSET = 1 # min-sum offset (integer) # Base matrix: H_BASE[row][col] = cyclic shift, -1 = no connection # This must match the RTL exactly! +# +# Structure: IRA (Irregular Repeat-Accumulate) with staircase parity +# Column 0: information bits, connected to ALL 7 rows (dv=7, high protection) +# Columns 1-7: parity bits with lower-triangular staircase (dv=2, except col7=1) +# +# Encoding: purely sequential (lower triangular parity submatrix) +# Row 0 → solve p1; Row 1 → solve p2 (using p1); ... Row 6 → solve p7 +# +# VN degree distribution: col0=7, cols1-6=2, col7=1 +# Average VN degree = (7*32 + 2*192 + 1*32) / 256 = 2.5 H_BASE = np.array([ - [ 0, 5, 11, 17, 23, 29, 3, 9], - [15, 0, 21, 7, 13, 19, 25, 31], - [10, 20, 0, 30, 8, 16, 24, 2], - [27, 14, 1, 0, 18, 6, 12, 22], - [ 4, 28, 16, 12, 0, 26, 8, 20], - [19, 9, 31, 25, 15, 0, 21, 11], - [22, 26, 6, 14, 30, 10, 0, 18], -], dtype=np.int8) + [ 0, 5, -1, -1, -1, -1, -1, -1], # row 0: info(0) + p1(5) + [11, 3, 0, -1, -1, -1, -1, -1], # row 1: info(11) + p1(3) + p2(0) + [17, -1, 7, 0, -1, -1, -1, -1], # row 2: info(17) + p2(7) + p3(0) + [23, -1, -1, 13, 0, -1, -1, -1], # row 3: info(23) + p3(13) + p4(0) + [29, -1, -1, -1, 19, 0, -1, -1], # row 4: info(29) + p4(19) + p5(0) + [ 3, -1, -1, -1, -1, 25, 0, -1], # row 5: info(3) + p5(25) + p6(0) + [ 9, -1, -1, -1, -1, -1, 31, 0], # row 6: info(9) + p6(31) + p7(0) +], dtype=np.int16) def build_full_h_matrix(): @@ -50,62 +60,67 @@ def build_full_h_matrix(): H = np.zeros((M, N), dtype=np.int8) for r in range(M_BASE): for c in range(N_BASE): - shift = H_BASE[r, c] + shift = int(H_BASE[r, c]) if shift < 0: continue # null sub-matrix # Cyclic permutation matrix of size Z with shift for z in range(Z): - H[r * Z + z, c * Z + (z + shift) % Z] = 1 + col_idx = int(c * Z + (z + shift) % Z) + H[r * Z + z, col_idx] = 1 return H +def cyclic_shift(vec, shift): + """Cyclic right-shift a Z-length vector by 'shift' positions.""" + s = int(shift) % len(vec) + if s == 0: + return vec.copy() + return np.concatenate([vec[-s:], vec[:-s]]) + + def ldpc_encode(info_bits, H): """ - Systematic encoding: info bits are the first K bits of codeword. - Solve H * c^T = 0 for parity bits given info bits. + Sequential encoding exploiting the IRA staircase structure. - For a systematic code, H = [H_p | H_i] where H_p is invertible. - c = [info | parity], H_p * parity^T = H_i * info^T (mod 2) + QC-LDPC convention: H_BASE[r][c] = s means sub-block is circulant P_s + where P_s applied to vector x gives: result[i] = x[(i+s) % Z] (left shift by s). - This uses dense GF(2) Gaussian elimination. Fine for small codes. + Lower-triangular parity submatrix allows top-to-bottom sequential solve: + Row 0: P_{h00}*info + P_{h01}*p1 = 0 -> p1 = P_{h01}^-1 * P_{h00} * info + Row r: P_{hr0}*info + P_{hrr}*p[r] + P_{hr,r+1}*p[r+1] = 0 + -> p[r+1] = P_{hr,r+1}^-1 * (P_{hr0}*info + P_{hrr}*p[r]) + + P_s * x = np.roll(x, -s) (left circular shift) + P_s^-1*y = np.roll(y, +s) (right circular shift) """ - # info_bits goes in columns 0..K-1 (first base column = info) - # Parity bits in columns K..N-1 + assert len(info_bits) == K - # We need to solve: H[:,K:] * p = H[:,:K] * info (mod 2) - H_p = H[:, K:].copy() # M x (N-K) = 224 x 224 - H_i = H[:, :K].copy() # M x K = 224 x 32 + def apply_p(x, s): + """Apply circulant P_s: result[i] = x[(i+s)%Z].""" + return np.roll(x, -int(s)) - syndrome = H_i @ info_bits % 2 # M-vector + def inv_p(y, s): + """Apply P_s inverse: P_{-s}.""" + return np.roll(y, int(s)) - # Gaussian elimination on H_p to solve for parity - n_parity = N - K # 224 - assert H_p.shape == (M, n_parity) + # Parity blocks: p[c] is a Z-length binary vector for base column c (1..7) + p = [np.zeros(Z, dtype=np.int8) for _ in range(N_BASE)] - # Augmented matrix [H_p | syndrome] - aug = np.hstack([H_p, syndrome.reshape(-1, 1)]).astype(np.int8) + # Step 1: Solve row 0 for p[1] + # P_{H[0][0]} * info + P_{H[0][1]} * p1 = 0 + # p1 = P_{H[0][1]}^-1 * P_{H[0][0]} * info + accum = apply_p(info_bits, H_BASE[0, 0]) + p[1] = inv_p(accum, H_BASE[0, 1]) - # Forward elimination - pivot_row = 0 - for col in range(n_parity): - # Find pivot - found = False - for row in range(pivot_row, M): - if aug[row, col] == 1: - aug[[pivot_row, row]] = aug[[row, pivot_row]] - found = True - break - if not found: - continue # skip this column (rank deficient) + # Step 2: Solve rows 1-6 sequentially for p[2]..p[7] + for r in range(1, M_BASE): + # P_{H[r][0]}*info + P_{H[r][r]}*p[r] + P_{H[r][r+1]}*p[r+1] = 0 + accum = apply_p(info_bits, H_BASE[r, 0]) + accum = (accum + apply_p(p[r], H_BASE[r, r])) % 2 + p[r + 1] = inv_p(accum, H_BASE[r, r + 1]) - # Eliminate - for row in range(M): - if row != pivot_row and aug[row, col] == 1: - aug[row] = (aug[row] + aug[pivot_row]) % 2 - pivot_row += 1 - - parity = aug[:n_parity, -1] # solution - codeword = np.concatenate([info_bits, parity]) + # Assemble codeword: [info | p1 | p2 | ... | p7] + codeword = np.concatenate([info_bits] + [p[c] for c in range(1, N_BASE)]) # Verify check = H @ codeword % 2 @@ -140,18 +155,22 @@ def poisson_channel(codeword, lam_s, lam_b): # Compute exact LLR for each observation # P(y|1) = (lam_s+lam_b)^y * exp(-(lam_s+lam_b)) / y! # P(y|0) = lam_b^y * exp(-lam_b) / y! - # LLR = y * log((lam_s+lam_b)/lam_b) - lam_s + # + # Convention: LLR = log(P(y|0) / P(y|1)) (positive = bit 0 more likely) + # This matches decoder: positive belief -> hard decision 0, negative -> 1 + # + # LLR = lam_s - y * log((lam_s + lam_b) / lam_b) llr = np.zeros(n, dtype=np.float64) for i in range(n): y = photon_counts[i] if lam_b > 0: - llr[i] = y * np.log((lam_s + lam_b) / lam_b) - lam_s + llr[i] = lam_s - y * np.log((lam_s + lam_b) / lam_b) else: # No background: click = definitely bit 1, no click = definitely bit 0 if y > 0: - llr[i] = 100.0 # strong positive (bit=1) + llr[i] = -100.0 # strong negative (bit=1 likely) else: - llr[i] = -lam_s # no photons, likely bit=0 + llr[i] = lam_s # no photons, likely bit=0 return llr, photon_counts @@ -246,37 +265,37 @@ def decode_layered_min_sum(llr_q, max_iter=30, early_term=True): for iteration in range(max_iter): # Process each base matrix row (layer) for row in range(M_BASE): + # Find which columns are connected in this row + connected_cols = [c for c in range(N_BASE) if H_BASE[row, c] >= 0] + dc = len(connected_cols) + # Step 1: Compute VN->CN messages by subtracting old CN->VN - vn_to_cn = [[0]*Z for _ in range(N_BASE)] - for col in range(N_BASE): + # vn_to_cn[col_pos][z] where col_pos indexes into connected_cols + vn_to_cn = [[0]*Z for _ in range(dc)] + for ci, col in enumerate(connected_cols): shift = int(H_BASE[row, col]) - if shift < 0: - continue for z in range(Z): shifted_z = (z + shift) % Z bit_idx = col * Z + shifted_z old_msg = msg[row][col][z] - vn_to_cn[col][z] = sat_sub_q(beliefs[bit_idx], old_msg) + vn_to_cn[ci][z] = sat_sub_q(beliefs[bit_idx], old_msg) - # Step 2: CN min-sum update - cn_to_vn = [[0]*Z for _ in range(N_BASE)] + # Step 2: CN min-sum update (only over connected columns) + cn_to_vn = [[0]*Z for _ in range(dc)] for z in range(Z): - # Gather messages from all columns for this check node - cn_inputs = [vn_to_cn[col][z] for col in range(N_BASE)] + cn_inputs = [vn_to_cn[ci][z] for ci in range(dc)] cn_outputs = min_sum_cn_update(cn_inputs) - for col in range(N_BASE): - cn_to_vn[col][z] = cn_outputs[col] + for ci in range(dc): + cn_to_vn[ci][z] = cn_outputs[ci] # Step 3: Update beliefs and store new messages - for col in range(N_BASE): + for ci, col in enumerate(connected_cols): shift = int(H_BASE[row, col]) - if shift < 0: - continue for z in range(Z): shifted_z = (z + shift) % Z bit_idx = col * Z + shifted_z - new_msg = cn_to_vn[col][z] - extrinsic = vn_to_cn[col][z] + new_msg = cn_to_vn[ci][z] + extrinsic = vn_to_cn[ci][z] beliefs[bit_idx] = sat_add_q(extrinsic, new_msg) msg[row][col][z] = new_msg