engine.py

 avatar
quoc14
python
5 months ago
4.7 kB
2
Indexable
FaceRC
import os
import torch
import pandas as pd
from torchvision.transforms import Compose, ToTensor, Normalize
import inspect
from face_engine.huggingface_model_utils import load_model_from_local_path


# Check devices
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


class ModelLoader():
    def __init__(self, model_path: str):
        self.__model = load_model_from_local_path(f'face_engine/models/{model_path}').to(DEVICE)
    def instance(self):
        return self.__model

class FaceEngine():
    def __init__(self) -> None:
        self.__database_path = './face_db.csv'
        self.__extractor = ModelLoader("minchul/cvlface_adaface_vit_base_webface4m").instance()
        self.__aligner = ModelLoader("minchul/cvlface_DFA_mobilenet").instance()
        self.threshold = 0.3
        self.load_csv()
        pass

    def reset_csv(self):
        # Tạo một DataFrame rỗng với các cột tiêu đề
        columns = ['id', 'feat']
        df = pd.DataFrame(columns=columns)
        df.to_csv(self.__database_path, index=False)

    def load_csv(self):
        # Kiểm tra nếu file CSV không tồn tại hoặc trống
        if not os.path.exists(self.__database_path) or os.stat(self.__database_path).st_size == 0:
            # Nếu không tồn tại hoặc trống, bắt đầu ID từ 1
            self.__next_id = 1
            self.__db = pd.DataFrame(columns=['id', 'feat'])
            self.__db.to_csv(self.__database_path, index=False)
        else:
            self.__db = pd.read_csv(self.__database_path)
            # Đảm bảo rằng cột 'id' không bị lỗi và có thể chuyển thành số
            if pd.to_numeric(self.__db['id'], errors='coerce').isna().all():
                self.__next_id = 1  # Bắt đầu lại nếu tất cả giá trị trong 'id' là NaN
            else:
                self.__next_id = self.__db['id'].max() + 1  # Tính ID kế tiếp

    # Hàm chuẩn hóa ảnh và căn chỉnh
    def pil_to_input(self, pil_image):
        trans = Compose([ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        return trans(pil_image).unsqueeze(0).to(DEVICE)

    def get_feat(self, pil_image):
        # Chuẩn hóa ảnh và căn chỉnh
        input_tensor = self.pil_to_input(pil_image)
        # Căn chỉnh khuôn mặt
        aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, bbox = self.__aligner(input_tensor)

        # Nhận diện đặc trưng (feature)
        input_signature = inspect.signature(self.__extractor.model.net.forward)
        if input_signature.parameters.get('keypoints') is not None:
            feat = self.__extractor(aligned_x, aligned_ldmks)
        else:
            feat = self.__extractor(aligned_x)

        return feat

    def compute_cosine_similarity(self, pil_image_1, pil_image_2):

        # Lấy đặc trưng của hai ảnh
        feature_1 = self.get_feat(pil_image_1)
        feature_2 = self.get_feat(pil_image_2)

        # Tính toán cosine similarity
        cosine_similarity = torch.nn.functional.cosine_similarity(feature_1, feature_2).item()
        return cosine_similarity


    def get_id(self, pil_image):
        if not os.path.exists(self.__database_path):
            return None
        
        # Lấy đặc trưng (feature)
        feat_input = self.get_feat(pil_image)

        # So sánh với từng ảnh trong CSDL
        max_sim = -1  # Biến lưu giá trị cosine similarity lớn nhất
        matched_id = None

        for i, row in self.__db.iterrows():
            # Chuyển đổi đặc trưng từ CSDL thành tensor
            feat_db = torch.tensor(eval(row['feat']), device=DEVICE)

            # Tính cosine similarity giữa ảnh đầu vào và ảnh trong CSDL
            cossim = torch.nn.functional.cosine_similarity(feat_input, feat_db).item()

            # Kiểm tra nếu cosine similarity lớn hơn ngưỡng và lớn hơn giá trị max_sim hiện tại
            if cossim > self.threshold and cossim > max_sim:
                max_sim = cossim
                matched_id = row['id']

        return matched_id

    # Lưu feature mới vào CSDL
    def save_to_db(self, pil_image):
        new_id = self.__next_id

        feature = self.get_feat(pil_image)

        new_row = pd.DataFrame({'id': [new_id], 'feat': [str(feature.squeeze().cpu().detach().numpy().tolist())]})
        self.__db = pd.concat([self.__db, new_row], ignore_index=True)
        self.__db.to_csv(self.__database_path, index=False)

        self.__next_id = self.__next_id + 1
        
        return new_id
Editor is loading...
Leave a Comment