Untitled
unknown
python
4 years ago
5.3 kB
11
Indexable
import torch
import torch.nn as nn
class ResidualBlock(torch.nn.Module):
def __init__(self, channels_count):
super(ResidualBlock, self).__init__()
self.conv0 = torch.nn.Conv2d(channels_count, channels_count, kernel_size=3, padding=1)
self.act0 = torch.nn.ReLU()
self.conv1 = torch.nn.Conv2d(channels_count, channels_count, kernel_size=3, padding=1)
self.act1 = torch.nn.ReLU()
torch.nn.init.orthogonal_(self.conv0.weight, 0.1)
torch.nn.init.zeros_(self.conv0.bias)
torch.nn.init.orthogonal_(self.conv1.weight, 0.1)
torch.nn.init.zeros_(self.conv1.bias)
def forward(self, x):
y = self.conv0(x)
y = self.act0(y)
y = self.conv1(y)
y = self.act1(y + x)
return y
class ModelPolicy(torch.nn.Module):
def __init__(self, input_shape, outputs_count):
super(ModelPolicy, self).__init__()
self.layers = [
torch.nn.Conv2d(input_shape[0], 8, kernel_size=1),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(8*input_shape[1]*input_shape[2], 512),
torch.nn.ReLU(),
torch.nn.Linear(512, outputs_count)
]
for i in range(len(self.layers)):
if hasattr(self.layers[i], "weight"):
torch.nn.init.orthogonal_(self.layers[i].weight, 0.1)
torch.nn.init.zeros_(self.layers[i].bias)
self.model = nn.Sequential(*self.layers)
def forward(self, x):
return self.model(x)
class ModelCritic(torch.nn.Module):
def __init__(self, input_shape):
super(ModelCritic, self).__init__()
self.layers_features = [
torch.nn.Conv2d(input_shape[0], 8, kernel_size=1),
torch.nn.ReLU(),
torch.nn.Flatten(),
torch.nn.Linear(8*input_shape[1]*input_shape[2], 256),
torch.nn.ReLU(),
torch.nn.Flatten()
]
self.model_features = nn.Sequential(*self.layers_features)
self.ext_value = torch.nn.Linear(256, 1)
self.int_value = torch.nn.Linear(256, 1)
torch.nn.init.orthogonal_(self.layers_features[0].weight, 0.1)
torch.nn.init.zeros_(self.layers_features[0].bias)
torch.nn.init.orthogonal_(self.layers_features[3].weight, 0.1)
torch.nn.init.zeros_(self.layers_features[3].bias)
torch.nn.init.orthogonal_(self.ext_value.weight, 0.1)
torch.nn.init.zeros_(self.ext_value.bias)
torch.nn.init.orthogonal_(self.int_value.weight, 0.1)
torch.nn.init.zeros_(self.int_value.bias)
def forward(self, x):
y = self.model_features(x)
return self.ext_value(y), self.int_value(y)
class Model(torch.nn.Module):
def __init__(self, input_shape, outputs_count):
super(Model, self).__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
blocks_count = 16
channels_count = 256
features_shape = (channels_count, input_shape[1], input_shape[2])
self.layers_features = [
nn.Conv2d(input_shape[0], channels_count, kernel_size=3, padding=1),
nn.ReLU()
]
for i in range(blocks_count):
self.layers_features.append(ResidualBlock(channels_count))
self.model_features = nn.Sequential(*self.layers_features)
self.model_features.to(self.device)
self.model_policy = ModelPolicy(features_shape, outputs_count)
self.model_policy.to(self.device)
self.model_critic = ModelCritic(features_shape)
self.model_critic.to(self.device)
print("model_ppo")
print(self.model_features)
print(self.model_policy)
print(self.model_critic)
print("\n\n")
def forward(self, state):
features = self.model_features(state)
policy = self.model_policy(features)
ext_value, int_value = self.model_critic(features)
return policy, ext_value, int_value
def save(self, path):
print("saving ", path)
torch.save(self.model_features.state_dict(), path + "model_features.pt")
torch.save(self.model_policy.state_dict(), path + "model_policy.pt")
torch.save(self.model_critic.state_dict(), path + "model_critic.pt")
def load(self, path):
print("loading ", path)
self.model_features.load_state_dict(torch.load(path + "model_features.pt", map_location = self.device))
self.model_policy.load_state_dict(torch.load(path + "model_policy.pt", map_location = self.device))
self.model_critic.load_state_dict(torch.load(path + "model_critic.pt", map_location = self.device))
self.model_features.eval()
self.model_policy.eval()
self.model_critic.eval()
if __name__ == "__main__":
state_shape = (10, 8, 8)
actions_count = 500
batch_size = 32
model = Model(state_shape, actions_count)
state = torch.randn((batch_size, ) + state_shape)
policy, ext_value, int_value = model(state)
print("shape = ", policy.shape, ext_value.shape, int_value.shape)Editor is loading...