engine3.py
quoc14
python
a year ago
8.4 kB
11
Indexable
FaceRC
import os
import torch
import pandas as pd
from torchvision.transforms import Compose, ToTensor, Normalize
import inspect
from PIL import Image
import gc
import sys
from face_engine.huggingface_model_utils import load_model_from_local_path
# Kiểm tra thiết bị (GPU hoặc CPU)
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
class ModelLoader:
def __init__(self, model_path: str):
"""Khởi tạo và load mô hình từ đường dẫn đã chỉ định."""
self.__model = load_model_from_local_path(f'/home/quoc14/Code/HeThongNhanDang/model/{model_path}').to(DEVICE)
def instance(self):
"""Trả về mô hình đã load."""
return self.__model
class FaceEngine:
def __init__(self, csv_file='/home/quoc14/Code/HeThongNhanDang/face_features_400.csv') -> None:
"""Khởi tạo FaceEngine và thiết lập file CSV dùng để lưu và so sánh."""
self.__database_path = csv_file # File CSV mặc định để lưu đặc trưng khuôn mặt
self.__available_models = {
"ir18_webface4m": "minchul/cvlface_adaface_ir18_webface4m",
"ir50_webface4m": "minchul/cvlface_adaface_ir50_webface4m",
"ir101_webface4m": "minchul/cvlface_adaface_ir101_webface4m",
"ir101_webface12m": "minchul/cvlface_adaface_ir101_webface12m",
"vit_base_kprpe_webface4m": "minchul/cvlface_adaface_vit_base_kprpe_webface4m",
"vit_base_kprpe_webface12m": "minchul/cvlface_adaface_vit_base_webface12m"
}
self._extractor = None
self._aligner = ModelLoader("minchul/cvlface_DFA_resnet50").instance()
self.__model_name = "ir101_webface4m" # Mặc định sử dụng mô hình này
self.threshold = 0.3 # Ngưỡng để nhận diện
self.load_csv() # Tải cơ sở dữ liệu từ file CSV và xử lý các tensor đặc trưng
def load_csv(self):
"""Load file CSV chứa các đặc trưng khuôn mặt và chuyển đổi đặc trưng từ JSON sang tensor."""
if not os.path.exists(self.__database_path):
# Tạo file CSV nếu chưa tồn tại
columns = ['id', 'feat', 'name', 'image_path', 'date_registered']
df = pd.DataFrame(columns=columns)
df.to_csv(self.__database_path, index=False)
self.__db = pd.DataFrame(columns=columns) # Khởi tạo database rỗng
else:
# Chỉ load các cột cần thiết để tiết kiệm bộ nhớ
self.__db = pd.read_csv(self.__database_path, dtype={'id': str}, usecols=['id', 'feat'])
# Chuyển đổi tất cả các đặc trưng từ chuỗi JSON sang tensor một lần
def convert_to_tensor(feat_str):
try:
# Chuyển chuỗi JSON thành danh sách sau đó thành tensor
feat_list = eval(feat_str)
if isinstance(feat_list, list):
return torch.tensor(feat_list, device=DEVICE)
else:
raise ValueError("Feature is not a valid list.")
except Exception as e:
print(f"Error converting feature: {e}")
return None
self.__db['feat_tensor'] = self.__db['feat'].apply(convert_to_tensor)
# Loại bỏ các hàng có đặc trưng không hợp lệ
self.__db = self.__db.dropna(subset=['feat_tensor'])
def reset_model(self):
"""Reset mô hình hiện tại và giải phóng bộ nhớ."""
if hasattr(self, '_extractor'):
del self._extractor
# Dọn dẹp bộ nhớ và giải phóng GPU
gc.collect()
torch.cuda.empty_cache()
# Xóa các module đã tải từ mô hình trước
modules_to_delete = [key for key in sys.modules if "model" in key or "transformers" in key]
for module in modules_to_delete:
del sys.modules[module]
def load_model(self, model_key):
"""Load mô hình dựa trên từ khóa model_key."""
model_path = self.__available_models[model_key]
print(f"Đang tải mô hình: {model_key} từ {model_path}")
return ModelLoader(model_path).instance()
def set_model(self, model_key):
"""Đặt mô hình hiện tại dựa trên lựa chọn của người dùng."""
if model_key in self.__available_models:
self.reset_model() # Reset mô hình trước khi chuyển đổi
self._extractor = self.load_model(model_key) # Tải mô hình mới
self.__model_name = model_key # Cập nhật tên mô hình hiện tại
print(f"Đã chuyển đổi sang mô hình: {model_key}")
else:
print(f"Mô hình {model_key} không tồn tại!")
# Hàm chuyển đổi ảnh PIL thành tensor và chuẩn hóa
def pil_to_input(self, pil_image):
trans = Compose([ToTensor(), Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
return trans(pil_image).unsqueeze(0).to(DEVICE)
def get_feat(self, pil_image):
"""Trích xuất đặc trưng từ ảnh."""
input_tensor = self.pil_to_input(pil_image)
aligned_x, _, aligned_ldmks, _, _, _ = self._aligner(input_tensor)
# Kiểm tra mô hình xem có cần keypoints không
input_signature = inspect.signature(self._extractor.model.net.forward)
if input_signature.parameters.get('keypoints') is not None:
feat = self._extractor(aligned_x, aligned_ldmks)
else:
feat = self._extractor(aligned_x)
return feat
def compute_cosine_similarity(self, feat_input, feats_db):
"""Tính toàn bộ cosine similarity giữa đặc trưng đầu vào và tất cả các đặc trưng trong cơ sở dữ liệu."""
return torch.nn.functional.cosine_similarity(feat_input, feats_db)
def get_id(self, pil_image):
"""So sánh ảnh với cơ sở dữ liệu và trả về ID nếu tìm thấy."""
if not os.path.exists(self.__database_path):
return {"id": None, "status": "not_found"}
# Trích xuất đặc trưng của ảnh đầu vào
feat_input = self.get_feat(pil_image)
# Chuyển danh sách đặc trưng của cơ sở dữ liệu thành tensor để tính toán
feats_db = torch.stack(self.__db['feat_tensor'].values.tolist()).to(DEVICE)
# Tính toán cosine similarity giữa ảnh đầu vào và tất cả các ảnh trong cơ sở dữ liệu
similarities = self.compute_cosine_similarity(feat_input, feats_db)
# Lấy ra ảnh có cosine lớn nhất
max_sim, idx_max = similarities.max(0)
if max_sim.item() > self.threshold:
# Đã nhận diện thành công
best_match_id = self.__db.iloc[idx_max.item()]['id']
return {"id": best_match_id, "status": "found"}
else:
return {"id": None, "status": "not_found"}
def save_to_db(self, pil_image, user_id):
"""Lưu đặc trưng của ảnh mới với ID người dùng."""
feature = self.get_feat(pil_image)
if not isinstance(feature, torch.Tensor):
return "Feature extraction failed, not saving to database."
# Chuyển đổi tensor sang danh sách để lưu thành chuỗi JSON
feature_list = feature.squeeze().cpu().detach().numpy().tolist()
new_row = pd.DataFrame({
'id': [user_id],
'name': ["quoc"],
'feat': [feature_list], # Lưu đặc trưng dưới dạng danh sách JSON
# Để trống nếu không có
'image_path': ["quoc"], # Để trống nếu không có
'date_registered': [pd.Timestamp.now()]
})
self.__db = pd.concat([self.__db, new_row], ignore_index=True)
self.__db.loc[self.__db['id'] == user_id, 'feat_tensor'] = self.__db.loc[self.__db['id'] == user_id, 'feat'].apply(lambda x: torch.tensor(x, device=DEVICE))
# Thay vì ghi lại toàn bộ file, chỉ ghi thêm dòng mới
with open(self.__database_path, 'a') as f:
new_row.to_csv(f, header=False, index=False)
return f"Data for user {user_id} saved successfully."
Editor is loading...
Leave a Comment