engne1.py
quoc14
python
5 months ago
6.4 kB
5
Indexable
FaceRC
import os import torch import pandas as pd from torchvision.transforms import Compose, ToTensor, Normalize import inspect from face_engine.huggingface_model_utils import load_model_from_local_path import gc import sys import time # Check devices DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' class ModelLoader(): def __init__(self, model_path: str): self.__model = load_model_from_local_path(f'/home/quoc14/Code/HeThongNhanDang/model/{model_path}').to(DEVICE) def instance(self): return self.__model class FaceEngine(): def __init__(self) -> None: self.__csv_directory = '/home/quoc14/Code/HeThongNhanDang/csv_models/' os.makedirs(self.__csv_directory, exist_ok=True) # Tạo thư mục lưu CSV cho các mô hình nếu chưa tồn tại self.__available_models = { "ir18_webface4m" : "minchul/cvlface_adaface_ir18_webface4m", "ir50_webface4m" : "minchul/cvlface_adaface_ir50_webface4m", "ir101_webface4m": "minchul/cvlface_adaface_ir101_webface4m", "ir101_webface12m": "minchul/cvlface_adaface_ir101_webface12m", "vit_base_kprpe_webface4m" : "minchul/cvlface_adaface_vit_base_kprpe_webface4m", "vit_base_kprpe_webface12m" : "minchul/cvlface_adaface_vit_base_kprpe_webface12m", "vit_base_webface4m": "minchul/cvlface_adaface_vit_base_webface4m" } self._extractor = None self._aligner = ModelLoader("minchul/cvlface_DFA_mobilenet").instance() self.__model_name = "ir101_webface4m" # Mặc định là mô hình này self.__database_path = self.get_csv_path(self.__model_name) # Đường dẫn đến CSV của model hiện tại self.threshold = 0.3 self.load_csv() def get_csv_path(self, model_name): """Trả về đường dẫn file CSV dựa trên tên model.""" return os.path.join(self.__csv_directory, f'{model_name}_face_db.csv') def reset_model(self): if hasattr(self, '_extractor'): del self._extractor # Dọn dẹp bộ nhớ và giải phóng GPU gc.collect() torch.cuda.empty_cache() # Xóa các module đã tải từ mô hình trước modules_to_delete = [key for key in sys.modules if "model" in key or "transformers" in key] for module in modules_to_delete: del sys.modules[module] def load_model(self, model_key): model_path = self.__available_models[model_key] print(f"Đang tải mô hình: {model_key} từ {model_path}") return ModelLoader(model_path).instance() def set_model(self, model_key): if model_key in self.__available_models: self.reset_model() self._extractor = self.load_model(model_key) self.__model_name = model_key self.__database_path = self.get_csv_path(model_key) # Cập nhật đường dẫn CSV cho model mới print(f"Đã chuyển đổi sang mô hình: {model_key}") self.load_csv() # Tải CSV tương ứng với mô hình else: print(f"Mô hình {model_key} không tồn tại!") def load_csv(self): """Load hoặc tạo mới file CSV cho model hiện tại.""" if not os.path.exists(self.__database_path): # Tạo file CSV nếu chưa tồn tại columns = ['id', 'feat'] df = pd.DataFrame(columns=columns) df.to_csv(self.__database_path, index=False) self.__next_id = 1 # Bắt đầu ID từ 1 else: self.__db = pd.read_csv(self.__database_path) if pd.to_numeric(self.__db['id'], errors='coerce').isna().all(): self.__next_id = 1 else: self.__next_id = self.__db['id'].max() + 1 # Tính ID kế tiếp # Hàm chuẩn hóa ảnh và căn chỉnh def pil_to_input(self, pil_image): trans = Compose([ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) return trans(pil_image).unsqueeze(0).to(DEVICE) def get_feat(self, pil_image): # Chuẩn hóa ảnh và căn chỉnh input_tensor = self.pil_to_input(pil_image) aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, bbox = self._aligner(input_tensor) # Đo thời gian bắt đầu trích xuất đặc trưng start_time = time.time() # Nhận diện đặc trưng (feature) input_signature = inspect.signature(self._extractor.model.net.forward) if input_signature.parameters.get('keypoints') is not None: feat = self._extractor(aligned_x, aligned_ldmks) else: feat = self._extractor(aligned_x) # Đo thời gian kết thúc trích xuất đặc trưng end_time = time.time() extraction_time = end_time - start_time return feat, extraction_time def compute_cosine_similarity(self, pil_image_1, pil_image_2): feature_1, _ = self.get_feat(pil_image_1) feature_2, _ = self.get_feat(pil_image_2) cosine_similarity = torch.nn.functional.cosine_similarity(feature_1, feature_2).item() return cosine_similarity def get_id(self, pil_image): if not os.path.exists(self.__database_path): return None feat_input, extraction_time = self.get_feat(pil_image) max_sim = -1 matched_id = None for i, row in self.__db.iterrows(): feat_db = torch.tensor(eval(row['feat']), device=DEVICE) cossim = torch.nn.functional.cosine_similarity(feat_input, feat_db).item() if cossim > self.threshold and cossim > max_sim: max_sim = cossim matched_id = row['id'] if matched_id: return {"id": matched_id, "extraction_time": extraction_time, "status": "found"} else: return {"id": None, "extraction_time": extraction_time, "status": "not_found"} def save_to_db(self, pil_image): new_id = self.__next_id feature, _ = self.get_feat(pil_image) new_row = pd.DataFrame({'id': [new_id], 'feat': [str(feature.squeeze().cpu().detach().numpy().tolist())]}) self.__db = pd.concat([self.__db, new_row], ignore_index=True) self.__db.to_csv(self.__database_path, index=False) self.__next_id += 1 return new_id
Editor is loading...
Leave a Comment