Untitled
unknown
plain_text
10 months ago
2.7 kB
7
Indexable
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.input = nn.Linear(4, 42)
self.batch_norm1 = nn.BatchNorm1d(42)
self.activation = nn.ReLU()
self.hidden = nn.Linear(42, 42)
self.batch_norm2 = nn.BatchNorm1d(42)
self.output = nn.Linear(42, 3)
self.fix_dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.input(x)
x = self.batch_norm1(x)
x = self.activation(x)
x = self.hidden(x)
x = self.batch_norm2(x)
x = self.activation(x)
x = self.fix_dropout(x)
x = self.output(x)
return x
def main():
###################### В этом фрагменте всё ок. #############################
dataset = pd.read_csv("IRIS_DATA.csv")
features = ["SepalLengthCm", "SepalWidthCm", "PetalLengthCm", "PetalWidthCm"]
X = torch.from_numpy(dataset[features].values).type(torch.float32)
y = torch.from_numpy(dataset["Species"].values).type(torch.long)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y)
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)
kwargs = {"batch_size": 42, "num_workers": 1, "pin_memory": True, "shuffle": True}
train_loader = torch.utils.data.DataLoader(train_dataset, **kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, **kwargs)
#######################################################################################
model = Net().cuda() # device = "cuda" if torch.cuda.available() else "cpu"
optimizer = optim.SGD(model.parameters(), lr=1e-6) # Change learning rate
for epoch in range(100):
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data.cuda())
loss = F.cross_entropy(output, target.cuda())
loss.backward()
optimizer.step()
optimizer.zero_grad()
if batch_idx % LOG_INTERVAL == 0:
print(f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)}\tLoss: {loss.item()}")
correct = 0
model.eval()
with torch.inference_mode(): #torch.no_grad()
for data, target in test_loader:
output = model(data.cuda())
pred = output.argmax(dim=1, keepdim=True).cpu()
correct += pred.eq(target.view_as(pred)).mean().item()
print("Accuracy: {:.0f}%".format(100.0 * correct / len(test_loader)))
torch.save(model.state_dict(), "iris_dl_model.pt")Editor is loading...
Leave a Comment