engine2.py

 avatar
quoc14
python
a year ago
6.1 kB
13
Indexable
FaceRC
import os
import torch
import pandas as pd
from torchvision.transforms import Compose, ToTensor, Normalize
from PIL import Image
import gc
import sys
from facenet_pytorch import MTCNN
from huggingface_model_utils import load_model_from_local_path
import inspect

# Kiểm tra thiết bị (GPU hoặc CPU)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

class ModelLoader():
    def __init__(self, model_path: str):
        """Khởi tạo và load mô hình từ đường dẫn đã chỉ định."""
        self.__model = load_model_from_local_path(f'/home/quoc14/Code/HeThongNhanDang/model/{model_path}').to(DEVICE)
    
    def instance(self):
        """Trả về mô hình đã load."""
        return self.__model

class FaceEngine():
    def __init__(self, csv_file='/home/quoc14/Code/HeThongNhanDang/face_features_400.csv') -> None:
        """Khởi tạo FaceEngine và thiết lập file CSV dùng để lưu và so sánh."""
        self.__database_path = csv_file  # File CSV mặc định để lưu đặc trưng khuôn mặt
        self.__available_models = {
            "ir18_webface4m": "minchul/cvlface_adaface_ir18_webface4m",
            "ir50_webface4m": "minchul/cvlface_adaface_ir50_webface4m",
            "ir101_webface4m": "minchul/cvlface_adaface_ir101_webface4m",
            "ir101_webface12m": "minchul/cvlface_adaface_ir101_webface12m",
            "vit_base_kprpe_webface4m": "minchul/cvlface_adaface_vit_base_kprpe_webface4m",
            "vit_base_kprpe_webface12m": "minchul/cvlface_adaface_vit_base_kprpe_webface12m",
            "vit_base_webface4m": "minchul/cvlface_adaface_vit_base_webface4m"
        }
        self._extractor = None
        self._aligner = ModelLoader("minchul/cvlface_DFA_resnet50").instance()
        self.__model_name = "ir101_webface4m"  # Mặc định sử dụng mô hình này
        self.threshold = 0.3  # Ngưỡng để nhận diện
        self.mtcnn = MTCNN(keep_all=False, device=DEVICE)  # Sử dụng MTCNN cho phát hiện khuôn mặt
        self.load_csv()  # Tải cơ sở dữ liệu từ file CSV

    def load_csv(self):
        """Load file CSV chứa các đặc trưng khuôn mặt và chuyển đổi đặc trưng từ JSON sang tensor."""
        if not os.path.exists(self.__database_path):
            # Tạo file CSV nếu chưa tồn tại
            columns = ['id', 'feat']
            df = pd.DataFrame(columns=columns)
            df.to_csv(self.__database_path, index=False)
            self.__db = pd.DataFrame(columns=columns)  # Khởi tạo database rỗng
        else:
            self.__db = pd.read_csv(self.__database_path, dtype={'id': str})
            # Chuyển đổi tất cả các đặc trưng từ chuỗi JSON sang tensor một lần
            self.__db['feat_tensor'] = self.__db['feat'].apply(lambda x: torch.tensor(eval(x), device=DEVICE))

    def reset_model(self):
        """Reset mô hình hiện tại và giải phóng bộ nhớ."""
        if hasattr(self, '_extractor'):
            del self._extractor
        
        # Dọn dẹp bộ nhớ và giải phóng GPU
        gc.collect()
        torch.cuda.empty_cache()

    def load_model(self, model_key):
        """Load mô hình dựa trên từ khóa model_key."""
        model_path = self.__available_models[model_key]
        return ModelLoader(model_path).instance()

    def set_model(self, model_key):
        """Đặt mô hình hiện tại dựa trên lựa chọn của người dùng."""
        if model_key in self.__available_models:
            self.reset_model()  # Reset mô hình trước khi chuyển đổi
            self._extractor = self.load_model(model_key)  # Tải mô hình mới
            self.__model_name = model_key  # Cập nhật tên mô hình hiện tại
        else:
            print(f"Mô hình {model_key} không tồn tại!")

    # Hàm chuẩn hóa tensor từ MTCNN
    def normalize_tensor(self, tensor):
        normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        return normalize(tensor).unsqueeze(0).to(DEVICE)  # Thêm chiều batch

    def get_feat(self, pil_image):
        """Trích xuất đặc trưng từ ảnh."""
        # Dùng MTCNN để phát hiện khuôn mặt trước khi trích xuất
        face_crop = self.mtcnn(pil_image)
        if face_crop is None:
            return None
        input_tensor = self.normalize_tensor(face_crop)
        aligned_x, _, aligned_ldmks, _, _, _ = self._aligner(input_tensor)

        input_signature = inspect.signature(self._extractor.model.net.forward)
        if input_signature.parameters.get('keypoints') is not None:
            feat = self._extractor(aligned_x, aligned_ldmks)
        else:
            feat = self._extractor(aligned_x)
        return feat

    def compute_cosine_similarity(self, feat_input, feats_db):
        """Tính toàn bộ cosine similarity giữa đặc trưng đầu vào và tất cả các đặc trưng trong cơ sở dữ liệu."""
        return torch.nn.functional.cosine_similarity(feat_input, feats_db)

    def get_id(self, pil_image):
        """So sánh ảnh với cơ sở dữ liệu và trả về ID nếu tìm thấy."""
        if not os.path.exists(self.__database_path):
            return {"id": None, "status": "not_found"}

        # Trích xuất đặc trưng của ảnh đầu vào
        feat_input = self.get_feat(pil_image)
        if feat_input is None:
            return {"id": None, "status": "not_found"}

        # Tính toán cosine similarity giữa ảnh đầu vào và tất cả các ảnh trong cơ sở dữ liệu
        feats_db = torch.stack(self.__db['feat_tensor'].values.tolist()).to(DEVICE)
        similarities = self.compute_cosine_similarity(feat_input, feats_db)

        # Lấy ra ảnh có cosine lớn nhất
        max_sim, idx_max = similarities.max(0)
        if max_sim.item() > self.threshold:
            best_match_id = self.__db.iloc[idx_max.item()]['id']
            return {"id": best_match_id, "status": "found"}
        else:
            return {"id": None, "status": "not_found"}

Editor is loading...
Leave a Comment