Untitled

 avatar
unknown
plain_text
a year ago
11 kB
6
Indexable
"""
This project is developed by Haofan Wang to support face swap in single frame. Multi-frame will be supported soon!

It is highly built on the top of insightface, sd-webui-roop and CodeFormer.
"""

import os
import cv2
import copy
import insightface
import numpy as np
from numpy import dot
from numpy.linalg import norm
from PIL import Image
from typing import Dict, List, Union
from collections import namedtuple
import numpy as np
from queue import PriorityQueue
from model.common.request.schema import SubjectModel, FaceSwapModel, ConditioningModel
from model.common.request.request import Request
from model.common.util.timing import CodeTimer
from model.common.face.analyzer.insightface import FaceAnalyzer
import torch
import facer
import logging

logger = logging.getLogger()


def cosine_similarity(a, b):
    return dot(a, b)/(norm(a)*norm(b))


def region_to_full_resolution(region, image_width, image_height):
    x = int(region.x * image_width)
    y = int(region.y * image_height)
    width = int(region.width * image_width)
    height = int(region.height * image_height)

    return [x, y, width, height]


def calculate_distance(face_bbox, intended_region_bbox):
    face_center = ((face_bbox[0] + face_bbox[2]) / 2, (face_bbox[1] + face_bbox[3]) / 2)
    region_center = ((intended_region_bbox[0] + intended_region_bbox[2]) / 2,
                     (intended_region_bbox[1] + intended_region_bbox[3]) / 2)

    distance = np.sqrt((face_center[0] - region_center[0]) ** 2 + (face_center[1] - region_center[1]) ** 2)

    image_diagonal = np.sqrt((intended_region_bbox[2] - intended_region_bbox[0])
                             ** 2 + (intended_region_bbox[3] - intended_region_bbox[1]) ** 2)
    normalized_distance = distance / image_diagonal

    return normalized_distance


def face_score(source_face, target_face, intended_region_bbox, distance_weight, similarity_weight, gender_weight):
    normalized_distance = calculate_distance(target_face.bbox, intended_region_bbox)
    similarity = cosine_similarity(source_face.embedding, target_face.embedding)
    normalized_similarity = (similarity + 1) / 2  # normalize to [0, 1]
    gender_score = 1 if source_face.sex == target_face.sex else 0
    score = (distance_weight * (1 - normalized_distance) +
             similarity_weight * normalized_similarity +
             gender_weight * gender_score)
    return score


def create_mask(vis_img: np.ndarray) -> np.ndarray:
    """
    Create a binary mask from a given image based on a threshold.

    Args:
    vis_img (np.ndarray): The input image for mask creation.

    Returns:
    np.ndarray: The binary mask of the image.
    """
    threshold = 0.0  # Adjust this threshold as needed
    binary_mask = vis_img > threshold

    # White object on a black background
    object_color, background_color = 255, 0

    white_mask = binary_mask * object_color
    black_mask = ~binary_mask * background_color

    result_mask = white_mask + black_mask
    return result_mask


class Swapper:
    def __init__(self, device: torch.device, model_rootpath: str) -> None:
        self._model_rootpath = model_rootpath
        self._device = device
        self._loaded = False

    def load(self, face_analyzer: FaceAnalyzer):
        self._loaded = True
        load_timer = CodeTimer()
        model_path = os.path.join(self._model_rootpath, "inswapper_128.onnx")
        self._face_analyzer = face_analyzer

        get_model_timer = load_timer.start("swapper_get_model")
        cuda_provider = ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"})
        self._face_swapper = insightface.model_zoo.get_model(
            model_path,  providers=[cuda_provider])
        get_model_timer.finish()

        face_parser_timer = load_timer.start("swapper_face_parser")
        self._face_parser = facer.face_parser(name='farl/lapa/448',
                                              ckpt_path=os.path.join(self._model_rootpath, "farl.pt"), device=self._device)
        face_parser_timer.finish()

        force_load_timer = load_timer.start("swapper_force_load")
        self._force_model_load()
        force_load_timer.finish()
        logger.info("Swapper times %s", load_timer.times())

    def _ensure_loaded(self):
        if not self._loaded:
            self.load()

    def _force_model_load(self):
        # TODO: This is a hack to force the model to load and not on the first inference.
        # It seems to be the "recognition" model that is causing the issue.
        current_dir = os.path.dirname(os.path.abspath(__file__))
        image_path = os.path.join(current_dir, "images/face-female.jpg")
        image = Image.open(image_path)

        face_swap = FaceSwapModel(face_reference_image_url="")
        face_swap._face_reference_image = image
        conditioning = ConditioningModel(prompt="", negative_prompt="")
        subjects = [
            SubjectModel(face_swap=face_swap, conditioning=conditioning)]
        request = Request(subjects=subjects)
        request.process()
        self.process(image, subjects)

    def _get_one_face(self, frame: np.ndarray):
        face = self._face_analyzer.process(frame)
        try:
            return min(face, key=lambda x: x.bbox[0])
        except ValueError:
            return None

    def _get_many_faces(self,
                        frame: np.ndarray):
        """
        get faces from left to right by order
        """
        try:
            face = self._face_analyzer.process(frame)
            return sorted(face, key=lambda x: x.bbox[0])
        except IndexError:
            return None

    def get_face_parsing_mask(self, face_image: torch.Tensor, yv5_faces: dict) -> np.ndarray:
        """
        Generate a face parsing mask for the given face image.

        Args:
        face_image (torch.Tensor): The face image as a PyTorch tensor.
        yv5_faces (dict): A dictionary containing the face data.

        Returns:
        np.ndarray: The face parsing mask as a numpy array.
        """
        with torch.inference_mode():
            torch._C._jit_set_profiling_executor(False)
            torch._C._jit_set_profiling_mode(False)
            torch.jit._state.disable()
            faces = self._face_parser(face_image, yv5_faces)

        seg_logits = faces['seg']['logits']
        seg_probs = seg_logits.softmax(dim=1)  # nfaces x nclasses x h x w
        n_classes = seg_probs.size(1)
        vis_seg_probs = seg_probs.argmax(dim=1).float() / n_classes * 255
        vis_img = vis_seg_probs.sum(0, keepdim=True)

        result_mask = create_mask(vis_img)
        result_mask = torch.Tensor(result_mask)

        mask_array = result_mask.squeeze().cpu().numpy()
        image_seg_array = np.uint8(mask_array[..., None])

        return image_seg_array

    def _swap_face(self,
                   source_face,
                   target_face,
                   temp_frame):
        """
        paste source_face on target image
        """
        img_mask, bgr_fake, face_image, yv5_faces = self._face_swapper.get(
            temp_frame, target_face, source_face, paste_back=True)

        rgb_fake_mask = self.get_face_parsing_mask(face_image, yv5_faces) / 255
        img_mask = rgb_fake_mask * img_mask
        fake_merged = img_mask * bgr_fake + (1 - img_mask) * temp_frame.astype(np.float32)
        fake_merged = fake_merged.astype(np.uint8)
        return fake_merged

    def swap_faces(self, temp_frame, target_faces, subjects: List[SubjectModel], min_score=0.2, distance_weight=0.3, similarity_weight=0.2, gender_weight=0.5):
        face_scores = PriorityQueue()

        # Calculate scores for all possible source-target pairs
        for i, subject in enumerate(subjects):
            face_reference_image = subject.face_swap._face_reference_image
            if face_reference_image is None:
                logger.debug(f"Skipping subject {i} because no face reference image was provided")
                continue
            source_face = self._get_one_face(cv2.cvtColor(
                np.array(face_reference_image), cv2.COLOR_RGB2BGR))
            if source_face is None:
                logger.debug(f"Skipping subject {i} because no face was found")
                continue
            subject.face_swap._face = source_face
            intended_region_bbox = region_to_full_resolution(subject.region, temp_frame.shape[1], temp_frame.shape[0]) \
                if subjects else [0, 0, temp_frame.shape[1], temp_frame.shape[0]]
            for j, target_face in enumerate(target_faces):
                score = face_score(source_face, target_face, intended_region_bbox, distance_weight,
                                   similarity_weight, gender_weight)
                if score >= min_score:
                    # Use negative score because PriorityQueue returns smallest first
                    face_scores.put((-score, i, j))
                else:
                    logger.debug(f"Skipping source face {i} with target face {j}. Score: {score}")

        used_source_indices = set()
        used_target_indices = set()

        # Match faces based on the highest scores
        while not face_scores.empty():
            score, source_index, target_index = face_scores.get()
            if source_index not in used_source_indices and target_index not in used_target_indices:
                logger.debug(f"Swapping source face {source_index} with target face {target_index}. Score: {-score}")
                used_source_indices.add(source_index)
                used_target_indices.add(target_index)
                source_face = subjects[source_index].face_swap._face
                temp_frame = self._swap_face(
                    source_face, target_faces[target_index], temp_frame)
        return temp_frame

    def process(self,
                target_img: Image.Image,
                subjects: List[SubjectModel],
                ):
        self._ensure_loaded()
        # read target image
        target_img_np = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR)

        # detect faces that will be replaced in the target image
        target_faces = self._get_many_faces(target_img_np)

        for face in target_faces:
            logger.debug(f"Face detected in target image at {face.bbox} with score {face.det_score}")

        if target_faces is None or len(target_faces) == 0:
            logger.debug("No target faces found!")
            result_image = target_img
        else:
            result = self.swap_faces(target_img_np, target_faces, subjects)
            result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB))
        return result_image
Editor is loading...
Leave a Comment