Untitled
unknown
plain_text
a year ago
2.5 kB
9
Indexable
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import cv2
from tqdm import tqdm
def create_emotion_model(num_ftrs, num_emotions):
return nn.Sequential(
nn.Linear(num_ftrs + num_emotions, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 2),
)
def load_model(model_path, device):
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Identity()
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device).eval()
return model, num_ftrs
def load_emotion_model(num_ftrs, num_emotions, model_path, device):
emotion_model = create_emotion_model(num_ftrs, num_emotions).to(device)
emotion_model.load_state_dict(torch.load(model_path, map_location=device))
emotion_model.eval()
return emotion_model
def va_predict(val_model_path, val_featmodel_path, faces, emotions):
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the models
resnet, num_ftrs = load_model(val_featmodel_path, device)
num_emotions = 1 # Assuming single emotion feature
emotion_model = load_emotion_model(num_ftrs, num_emotions, val_model_path, device)
# Define image transformation
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def model_forward(images, emotions):
resnet_features = resnet(images)
batch_size = resnet_features.size(0)
emotions = emotions.view(batch_size, -1)
x = torch.cat((resnet_features, emotions), dim=1)
return emotion_model(x)
arousal_list, valence_list, stress_list = [], [], []
for face, emotion in tqdm(zip(faces, emotions), total=len(faces)):
face_pil = Image.fromarray(cv2.cvtColor(face, cv2.COLOR_BGR2RGB))
face_tensor = transform(face_pil).unsqueeze(0).to(device)
emotion = emotion.to(device)
with torch.no_grad():
output_va = model_forward(face_tensor, emotion)
arousal = float(output_va[0][0].item()) / 2 + 0.5
valence = float(output_va[0][1].item()) / 2 + 0.5
stress = (1 - valence) * arousal
arousal_list.append(arousal)
valence_list.append(valence)
stress_list.append(stress)
return valence_list, arousal_list, stress_list
Editor is loading...
Leave a Comment