Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
1.7 kB
22
Indexable
Never
import numpy as np
import pywt
from PIL import Image
import matplotlib.pyplot as plt

def F2Coeffs(F):
    # Coeffs is the Coefficients list [cAn, (cHn, cVn, cDn), ..., (cH1, cV1, cD1)]
    # For this problem, n is 7
    # cHn, cVn, and cDn are 2D matrices with the size of 2^(7-n) x 2^(7-n)
    # In Matlab, the Coeffs is flatten as a vector.
    # Now we reform it to a coefficient list

    # Input: F, a vector with size of 16384,
    # Output: Coeffs, see https://pywavelets.readthedocs.io/en/latest/ref/2d-dwt-and-idwt.html#d-multilevel-decomposition-using-wavedec2 for explanation

    Coeffs = []
    # cAn
    Coeffs.append(np.reshape(F[0], (1, 1)))
    # (cHn, cVn, cDn)
    pt = 1
    for n in range(7, 0, -1):
        dim = 2 ** (7 - n)
        coeff = np.reshape(F[pt:pt + 3 * dim * dim], (3, dim, dim))
        coeff = np.transpose(coeff, (0, 2, 1))
        cHn, cVn, cDn = list(coeff)
        Coeffs.append((cHn, cVn, cDn))
        pt += 3 * dim * dim
    return Coeffs 


def calculate_error(X, Y):
    return np.sum(np.linalg.norm(X - Y, 2))

def get_max_col_norm(R):
    return np.max(np.linalg.norm(R, axis=0))

def iht(R, P, K, alpha=1, n_iter=500, eta=1):
    R = R * alpha
    print(get_max_col_norm(R), alpha)
    F = np.zeros(R.shape[1]).reshape(-1, 1)
    N = F.shape[0]
    error_PRF = []
    for _ in range(n_iter):
        new_F = F + eta*(R.T @ (P - (R @ F)))
        abs_F = np.abs(new_F)
        new_F[np.argsort(abs_F, axis=0)[:N-K]] = 0
        error_PRF.append(calculate_error(P, R @ new_F))
        delta_f = np.linalg.norm(new_F - F, 2)
        print(delta_f)
        F = new_F
    plt.plot(error_PRF)
    plt.show()
    print(error_PRF[-1])
    return F*alpha