Untitled
unknown
plain_text
7 days ago
2.5 kB
4
Indexable
import numpy as np def pivotedOuterProductLU(A): """ Perform an LU factorization with column pivoting via an outer-product (rank-1) update. Input: A, an n x n NumPy array (float or int). Output: L, U, P (all n x n) so that A * P = L * U, L is lower-triangular (1s on diag), U is upper-triangular, P is the permutation matrix for the column swaps. """ # Convert A to floating point in case of integer input A = A.astype(float) n = A.shape[0] # Initialize U as a copy of A; L as identity; P as identity U = A.copy() L = np.eye(n, dtype=float) P = np.eye(n, dtype=float) for k in range(n - 1): # 1) Find pivot in row k, among columns k..n-1 # pivot_col is the column index of the largest absolute value in U(k,k..end) pivot_col_relative = np.argmax(np.abs(U[k, k:])) pivot_col = k + pivot_col_relative # 2) If pivot_col != k, swap columns in U, P, and L if pivot_col != k: # Swap columns k and pivot_col in U U[:, [k, pivot_col]] = U[:, [pivot_col, k]] # Swap columns k and pivot_col in P P[:, [k, pivot_col]] = P[:, [pivot_col, k]] # Swap columns in L (only below row k is strictly necessary, but simpler to swap all) if k > 0: L[:, [k, pivot_col]] = L[:, [pivot_col, k]] # 3) The pivot is U[k, k] pivot_val = U[k, k] if abs(pivot_val) < 1.0e-14: raise ValueError("Zero (or very small) pivot encountered. Factorization fails.") # 4) Fill multipliers L(i,k) for i>k for i in range(k+1, n): L[i, k] = U[i, k] / pivot_val # 5) Outer product (rank-1) update of submatrix of U for i in range(k+1, n): for j in range(k+1, n): U[i, j] -= L[i, k] * U[k, j] return L, U, P if __name__ == "__main__": # Matrix from the exercise A_test = np.array([ [ 4, 3, 2, 1], [ 8, 8, 5, 2], [16, 12, 10, 5], [32, 24, 20, 11] ], dtype=float) L, U, P = pivotedOuterProductLU(A_test) print("A =\n", A_test, "\n") print("L =\n", L, "\n") print("U =\n", U, "\n") print("P =\n", P, "\n") # Check that A*P and L*U match lhs = A_test @ P rhs = L @ U print("A * P =\n", lhs, "\n") print("L * U =\n", rhs, "\n") print("Difference norm =", np.linalg.norm(lhs - rhs))
Editor is loading...
Leave a Comment