Untitled
unknown
plain_text
a year ago
2.5 kB
12
Indexable
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import cv2
from tqdm import tqdm
def create_emotion_model(num_ftrs, num_emotions):
return nn.Sequential(
nn.Linear(num_ftrs + num_emotions, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 2),
)
def load_model(model_path, device):
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Identity()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device).eval()
return model, num_ftrs
def load_emotion_model(num_ftrs, num_emotions, model_path, device):
emotion_model = create_emotion_model(num_ftrs, num_emotions).to(device)
emotion_model.load_state_dict(torch.load(model_path, map_location=device))
emotion_model.eval()
return emotion_model
def va_predict(val_model_path, val_featmodel_path, faces, emotions):
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the models
resnet, num_ftrs = load_model(val_featmodel_path, device)
num_emotions = 1 # Assuming single emotion feature
emotion_model = load_emotion_model(num_ftrs, num_emotions, val_model_path, device)
# Define image transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def model_forward(images, emotions):
resnet_features = resnet(images)
batch_size = resnet_features.size(0)
emotions = emotions.view(batch_size, -1)
x = torch.cat((resnet_features, emotions), dim=1)
return emotion_model(x)
arousal_list, valence_list, stress_list = [], [], []
for face, emotion in tqdm(zip(faces, emotions), total=len(faces)):
face_pil = Image.fromarray(cv2.cvtColor(face, cv2.COLOR_BGR2RGB))
face_tensor = transform(face_pil).unsqueeze(0).to(device)
emotion = emotion.to(device)
with torch.no_grad():
output_va = model_forward(face_tensor, emotion)
arousal = float(output_va[0][0].item()) / 2 + 0.5
valence = float(output_va[0][1].item()) / 2 + 0.5
stress = (1 - valence) * arousal
arousal_list.append(arousal)
valence_list.append(valence)
stress_list.append(stress)
return valence_list, arousal_list, stress_list
Editor is loading...
Leave a Comment