engne1.py

 avatar
quoc14
python
5 months ago
6.4 kB
5
Indexable
FaceRC
import os
import torch
import pandas as pd
from torchvision.transforms import Compose, ToTensor, Normalize
import inspect
from face_engine.huggingface_model_utils import load_model_from_local_path
import gc
import sys
import time

# Check devices
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


class ModelLoader():
    def __init__(self, model_path: str):
        self.__model = load_model_from_local_path(f'/home/quoc14/Code/HeThongNhanDang/model/{model_path}').to(DEVICE)
    def instance(self):
        return self.__model

class FaceEngine():
    def __init__(self) -> None:
        self.__csv_directory = '/home/quoc14/Code/HeThongNhanDang/csv_models/'
        os.makedirs(self.__csv_directory, exist_ok=True)  # Tạo thư mục lưu CSV cho các mô hình nếu chưa tồn tại
        
        self.__available_models = {
            "ir18_webface4m" : "minchul/cvlface_adaface_ir18_webface4m",
            "ir50_webface4m" : "minchul/cvlface_adaface_ir50_webface4m",
            "ir101_webface4m": "minchul/cvlface_adaface_ir101_webface4m",
            "ir101_webface12m": "minchul/cvlface_adaface_ir101_webface12m",
            "vit_base_kprpe_webface4m" : "minchul/cvlface_adaface_vit_base_kprpe_webface4m",
            "vit_base_kprpe_webface12m" : "minchul/cvlface_adaface_vit_base_kprpe_webface12m",
            "vit_base_webface4m": "minchul/cvlface_adaface_vit_base_webface4m"
        }
        self._extractor = None
        self._aligner = ModelLoader("minchul/cvlface_DFA_mobilenet").instance()
        self.__model_name = "ir101_webface4m"  # Mặc định là mô hình này
        self.__database_path = self.get_csv_path(self.__model_name)  # Đường dẫn đến CSV của model hiện tại
        self.threshold = 0.3
        self.load_csv()

    def get_csv_path(self, model_name):
        """Trả về đường dẫn file CSV dựa trên tên model."""
        return os.path.join(self.__csv_directory, f'{model_name}_face_db.csv')

    def reset_model(self):
        if hasattr(self, '_extractor'):
            del self._extractor
    
        # Dọn dẹp bộ nhớ và giải phóng GPU
        gc.collect()
        torch.cuda.empty_cache()

        # Xóa các module đã tải từ mô hình trước
        modules_to_delete = [key for key in sys.modules if "model" in key or "transformers" in key]
        for module in modules_to_delete:
            del sys.modules[module]

    def load_model(self, model_key):
        model_path = self.__available_models[model_key]
        print(f"Đang tải mô hình: {model_key} từ {model_path}")
        return ModelLoader(model_path).instance()

    def set_model(self, model_key):
        if model_key in self.__available_models:
            self.reset_model()
            self._extractor = self.load_model(model_key)
            self.__model_name = model_key
            self.__database_path = self.get_csv_path(model_key)  # Cập nhật đường dẫn CSV cho model mới
            print(f"Đã chuyển đổi sang mô hình: {model_key}")
            self.load_csv()  # Tải CSV tương ứng với mô hình
        else:
            print(f"Mô hình {model_key} không tồn tại!")

    def load_csv(self):
        """Load hoặc tạo mới file CSV cho model hiện tại."""
        if not os.path.exists(self.__database_path):
            # Tạo file CSV nếu chưa tồn tại
            columns = ['id', 'feat']
            df = pd.DataFrame(columns=columns)
            df.to_csv(self.__database_path, index=False)
            self.__next_id = 1  # Bắt đầu ID từ 1
        else:
            self.__db = pd.read_csv(self.__database_path)
            if pd.to_numeric(self.__db['id'], errors='coerce').isna().all():
                self.__next_id = 1
            else:
                self.__next_id = self.__db['id'].max() + 1  # Tính ID kế tiếp

    # Hàm chuẩn hóa ảnh và căn chỉnh
    def pil_to_input(self, pil_image):
        trans = Compose([ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
        return trans(pil_image).unsqueeze(0).to(DEVICE)

    def get_feat(self, pil_image):
        # Chuẩn hóa ảnh và căn chỉnh
        input_tensor = self.pil_to_input(pil_image)
        aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, bbox = self._aligner(input_tensor)

        # Đo thời gian bắt đầu trích xuất đặc trưng
        start_time = time.time()

        # Nhận diện đặc trưng (feature)
        input_signature = inspect.signature(self._extractor.model.net.forward)
        if input_signature.parameters.get('keypoints') is not None:
            feat = self._extractor(aligned_x, aligned_ldmks)
        else:
            feat = self._extractor(aligned_x)

        # Đo thời gian kết thúc trích xuất đặc trưng
        end_time = time.time()
        extraction_time = end_time - start_time

        return feat, extraction_time

    def compute_cosine_similarity(self, pil_image_1, pil_image_2):
        feature_1, _ = self.get_feat(pil_image_1)
        feature_2, _ = self.get_feat(pil_image_2)
        cosine_similarity = torch.nn.functional.cosine_similarity(feature_1, feature_2).item()
        return cosine_similarity

    def get_id(self, pil_image):
        if not os.path.exists(self.__database_path):
            return None

        feat_input, extraction_time = self.get_feat(pil_image)
        max_sim = -1
        matched_id = None

        for i, row in self.__db.iterrows():
            feat_db = torch.tensor(eval(row['feat']), device=DEVICE)
            cossim = torch.nn.functional.cosine_similarity(feat_input, feat_db).item()
            if cossim > self.threshold and cossim > max_sim:
                max_sim = cossim
                matched_id = row['id']

        if matched_id:
            return {"id": matched_id, "extraction_time": extraction_time, "status": "found"}
        else:
            return {"id": None, "extraction_time": extraction_time, "status": "not_found"}

    def save_to_db(self, pil_image):
        new_id = self.__next_id
        feature, _ = self.get_feat(pil_image)

        new_row = pd.DataFrame({'id': [new_id], 'feat': [str(feature.squeeze().cpu().detach().numpy().tolist())]})
        self.__db = pd.concat([self.__db, new_row], ignore_index=True)
        self.__db.to_csv(self.__database_path, index=False)
        self.__next_id += 1
        return new_id
Editor is loading...
Leave a Comment