Untitled

mail@pastecode.io avatarunknown
plain_text
a month ago
31 kB
6
Indexable
Never
import cv2 as cv
import mediapipe as mp
import scipy.io as sio
import numpy as np
import glob
# import natsort
import os
from pathlib import Path
import timeit
import scipy
import scipy.ndimage
import PIL
import PIL.Image
from PIL import Image

# import tensorflow as tf
import torch

import skimage.metrics
import lpips
import matplotlib.pyplot as plt

import sys
import shutil
import json

from sklearn.preprocessing import normalize

sys.path.append("/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/eg3d/dataset_preprocessing/ffhq/Deep3DFaceRecon_pytorch/insightface/recognition/")
from arcface_torch.backbones import get_model


MN = {
    'silhouette': [
        10,  338, 297, 332, 284, 251, 389, 356, 454, 323, 361, 288,
        397, 365, 379, 378, 400, 377, 152, 148, 176, 149, 150, 136,
        172, 58,  132, 93,  234, 127, 162, 21,  54,  103, 67,  109
    ],

    'lipsUpperOuter': [61, 185, 40, 39, 37, 0, 267, 269, 270, 409, 291],
    'lipsLowerOuter': [146, 91, 181, 84, 17, 314, 405, 321, 375, 291],
    'lipsUpperInner': [78, 191, 80, 81, 82, 13, 312, 311, 310, 415, 308],
    'lipsLowerInner': [78, 95, 88, 178, 87, 14, 317, 402, 318, 324, 308],

    'rightEyeUpper0': [246, 161, 160, 159, 158, 157, 173],
    'rightEyeLower0': [33, 7, 163, 144, 145, 153, 154, 155, 133],
    'rightEyeUpper1': [247, 30, 29, 27, 28, 56, 190],
    'rightEyeLower1': [130, 25, 110, 24, 23, 22, 26, 112, 243],
    'rightEyeUpper2': [113, 225, 224, 223, 222, 221, 189],
    'rightEyeLower2': [226, 31, 228, 229, 230, 231, 232, 233, 244],
    'rightEyeLower3': [143, 111, 117, 118, 119, 120, 121, 128, 245],

    'rightEyebrowUpper': [156, 70, 63, 105, 66, 107, 55, 193],
    'rightEyebrowLower': [35, 124, 46, 53, 52, 65],

    'rightEyeIris': [473, 474, 475, 476, 477],

    'leftEyeUpper0': [466, 388, 387, 386, 385, 384, 398],
    'leftEyeLower0': [263, 249, 390, 373, 374, 380, 381, 382, 362],
    'leftEyeUpper1': [467, 260, 259, 257, 258, 286, 414],
    'leftEyeLower1': [359, 255, 339, 254, 253, 252, 256, 341, 463],
    'leftEyeUpper2': [342, 445, 444, 443, 442, 441, 413],
    'leftEyeLower2': [446, 261, 448, 449, 450, 451, 452, 453, 464],
    'leftEyeLower3': [372, 340, 346, 347, 348, 349, 350, 357, 465],

    'leftEyebrowUpper': [383, 300, 293, 334, 296, 336, 285, 417],
    'leftEyebrowLower': [265, 353, 276, 283, 282, 295],

    'leftEyeIris': [468, 469, 470, 471, 472],

    'midwayBetweenEyes': [168],

    'noseTip': [1],
    'noseBottom': [2],
    'noseRightCorner': [98],
    'noseLeftCorner': [327],

    'rightCheek': [205],
    'leftCheek': [425]
};

def calc_alignment_coefficients(pa, pb):
    matrix = []
    for p1, p2 in zip(pa, pb):
        matrix.append([p1[0], p1[1], 1, 0, 0, 0, -p2[0] * p1[0], -p2[0] * p1[1]])
        matrix.append([0, 0, 0, p1[0], p1[1], 1, -p2[1] * p1[0], -p2[1] * p1[1]])

    a = np.matrix(matrix, dtype=float)
    b = np.array(pb).reshape(8)

    res = np.dot(np.linalg.inv(a.T * a) * a.T, b)
    return np.array(res).reshape(8)

def perspectiveTransform(perspectiveMatrix, sourceCorners, sourcePoints):
    '''
    perspectiveMatrix as above
    sourcePoints has shape (n,2)
    '''
    augment = np.ones((sourcePoints.shape[0], 1))
    projective_corners = np.concatenate( (sourcePoints, augment), axis=1).T

    # projective_points has shape 3xn
    projective_points = perspectiveMatrix.dot(projective_corners)

    # obtain the target_points by dividing the projective_points
    # by its last row (where it is non-zero)
    # target_points has shape (3,n).
    target_points = np.true_divide(projective_points, projective_points[-1,:])

    # so we want return points in row form
    return target_points[:2,:].T

def align_face(lm, img, output_size, file_name=None):
    """
    :param filepath: str
    :return: PIL Image
    """

    # lm_chin = lm[0: 17]  # left-right
    # lm_eyebrow_left = lm[17: 22]  # left-right
    # lm_eyebrow_right = lm[22: 27]  # left-right
    # lm_nose = lm[27: 31]  # top-down
    # lm_nostrils = lm[31: 36]  # top-down

    # left_eye, right_eye = face_mesh.calc_around_eye_bbox(lm)
    # print(face_mesh.calc_eye_dist(face_result))

    # # 虹彩検出
    # left_iris, right_iris = detect_iris(image, iris_detector, left_eye,
    #                                     right_eye)

    # # 虹彩の外接円を計算
    # left_center, left_radius = calc_min_enc_losingCircle(left_iris)
    # right_center, right_radius = calc_min_enc_losingCircle(right_iris)
    #
    #
    # lm_eye_right = left_eye.copy()
    # lm_eye_left = right_eye.copy()
    lm_eye_right = np.concatenate((lm[MN['leftEyeLower1'], :], lm[MN['leftEyeUpper1'], :]), axis=0)  # left-clockwise
    lm_eye_left = np.concatenate((lm[MN['rightEyeLower1'], :], lm[MN['rightEyeUpper1'], :]), axis=0)  # left-clockwise
    lm_mouth_outer = np.concatenate((lm[MN['lipsLowerOuter'], :], lm[MN['lipsUpperOuter'], :]), axis=0)  # left-clockwise
    lm_mouth_inner = np.concatenate((lm[MN['lipsLowerInner'], :], lm[MN['lipsUpperInner'], :]), axis=0)  # left-clockwise

    # Calculate auxiliary vectors.
    eye_left = np.mean(lm_eye_left, axis=0)

    # print(eye_left)
    eye_right = np.mean(lm_eye_right, axis=0)


    # print(eye_right)
    eye_avg = (eye_left + eye_right) * 0.5
    eye_to_eye = eye_right - eye_left
    # print(eye_to_eye)
    mouth_left = lm_mouth_outer[10]
    mouth_right = lm_mouth_outer[-1]
    mouth_avg = (mouth_left + mouth_right) * 0.5
    eye_to_mouth = mouth_avg - eye_avg

    # print(eye_to_mouth)
    # Choose oriented crop rectangle.
    x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
    x /= np.hypot(*x)
    x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
    y = np.flipud(x) * [-1, 1]
    c = eye_avg + eye_to_mouth * 0.1
    quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
    qsize = np.hypot(*x) * 2

    # read image
    # img = PIL.Image.open(filepath)

    transform_size = output_size
    enable_padding = False

    # annotated_image = np.array(img.copy())
    # for i in range(lm_eye_right.shape[0]):
    #     annotated_image = cv.circle(annotated_image, (int(lm_eye_right[i,0]), int(lm_eye_right[i,1])), radius=0, color=(0, 0, 255), thickness=10)
    # for i in range(lm_eye_left.shape[0]):
    #     annotated_image = cv.circle(annotated_image, (int(lm_eye_left[i,0]), int(lm_eye_left[i,1])), radius=0, color=(0, 0, 255), thickness=10)
    # for i in range(lm_mouth_outer.shape[0]):
    #     annotated_image = cv.circle(annotated_image, (int(lm_mouth_outer[i,0]), int(lm_mouth_outer[i,1])), radius=0, color=(0, 0, 255), thickness=10)
    #     # annotated_image = cv.circle(annotated_image, (int(lm_eye_left[i,0]), int(lm_eye_left[i,1])), radius=0, color=(0, 0, 255), thickness=10)
    #
    #
    # annotated_image = cv.circle(annotated_image, (int(eye_left[0]), int(eye_left[1])), radius=0, color=(0, 255, 0), thickness=10)
    # annotated_image = cv.circle(annotated_image, (int(eye_right[0]), int(eye_right[1])), radius=0, color=(0, 255, 0), thickness=10)
    # annotated_image = cv.circle(annotated_image, (int(mouth_left[0]), int(mouth_left[1])), radius=0, color=(0, 255, 0), thickness=10)
    # annotated_image = cv.circle(annotated_image, (int(mouth_right[0]), int(mouth_right[1])), radius=0, color=(0, 255, 0), thickness=10)
    #
    # # annotated_image = cv.circle(annotated_image, (int(lm_mouth_outer[10,0]), int(lm_mouth_outer[10,1])), radius=0, color=(0, 255, 0), thickness=10)
    # # annotated_image = cv.circle(annotated_image, (int(lm_mouth_outer[-1,0]), int(lm_mouth_outer[-1,1])), radius=0, color=(0, 255, 0), thickness=10)
    #
    # name = '/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output/landmark/' + file_name + '_debug1.jpg'
    # print(name)
    # cv.imwrite(name, annotated_image)

    # Shrink.
    shrink = int(np.floor(qsize / output_size * 0.5))
    if shrink > 1:
        rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
        img = img.resize(rsize, PIL.Image.ANTIALIAS)
        quad /= shrink
        qsize /= shrink
        lm /= shrink

    # Crop.
    border = max(int(np.rint(qsize * 0.1)), 3)
    crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
            int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
    crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
            min(crop[2] + border, img.size[0]), min(crop[3] + border, img.size[1]))
    if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
        img = img.crop(crop)
        quad -= crop[0:2]
        # print(lm[1,:])
        lm -= np.array(crop[0:2])
        # print(lm[1,:])

    # Pad.
    pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
           int(np.ceil(max(quad[:, 1]))))
    pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
           max(pad[3] - img.size[1] + border, 0))

    if enable_padding and max(pad) > border - 4:
        pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
        img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
        h, w, _ = img.shape
        y, x, _ = np.ogrid[:h, :w, :1]
        mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
                          1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
        blur = qsize * 0.02
        img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
        img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
        img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
        quad += pad[:2]
        lm += np.array(pad[:2])

    annotated_image = np.array(img.copy())
    for i in range(lm.shape[0]):
        annotated_image = cv.circle(annotated_image, (int(lm[i,0]), int(lm[i,1])), radius=0, color=(255, 0, 0), thickness=10)

        # annotated_image = cv.circle(np.array(img.copy()), (int(face_landmarks.landmark[159].x*w), int(face_landmarks.landmark[159].y*h)), radius=0, color=(0, 255, 0), thickness=10)
    # cv.imwrite(os.path.splitext(file)[0] + '_debug.jpg', annotated_image)

    # print(quad + 0.5)
    # Transform.
    img = img.transform((transform_size, transform_size), PIL.Image.QUAD, (quad + 0.5).flatten(), PIL.Image.BILINEAR)
    forward_transforms = calc_alignment_coefficients(quad + 0.5,
                                                     [[0, 0], [0, transform_size], [transform_size, transform_size], [transform_size, 0]])
    forward_transforms = np.concatenate((forward_transforms, np.array((1,))))

    lm = perspectiveTransform(np.array(forward_transforms).reshape((3,3)), np.array([[0, 0], [0, transform_size], [transform_size, transform_size], [transform_size, 0]]),
                              lm)
    if 'ours' in file_name:
        our_512_box = perspectiveTransform(np.array(forward_transforms).reshape((3,3)), np.array([[0, 0], [0, transform_size], [transform_size, transform_size], [transform_size, 0]]),
                                           np.array([[0, 0], [0, 511], [511, 511], [511, 0]]))
        our_512_box = our_512_box.clip(0, 511)


    name = os.path.join(align_folder, file_name + '.png')
    cv.imwrite(name, np.array(img.copy()))
    # annotated_image = np.array(img.copy())
    # for i in range(lm.shape[0]):
    #     annotated_image = cv.circle(annotated_image, (int(lm[i,0]), int(lm[i,1])), radius=0, color=(255, 0, 0), thickness=10)
    #     # annotated_image = cv.circle(np.array(img.copy()), (int(face_landmarks.landmark[159].x*w), int(face_landmarks.landmark[159].y*h)), radius=0, color=(0, 255, 0), thickness=10)
    # # print(os.path.splitext(os.path.basename(file_name))[0])
    # name = '/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output/landmark/' + file_name + '_debug2.jpg'
    # print(name)
    # cv.imwrite(name, annotated_image)

    if output_size < transform_size:
        img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)

    mask = np.zeros((512,512))
    import cv2
    if 'ours' in file_name:
        return lm, cv2.fillConvexPoly(mask, (our_512_box.astype(np.int64)), 1)
    #     import pdb; pdb.set_trace()

    # Return aligned image.
    else:
        return lm


def extract_foreground(image, matte):
    # calculate display resolution
    w, h = image.width, image.height
    # rw, rh = 800, int(h * 800 / (3 * w))

    # obtain predicted foreground
    image = np.asarray(image)
    if len(image.shape) == 2:
        image = image[:, :, None]
    if image.shape[2] == 1:
        image = np.repeat(image, 3, axis=2)
    elif image.shape[2] == 4:
        image = image[:, :, 0:3]
    matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255
    foreground = image * matte + np.full(image.shape, 255) * (1 - matte)

    # combine image, foreground, and alpha into one line
    # combined = np.concatenate((image, foreground, matte * 255), axis=1)
    # combined = Image.fromarray(np.uint8(combined)).resize((rw, rh))
    return foreground


def get_face_feat(img, net):
    img = cv.resize(img, (112, 112))
    img = np.transpose(img, (2, 0, 1))
    img = torch.from_numpy(img).unsqueeze(0).float()
    img.div_(255).sub_(0.5).div_(0.5)

    net.eval()
    feat = net(img).detach().numpy()
    feat = normalize(feat)

    return feat

def get_faces_sim(img1, img2, net):

    feat1 = get_face_feat(img1, net)
    feat2 = get_face_feat(img2, net)
    sim = np.dot(feat1, feat2.T)
    return sim



# specify filenames here
file_extension = 'png'
# results_folder = 'inversion/output/landmark/'
# end

net = get_model('r50', fp16=False)
net.load_state_dict(torch.load('/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/eg3d/backbone.pth'))
net.eval()

mp_drawing = mp.solutions.drawing_utils
mp_face_mesh = mp.solutions.face_mesh

# Path(results_folder).mkdir(parents=True, exist_ok=True)

max_numoffaces = 3
bool_notracking = True
pts3D = np.zeros((468, 3))
#
# # For static images:
# filenames = glob.glob('test5.' + file_extension)

# IMAGE_FILES = natsort.natsorted(filenames)

loss_fn = lpips.LPIPS(net='alex')

ref_path_list = []
# ref_list = ['2', '3', '4', '6', '8', '12', '16']
# ref_list = ['2', '3', '4', '6', '8', '12', '16']
# dist = np.array([0.6096, 0.9144, 1.2192, 1.8288, 2.4384, 3.6576, 4.8768])
ref_list = ['16']
dist = np.array([4.8768])

# %28 40 65 100 130 200 300
# %1000* 0.6096    0.9144    1.2192    1.8288    2.4384    3.6576    4.8768

# face_names = ['62RT-060329']
# model_list = ['ours', 'fried', 'Chen+ours']
# model_list = ['ours']
# model_list = [sys.argv[1]]
# print(model_list)

# file_list = glob('/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/PoseEstimate/inputs/tmp/crop/**.png')
# file_list_name = [os.path.splitext(os.path.basename(f))[0] for f in file_list]

indir = '/net/per610a/export/das18a/satoh-lab/wangzx/dataset/Perspective/'
file_list = glob.glob(indir + 'CMDP_*/**/*_2.jpg')

with open('/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output/landmark/sigresult/mapping.json', 'r') as f:
    dict_sig = json.load(f)

# import pdb; pdb.set_trace()

inver_dict_sig = {}
for key_ in dict_sig:
    inver_dict_sig[dict_sig[key_]] = key_


face_names = [os.path.splitext(os.path.basename(dict_sig[f]))[0] for f in dict_sig]
# face_names = ['X5FH-050622']
#
# face_names = ['8S7R-060329',  'VND0-060329', 'AMOH-060329',  'XWQM-060208',  'P1EV-060208',
#               '9QAD-050615',  'KKMI-060329', 'J6XX-060208', 'QZGT-060329', '5ZAZ-060208',
#               'X5FH-050622',  '92RT-050419',   'QQXY-060329', '3XKB-060208', 'VIMC-050615']
# face_names = ['3XKB-060208']

# face_names = ['62RT-060329', 'X5FH-050622']
# model_list = ['fried']

theirs_folder = '/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output/landmark/sigresult/theirs'
ours_folder = '/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output_compare/config43/dolly-slider-single'
chen_folder = '/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output_compare/config43'
pti_folder = '/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output_compare/config44'
wacv_folder = '/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output_compare/config44'

imp_folder = '/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output/landmark/sigresult/imp' #implemetation
our_model_name = 'ours'
our_transfered_model_name = 'ours_transferred'
chen_model_name = 'original_pti+chen'
wacv_model_name = 'wacv'
pti_model_name = 'original_pti'

# our_new_folder = ''

main_result_folder = os.path.join('/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output/landmark/sigresult/eval4')
os.makedirs(main_result_folder, exist_ok=True)

target_folder_dict = {os.path.splitext(os.path.basename(f))[0]: '/'.join(f.split('/')[-3:-1]) + '/' for f in file_list}
# target_folder_dict = {'62RT-060329': 'CMDP_1/15_62RT/',
#                       'X5FH-050622': 'CMDP_2/29_X5FH'}

model_list = ['input', 'sig', 'imp', 'ours', 'ours_trans', 'chen', 'wacv', 'pti']
# model_list = ['sig']

psnr = np.zeros((len(model_list), len(face_names), len(ref_list)))
ssim = np.zeros((len(model_list), len(face_names), len(ref_list)))
face_id = np.zeros((len(model_list), len(face_names), len(ref_list)))
lpips_score = np.zeros((len(model_list), len(face_names), len(ref_list)))
lmks = np.zeros((len(model_list), len(face_names), len(ref_list)))


drawing_spec = mp_drawing.DrawingSpec(thickness=1, circle_radius=1)

res = {}

# import pdb; pdb.set_trace()
# for k in range(len(model_list)):
#     model_name = model_list[k]
for j in range(len(face_names)):
    face_name = face_names[j] + '_2'
    face_result_folder = os.path.join(main_result_folder, face_name)  + '/'
    target_folder = '/net/per610a/export/das18a/satoh-lab/wangzx/dataset/Perspective/' + target_folder_dict[face_name]

    IMAGE_FILES = {}
    for ref_id in range(len(dist)):
        IMAGE_FILES['sig' + ref_list[ref_id]]     = os.path.join(theirs_folder, inver_dict_sig[face_name[:-2]] + '.jpg')
        IMAGE_FILES['imp' + ref_list[ref_id]]     = os.path.join(imp_folder, inver_dict_sig[face_name[:-2]] + f'_{int(ref_list[ref_id]) // 2}_out.jpg')
        IMAGE_FILES['reference' + ref_list[ref_id]] = os.path.join(target_folder, face_name[:-2] + '_' + ref_list[ref_id] + '.jpg')
        IMAGE_FILES['input' + ref_list[ref_id]] = os.path.join(target_folder, face_name[:-2] + '_2.jpg')

        chen_old = os.path.join(chen_folder, face_name, '_'.join((face_name, chen_model_name, ref_list[ref_id] + '.png')))
        IMAGE_FILES['chen' + ref_list[ref_id]] = chen_old

        wacv_old = os.path.join(wacv_folder, face_name, '_'.join((face_name, wacv_model_name, ref_list[ref_id] + '.png')))
        IMAGE_FILES['wacv' + ref_list[ref_id]] = wacv_old

        pti_old = os.path.join(pti_folder, face_name, '_'.join((face_name, pti_model_name, ref_list[ref_id] + '.png')))
        IMAGE_FILES['pti' + ref_list[ref_id]] = pti_old

        # our_old = os.path.join(ours_folder, face_name, '_'.join((face_name, our_model_name, ref_list[ref_id] + '.png')))
        our_old = os.path.join(ours_folder, '_'.join((face_name, our_model_name+ '.png')))
        # our_new = os.path.join(our_new_folder,  '_'.join((face_name, model_name, ref_list[ref_id] + '.png')))
        # shutil.copy(our_old, our_new)
        IMAGE_FILES['ours' + ref_list[ref_id]] = our_old
        IMAGE_FILES['ours_trans' + ref_list[ref_id]] = os.path.join(ours_folder, '_'.join((face_name, our_transfered_model_name+ '.png')))


    align_folder = os.path.join(face_result_folder, 'align/')
    if not os.path.exists(align_folder):
        os.makedirs(align_folder)

    matting_folder = os.path.join(face_result_folder, 'matting/')
    if not os.path.exists(matting_folder):
        os.makedirs(matting_folder)

    fore_folder = os.path.join(face_result_folder, 'fore/')
    if not os.path.exists(fore_folder):
        os.makedirs(fore_folder)

    with mp_face_mesh.FaceMesh(
            static_image_mode=bool_notracking,
            max_num_faces=max_numoffaces,
            min_detection_confidence=0.1,
            min_tracking_confidence=0.1) as face_mesh:

        for idx, file_name in enumerate(IMAGE_FILES):
            print(IMAGE_FILES[file_name])
            # import pdb; pdb.set_trace()
            image = cv.imread(IMAGE_FILES[file_name])
            # Convert the BGR image to RGB before processing.
            start = timeit.default_timer()
            results = face_mesh.process(cv.cvtColor(image, cv.COLOR_BGR2RGB))
            stop = timeit.default_timer()
            print('Time: ', stop - start)

            # # Print and draw face mesh landmarks on the image.
            if not results.multi_face_landmarks:
                continue
            annotated_image = image.copy()
            num_face = len(results.multi_face_landmarks)
            PupilDistance_faces = np.zeros((num_face,))
            for i in range(num_face):
                face_landmarks = results.multi_face_landmarks[i]
                PupilDistance_faces[i] = face_landmarks.landmark[386].x - face_landmarks.landmark[159].x
            # find the max face
            # print(PupilDistance_faces)
            ind = np.argmax(PupilDistance_faces)
            face_landmarks = results.multi_face_landmarks[ind]
            # print(len(face_landmarks.landmark))
            for i in range(468):
                pts3D[i, 0] = face_landmarks.landmark[i].x
                pts3D[i, 1] = face_landmarks.landmark[i].y
                pts3D[i, 2] = face_landmarks.landmark[i].z

            h, w = annotated_image.shape[0], annotated_image.shape[1]
            pts3D_rescale = pts3D[:,:2].copy()
            pts3D_rescale[:, 0] = pts3D[:, 0] * w
            pts3D_rescale[:, 1] = pts3D[:, 1] * h
            im_pil = PIL.Image.fromarray(image.copy())
            if 'our' in file_name:
                lm, out_box = align_face(pts3D_rescale, im_pil, 512, file_name=file_name)
            else:
                lm = align_face(pts3D_rescale, im_pil, 512, file_name=file_name)
            res[file_name] = lm

    for i, xk in enumerate(ref_list):
        for kk in range(len(model_list)):
            lmks[kk, j, i] = np.linalg.norm(res[model_list[kk] + xk] - res['reference' + xk])  / 512
            # lmks[1, j, i] = np.linalg.norm(res['ours' + xk] - res['reference' + xk])  / 512
            # lmks[1, j, i] = np.linalg.norm(res['ours' + xk] - res['reference' + xk])  / 512
            nm = model_list[kk]
            print(f"{nm} normalized landmark error:", lmks[kk, j, i])
        # print("our normalized landmark error:", lmks[1, j, i])




    os.chdir('/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/MODNet')
    os.system('python -m demo.image_matting.colab.inference --input-path %s --output-path %s --ckpt-path ./pretrained/modnet_photographic_portrait_matting.ckpt' % (
        align_folder,
        matting_folder
    ))
    os.chdir('/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/scripts')



    for k in range(len(model_list)):
        model_name = model_list[k]
        for i, xk in enumerate(ref_list):
            input_image_name = os.path.join(align_folder,   model_name + xk + '.png')
            input_image = Image.open(input_image_name)
            input_matte_name = os.path.join(matting_folder, model_name + xk + '.png')
            print(input_matte_name)
            input_matte = Image.open(input_matte_name)
            input_fore = extract_foreground(input_image, input_matte)*(np.repeat(out_box[:,:,None], 3, axis=2) > 0) + (np.repeat(out_box[:,:,None], 3, axis=2) == 0)*255
            # import pdb; pdb.set_trace()
            Image.fromarray(np.uint8(input_fore)).save(os.path.join(fore_folder, 'input' + xk + '_fore.png'))

            ref_image_name = os.path.join(align_folder,   'reference' + xk + '.png')
            ref_image = Image.open(ref_image_name)
            ref_matte_name = os.path.join(matting_folder, 'reference' + xk + '.png')
            ref_matte = Image.open(ref_matte_name)
            ref_fore = extract_foreground(ref_image, ref_matte)*(np.repeat(out_box[:,:,None], 3, axis=2) > 0)  + (np.repeat(out_box[:,:,None], 3, axis=2) == 0)*255
            Image.fromarray(np.uint8(ref_fore)).save(os.path.join(fore_folder, 'reference' + xk + '_fore.png'))
            # print(ref_fore.shape)

            face_id[k, j, i] = get_faces_sim((np.uint8(Image.fromarray(np.uint8(ref_fore)))),
                                             np.uint8(Image.fromarray(np.uint8(input_fore))), net)
            # import pdb; pdb.set_trace()
            psnr[k, j, i] = skimage.metrics.peak_signal_noise_ratio(np.uint8(ref_fore), np.uint8(input_fore))
            ssim[k, j, i] = skimage.metrics.structural_similarity(np.uint8(ref_fore), np.uint8(input_fore), multichannel = True)
            print("psnr:", psnr[k, j, i])
            print("ssim:", ssim[k, j, i])
            print("face_id:", face_id[k, j, i])


            ref_fore_torch = torch.tensor(ref_fore/255.).float()
            input_fore_torch = torch.tensor(input_fore/255.).float()
            d = loss_fn.forward(ref_fore_torch.permute([2,0,1]).unsqueeze(0), input_fore_torch.permute([2,0,1]).unsqueeze(0))
            lpips_score[k, j, i] = d.squeeze().detach().cpu().numpy()
            print("lpips:", lpips_score[k, j, i])


# sio.savemat('/net/per610a/export/das18a/satoh-lab/wangzx/src/Perspective/HFGI3D/inversion/output/landmark/irirs_dist.mat', {'dist1': dist1,
#                                                                                                                             'dist2': dist2,
#                                                                                                                             'z_mean': z_mean})


# with open(os.path.join(main_result_folder, model_list[k] + '_face_names.txt'), 'w') as fp:
#     fp.write("\n".join(str(item) for item in face_names))

# all_in_one{}
# all_in_one['face_id'] = face_id
# all_in_one['lmk'] = lmks
# all_in_one['psnr'] = psnr
# all_in_one['ssim'] = ssim
# all_in_one['lpips'] = lpips_score


for k in range(len(model_list)):
    np.save(os.path.join(main_result_folder, model_list[k] + '_lmk.npy'), lmks[k,:,:])
    np.save(os.path.join(main_result_folder, model_list[k] + '_psnr.npy'), psnr[k,:,:])
    np.save(os.path.join(main_result_folder, model_list[k] + '_ssim.npy'), ssim[k,:,:])
    np.save(os.path.join(main_result_folder, model_list[k] + '_face_id.npy'), face_id[k,:,:])
    np.save(os.path.join(main_result_folder, model_list[k] + '_lpips.npy'), lpips_score[k,:,:])

    print(model_list[k] + 'lmk:', lmks[k,:,:].mean())
    print(model_list[k] + 'psnr:', psnr[k,:,:].mean())
    print(model_list[k] + 'face_id:', face_id[k,:,:].mean())
    print(model_list[k] + 'ssmi:', ssim[k,:,:].mean())
    print(model_list[k] + 'lpips:', lpips_score[k,:,:].mean())
    print("----------------------------------")

for k in range(len(model_list)):

    print("{} & {:.3f}  & {:.3f}  & {:.3f}  & {:.3f}  & {:.3f}".format(model_list[k], lmks[k,:,:].mean(), psnr[k,:,:].mean(), ssim[k,:,:].mean(), lpips_score[k,:,:].mean(), face_id[k,:,:].mean()))
    # print("----------------------------------")


#
#     for j in range(len(face_names)):
#         face_result_folder = os.path.join(main_result_folder, face_names[j])  + '/'
#
#         plt.plot(dist, lmks[:, j,:].T)
#         plt.xlabel('distance [m]')
#         plt.ylabel('normalized landmark error')
#         plt.legend(model_list)
#         plt.savefig(os.path.join(face_result_folder, face_names[j] + '_' + model_name[k] + '_lmk.png'))
#         plt.close()
#
#         plt.plot(dist, psnr[:,j,:].T)
#         plt.xlabel('distance [m]')
#         plt.ylabel('PSNR [dB]')
#         plt.legend(model_list)
#         plt.savefig(os.path.join(face_result_folder, face_names[j] + '_' + model_name[k] + '_psnr.png'))
#         plt.close()
#
#         plt.plot(dist, ssim[:, j,:].T)
#         plt.xlabel('distance [m]')
#         plt.ylabel('SSIM')
#         plt.legend(model_list)
#         plt.savefig(os.path.join(face_result_folder, face_names[j] + '_' + model_name[k] + '_ssim.png'))
#         plt.close()
#
#         plt.plot(dist, lpips_score[:, j,:].T)
#         plt.xlabel('distance [m]')
#         plt.ylabel('LPIPS')
#         plt.legend(model_list)
#         plt.savefig(os.path.join(face_result_folder, face_names[j] + '_' + model_name[k] + '_lpips.png'))
#         plt.close()

line_color = ['black', 'red']
fig, ax = plt.subplots(figsize=(7, 4))
for i in range((lmks.shape[0])):
    print(dist.shape, lmks[i, ...].mean(axis=1).squeeze().T.shape)
    plt.errorbar(dist, lmks[i, ...].mean(axis=0).T, yerr=lmks[i, ...].std(axis=0).squeeze().T, fmt='o', color=line_color[i],
                 ecolor='lightgray', elinewidth=3, capsize=0)
plt.legend(model_list)
plt.xlabel('distance [m]')
plt.ylabel('normalized landmark error')
plt.savefig(os.path.join(main_result_folder, model_name  + '_all_lmk.png'))
plt.close()

fig, ax = plt.subplots(figsize=(7, 4))
for i in range((psnr.shape[0])):
    plt.errorbar(dist, psnr[i, ...].mean(axis=0).squeeze().T, yerr=psnr[i, ...].std(axis=0).squeeze().T, fmt='o', color=line_color[i],
                 ecolor='lightgray', elinewidth=3, capsize=0)
plt.xlabel('distance [m]')
plt.ylabel('PSNR [dB]')
plt.savefig(os.path.join(main_result_folder, model_name + '_all_psnr.png'))
plt.close()


fig, ax = plt.subplots(figsize=(7, 4))
for i in range((ssim.shape[0])):
    plt.errorbar(dist, ssim[i, ...].mean(axis=0).squeeze().T, yerr=ssim[i, ...].std(axis=0).squeeze().T, fmt='o', color=line_color[i],
                 ecolor='lightgray', elinewidth=3, capsize=0)
plt.legend(model_list)
plt.xlabel('distance [m]')
plt.ylabel('SSIM')
plt.savefig(os.path.join(main_result_folder, model_name + '_all_ssim.png'))
plt.close()


fig, ax = plt.subplots(figsize=(7, 4))
for i in range((lpips_score.shape[0])):
    plt.errorbar(dist, lpips_score[i, ...].mean(axis=0).squeeze().T, yerr=lpips_score[i, ...].std(axis=0).squeeze().T, fmt='o', color=line_color[i],
                 ecolor='lightgray', elinewidth=3, capsize=0)
plt.legend(model_list)
plt.xlabel('distance [m]')
plt.ylabel('LPIPS')
plt.savefig(os.path.join(main_result_folder, model_name + '_all_lpips.png'))
plt.close()