CONV w/ tensor core 20240823
user_3093867
python
a year ago
4.2 kB
10
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
import time
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, TensorDataset
import torch.cuda.nvtx as nvtx
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"CUDA version: {torch.version.cuda}")
class TFLiteMaximum(nn.Module):
def __init__(self, channels):
super(TFLiteMaximum, self).__init__()
self.threshold = nn.Parameter(torch.zeros(1, channels, 1, 1))
def forward(self, x):
return torch.maximum(x, self.threshold)
class CNNModel(nn.Module):
def __init__(self):
super(CNNModel, self).__init__()
# Modify Conv2d to use a multiple of 8 for both input and output channels
self.conv = nn.Conv2d(8, 64, kernel_size=3, padding=1, bias=True)
nn.init.constant_(self.conv.bias, 16)
self.maximum = TFLiteMaximum(64)
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Linear(64 * 16 * 16, 16)
def forward(self, x):
nvtx.range_push("CONV2D")
x = self.conv(x)
nvtx.range_pop()
nvtx.range_push("SCALE")
x = x * 0.10000000
nvtx.range_pop()
nvtx.range_push("MAXIMUM")
x = self.maximum(x)
nvtx.range_pop()
nvtx.range_push("MAXPOOL")
x = self.max_pool(x)
nvtx.range_pop()
nvtx.range_push("RESHAPE")
x = x.view(x.size(0), -1)
nvtx.range_pop()
nvtx.range_push("FULLY_CONNECTED")
x = self.fc(x)
nvtx.range_pop()
return x
def train_model(model, criterion, optimizer, train_loader, num_epochs):
scaler = GradScaler()
for epoch in range(num_epochs):
nvtx.range_push(f"Epoch {epoch+1}")
print(f"Epoch {epoch+1}/{num_epochs}")
start_time = time.time()
model.train()
for batch_x, batch_y in train_loader:
nvtx.range_push("Batch")
batch_x, batch_y = batch_x.to(device), batch_y.to(device)
with autocast():
nvtx.range_push("Forward")
outputs = model(batch_x)
loss = criterion(outputs, batch_y)
nvtx.range_pop() # Forward
optimizer.zero_grad()
nvtx.range_push("Backward")
scaler.scale(loss).backward()
nvtx.range_pop() # Backward
nvtx.range_push("Optimizer step")
scaler.step(optimizer)
scaler.update()
nvtx.range_pop() # Optimizer step
nvtx.range_pop() # Batch
print(f"Loss: {loss.item():.4f}")
print(f"Time taken: {time.time() - start_time:.2f} seconds")
print("------")
torch.cuda.empty_cache()
nvtx.range_pop() # Epoch
def main():
num_samples = 1024 # Multiple of 32 for better GPU utilization
batch_size = 32
model = CNNModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
# Generate dummy data (in FP16, with 8 input channels)
x_train = torch.randn(num_samples, 8, 32, 32, dtype=torch.float16, device=device)
y_train = torch.randint(0, 16, (num_samples,), device=device)
train_dataset = TensorDataset(x_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# Warmup run
print("Warmup run...")
train_model(model, criterion, optimizer, train_loader, num_epochs=1)
# Profiled run
print("Profiled run...")
train_model(model, criterion, optimizer, train_loader, num_epochs=5)
# Inference profiling
model.eval()
with torch.no_grad(), autocast():
print("Inference profiling...")
test_batch = x_train[:batch_size]
for _ in range(10): # Run inference 10 times for better profiling
predictions = model(test_batch)
if __name__ == "__main__":
main()Editor is loading...
Leave a Comment