engine3.py
quoc14
python
5 months ago
8.4 kB
2
Indexable
FaceRC
import os import torch import pandas as pd from torchvision.transforms import Compose, ToTensor, Normalize import inspect from PIL import Image import gc import sys from face_engine.huggingface_model_utils import load_model_from_local_path # Kiểm tra thiết bị (GPU hoặc CPU) DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' class ModelLoader: def __init__(self, model_path: str): """Khởi tạo và load mô hình từ đường dẫn đã chỉ định.""" self.__model = load_model_from_local_path(f'/home/quoc14/Code/HeThongNhanDang/model/{model_path}').to(DEVICE) def instance(self): """Trả về mô hình đã load.""" return self.__model class FaceEngine: def __init__(self, csv_file='/home/quoc14/Code/HeThongNhanDang/face_features_400.csv') -> None: """Khởi tạo FaceEngine và thiết lập file CSV dùng để lưu và so sánh.""" self.__database_path = csv_file # File CSV mặc định để lưu đặc trưng khuôn mặt self.__available_models = { "ir18_webface4m": "minchul/cvlface_adaface_ir18_webface4m", "ir50_webface4m": "minchul/cvlface_adaface_ir50_webface4m", "ir101_webface4m": "minchul/cvlface_adaface_ir101_webface4m", "ir101_webface12m": "minchul/cvlface_adaface_ir101_webface12m", "vit_base_kprpe_webface4m": "minchul/cvlface_adaface_vit_base_kprpe_webface4m", "vit_base_kprpe_webface12m": "minchul/cvlface_adaface_vit_base_webface12m" } self._extractor = None self._aligner = ModelLoader("minchul/cvlface_DFA_resnet50").instance() self.__model_name = "ir101_webface4m" # Mặc định sử dụng mô hình này self.threshold = 0.3 # Ngưỡng để nhận diện self.load_csv() # Tải cơ sở dữ liệu từ file CSV và xử lý các tensor đặc trưng def load_csv(self): """Load file CSV chứa các đặc trưng khuôn mặt và chuyển đổi đặc trưng từ JSON sang tensor.""" if not os.path.exists(self.__database_path): # Tạo file CSV nếu chưa tồn tại columns = ['id', 'feat', 'name', 'image_path', 'date_registered'] df = pd.DataFrame(columns=columns) df.to_csv(self.__database_path, index=False) self.__db = pd.DataFrame(columns=columns) # Khởi tạo database rỗng else: # Chỉ load các cột cần thiết để tiết kiệm bộ nhớ self.__db = pd.read_csv(self.__database_path, dtype={'id': str}, usecols=['id', 'feat']) # Chuyển đổi tất cả các đặc trưng từ chuỗi JSON sang tensor một lần def convert_to_tensor(feat_str): try: # Chuyển chuỗi JSON thành danh sách sau đó thành tensor feat_list = eval(feat_str) if isinstance(feat_list, list): return torch.tensor(feat_list, device=DEVICE) else: raise ValueError("Feature is not a valid list.") except Exception as e: print(f"Error converting feature: {e}") return None self.__db['feat_tensor'] = self.__db['feat'].apply(convert_to_tensor) # Loại bỏ các hàng có đặc trưng không hợp lệ self.__db = self.__db.dropna(subset=['feat_tensor']) def reset_model(self): """Reset mô hình hiện tại và giải phóng bộ nhớ.""" if hasattr(self, '_extractor'): del self._extractor # Dọn dẹp bộ nhớ và giải phóng GPU gc.collect() torch.cuda.empty_cache() # Xóa các module đã tải từ mô hình trước modules_to_delete = [key for key in sys.modules if "model" in key or "transformers" in key] for module in modules_to_delete: del sys.modules[module] def load_model(self, model_key): """Load mô hình dựa trên từ khóa model_key.""" model_path = self.__available_models[model_key] print(f"Đang tải mô hình: {model_key} từ {model_path}") return ModelLoader(model_path).instance() def set_model(self, model_key): """Đặt mô hình hiện tại dựa trên lựa chọn của người dùng.""" if model_key in self.__available_models: self.reset_model() # Reset mô hình trước khi chuyển đổi self._extractor = self.load_model(model_key) # Tải mô hình mới self.__model_name = model_key # Cập nhật tên mô hình hiện tại print(f"Đã chuyển đổi sang mô hình: {model_key}") else: print(f"Mô hình {model_key} không tồn tại!") # Hàm chuyển đổi ảnh PIL thành tensor và chuẩn hóa def pil_to_input(self, pil_image): trans = Compose([ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]) return trans(pil_image).unsqueeze(0).to(DEVICE) def get_feat(self, pil_image): """Trích xuất đặc trưng từ ảnh.""" input_tensor = self.pil_to_input(pil_image) aligned_x, _, aligned_ldmks, _, _, _ = self._aligner(input_tensor) # Kiểm tra mô hình xem có cần keypoints không input_signature = inspect.signature(self._extractor.model.net.forward) if input_signature.parameters.get('keypoints') is not None: feat = self._extractor(aligned_x, aligned_ldmks) else: feat = self._extractor(aligned_x) return feat def compute_cosine_similarity(self, feat_input, feats_db): """Tính toàn bộ cosine similarity giữa đặc trưng đầu vào và tất cả các đặc trưng trong cơ sở dữ liệu.""" return torch.nn.functional.cosine_similarity(feat_input, feats_db) def get_id(self, pil_image): """So sánh ảnh với cơ sở dữ liệu và trả về ID nếu tìm thấy.""" if not os.path.exists(self.__database_path): return {"id": None, "status": "not_found"} # Trích xuất đặc trưng của ảnh đầu vào feat_input = self.get_feat(pil_image) # Chuyển danh sách đặc trưng của cơ sở dữ liệu thành tensor để tính toán feats_db = torch.stack(self.__db['feat_tensor'].values.tolist()).to(DEVICE) # Tính toán cosine similarity giữa ảnh đầu vào và tất cả các ảnh trong cơ sở dữ liệu similarities = self.compute_cosine_similarity(feat_input, feats_db) # Lấy ra ảnh có cosine lớn nhất max_sim, idx_max = similarities.max(0) if max_sim.item() > self.threshold: # Đã nhận diện thành công best_match_id = self.__db.iloc[idx_max.item()]['id'] return {"id": best_match_id, "status": "found"} else: return {"id": None, "status": "not_found"} def save_to_db(self, pil_image, user_id): """Lưu đặc trưng của ảnh mới với ID người dùng.""" feature = self.get_feat(pil_image) if not isinstance(feature, torch.Tensor): return "Feature extraction failed, not saving to database." # Chuyển đổi tensor sang danh sách để lưu thành chuỗi JSON feature_list = feature.squeeze().cpu().detach().numpy().tolist() new_row = pd.DataFrame({ 'id': [user_id], 'name': ["quoc"], 'feat': [feature_list], # Lưu đặc trưng dưới dạng danh sách JSON # Để trống nếu không có 'image_path': ["quoc"], # Để trống nếu không có 'date_registered': [pd.Timestamp.now()] }) self.__db = pd.concat([self.__db, new_row], ignore_index=True) self.__db.loc[self.__db['id'] == user_id, 'feat_tensor'] = self.__db.loc[self.__db['id'] == user_id, 'feat'].apply(lambda x: torch.tensor(x, device=DEVICE)) # Thay vì ghi lại toàn bộ file, chỉ ghi thêm dòng mới with open(self.__database_path, 'a') as f: new_row.to_csv(f, header=False, index=False) return f"Data for user {user_id} saved successfully."
Editor is loading...
Leave a Comment