Untitled
unknown
plain_text
a year ago
5.7 kB
8
Indexable
import os
import torch
import torch.nn as nn
from torch.optim import Adam
from dataset import load_tiny_imagenet
from tqdm import tqdm
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
"""
初始化 Residual Block。
Args:
in_channels (int): 輸入通道數。
out_channels (int): 輸出通道數。
stride (int): 卷積的步幅。
"""
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# 如果輸入和輸出通道數不同,或者步幅不為 1,使用 downsample 升維或降維
self.downsample = None
if stride != 1 or in_channels != out_channels:
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
# 使用 nn.Sequential 來定義卷積層
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
ResidualBlock(64, 128, stride=2),
ResidualBlock(128, 256, stride=2),
ResidualBlock(256, 512, stride=2),
ResidualBlock(512, 1024, stride=2) # 最終特徵大小為 4x4
)
def forward(self, x):
return self.encoder(x)
# Classifier 定義
class Classifier(nn.Module):
def __init__(self, num_classes=200):
super(Classifier, self).__init__()
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, num_classes)
def forward(self, x):
x = x.mean((2,3)) # Global Average Pooling
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Top-N 準確率計算
def eval_top_n_accuracy(predictions, labels, n=1):
_, top_n_preds = predictions.topk(n, dim=1, largest=True, sorted=True)
correct = top_n_preds.eq(labels.view(-1, 1).expand_as(top_n_preds))
return correct.sum().item() / labels.size(0)
# 訓練模型
def train_model(encoder, classifier, train_loader, test_loader, device, epochs=10, lr=1e-3, top_n=5):
encoder.to(device)
classifier.to(device)
# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = Adam(list(encoder.parameters()) + list(classifier.parameters()), lr=lr)
for epoch in range(epochs):
encoder.train()
classifier.train()
running_loss = 0.0
total_samples = 0
total_correct = 0
for images, labels in tqdm(train_loader):
images, labels = images.to(device), labels.to(device)
# Forward pass
features = encoder(images)
outputs = classifier(features)
loss = criterion(outputs, labels)
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 記錄訓練損失和準確率
running_loss += loss.item() * labels.size(0)
_, preds = outputs.max(1)
total_correct += preds.eq(labels).sum().item()
total_samples += labels.size(0)
train_loss = running_loss / total_samples
train_acc = total_correct / total_samples
# 評估模型
encoder.eval()
classifier.eval()
total_top1_acc = 0
total_topn_acc = 0
total_samples = 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
features = encoder(images)
outputs = classifier(features)
total_top1_acc += eval_top_n_accuracy(outputs, labels, n=1) * labels.size(0)
total_topn_acc += eval_top_n_accuracy(outputs, labels, n=top_n) * labels.size(0)
total_samples += labels.size(0)
top1_acc = total_top1_acc / total_samples
topn_acc = total_topn_acc / total_samples
print(f"Epoch [{epoch+1}/{epochs}] - Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
f"Top-1 Acc: {top1_acc:.4f}, Top-{top_n} Acc: {topn_acc:.4f}")
if __name__ == "__main__":
# Tiny ImageNet 資料集路徑
data_dir = "./datasets/tiny-imagenet-200/"
batch_size = 256
epochs = 20
learning_rate = 1e-4
top_n = 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加載資料集
train_loader, test_loader = load_tiny_imagenet(data_dir, batch_size)
# 初始化模型
encoder = Encoder()
classifier = Classifier(num_classes=200)
# 訓練模型
train_model(encoder, classifier, train_loader, test_loader, device, epochs, learning_rate, top_n)
Editor is loading...
Leave a Comment