Untitled
unknown
plain_text
8 months ago
2.5 kB
6
Indexable
import numpy as np
def pivotedOuterProductLU(A):
"""
Perform an LU factorization with column pivoting via an
outer-product (rank-1) update.
Input: A, an n x n NumPy array (float or int).
Output: L, U, P (all n x n)
so that A * P = L * U,
L is lower-triangular (1s on diag), U is upper-triangular,
P is the permutation matrix for the column swaps.
"""
# Convert A to floating point in case of integer input
A = A.astype(float)
n = A.shape[0]
# Initialize U as a copy of A; L as identity; P as identity
U = A.copy()
L = np.eye(n, dtype=float)
P = np.eye(n, dtype=float)
for k in range(n - 1):
# 1) Find pivot in row k, among columns k..n-1
# pivot_col is the column index of the largest absolute value in U(k,k..end)
pivot_col_relative = np.argmax(np.abs(U[k, k:]))
pivot_col = k + pivot_col_relative
# 2) If pivot_col != k, swap columns in U, P, and L
if pivot_col != k:
# Swap columns k and pivot_col in U
U[:, [k, pivot_col]] = U[:, [pivot_col, k]]
# Swap columns k and pivot_col in P
P[:, [k, pivot_col]] = P[:, [pivot_col, k]]
# Swap columns in L (only below row k is strictly necessary, but simpler to swap all)
if k > 0:
L[:, [k, pivot_col]] = L[:, [pivot_col, k]]
# 3) The pivot is U[k, k]
pivot_val = U[k, k]
if abs(pivot_val) < 1.0e-14:
raise ValueError("Zero (or very small) pivot encountered. Factorization fails.")
# 4) Fill multipliers L(i,k) for i>k
for i in range(k+1, n):
L[i, k] = U[i, k] / pivot_val
# 5) Outer product (rank-1) update of submatrix of U
for i in range(k+1, n):
for j in range(k+1, n):
U[i, j] -= L[i, k] * U[k, j]
return L, U, P
if __name__ == "__main__":
# Matrix from the exercise
A_test = np.array([
[ 4, 3, 2, 1],
[ 8, 8, 5, 2],
[16, 12, 10, 5],
[32, 24, 20, 11]
], dtype=float)
L, U, P = pivotedOuterProductLU(A_test)
print("A =\n", A_test, "\n")
print("L =\n", L, "\n")
print("U =\n", U, "\n")
print("P =\n", P, "\n")
# Check that A*P and L*U match
lhs = A_test @ P
rhs = L @ U
print("A * P =\n", lhs, "\n")
print("L * U =\n", rhs, "\n")
print("Difference norm =", np.linalg.norm(lhs - rhs))
Editor is loading...
Leave a Comment