Untitled
unknown
python
3 years ago
5.3 kB
2
Indexable
Never
import torch import torch.nn as nn class ResidualBlock(torch.nn.Module): def __init__(self, channels_count): super(ResidualBlock, self).__init__() self.conv0 = torch.nn.Conv2d(channels_count, channels_count, kernel_size=3, padding=1) self.act0 = torch.nn.ReLU() self.conv1 = torch.nn.Conv2d(channels_count, channels_count, kernel_size=3, padding=1) self.act1 = torch.nn.ReLU() torch.nn.init.orthogonal_(self.conv0.weight, 0.1) torch.nn.init.zeros_(self.conv0.bias) torch.nn.init.orthogonal_(self.conv1.weight, 0.1) torch.nn.init.zeros_(self.conv1.bias) def forward(self, x): y = self.conv0(x) y = self.act0(y) y = self.conv1(y) y = self.act1(y + x) return y class ModelPolicy(torch.nn.Module): def __init__(self, input_shape, outputs_count): super(ModelPolicy, self).__init__() self.layers = [ torch.nn.Conv2d(input_shape[0], 8, kernel_size=1), torch.nn.ReLU(), torch.nn.Flatten(), torch.nn.Linear(8*input_shape[1]*input_shape[2], 512), torch.nn.ReLU(), torch.nn.Linear(512, outputs_count) ] for i in range(len(self.layers)): if hasattr(self.layers[i], "weight"): torch.nn.init.orthogonal_(self.layers[i].weight, 0.1) torch.nn.init.zeros_(self.layers[i].bias) self.model = nn.Sequential(*self.layers) def forward(self, x): return self.model(x) class ModelCritic(torch.nn.Module): def __init__(self, input_shape): super(ModelCritic, self).__init__() self.layers_features = [ torch.nn.Conv2d(input_shape[0], 8, kernel_size=1), torch.nn.ReLU(), torch.nn.Flatten(), torch.nn.Linear(8*input_shape[1]*input_shape[2], 256), torch.nn.ReLU(), torch.nn.Flatten() ] self.model_features = nn.Sequential(*self.layers_features) self.ext_value = torch.nn.Linear(256, 1) self.int_value = torch.nn.Linear(256, 1) torch.nn.init.orthogonal_(self.layers_features[0].weight, 0.1) torch.nn.init.zeros_(self.layers_features[0].bias) torch.nn.init.orthogonal_(self.layers_features[3].weight, 0.1) torch.nn.init.zeros_(self.layers_features[3].bias) torch.nn.init.orthogonal_(self.ext_value.weight, 0.1) torch.nn.init.zeros_(self.ext_value.bias) torch.nn.init.orthogonal_(self.int_value.weight, 0.1) torch.nn.init.zeros_(self.int_value.bias) def forward(self, x): y = self.model_features(x) return self.ext_value(y), self.int_value(y) class Model(torch.nn.Module): def __init__(self, input_shape, outputs_count): super(Model, self).__init__() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") blocks_count = 16 channels_count = 256 features_shape = (channels_count, input_shape[1], input_shape[2]) self.layers_features = [ nn.Conv2d(input_shape[0], channels_count, kernel_size=3, padding=1), nn.ReLU() ] for i in range(blocks_count): self.layers_features.append(ResidualBlock(channels_count)) self.model_features = nn.Sequential(*self.layers_features) self.model_features.to(self.device) self.model_policy = ModelPolicy(features_shape, outputs_count) self.model_policy.to(self.device) self.model_critic = ModelCritic(features_shape) self.model_critic.to(self.device) print("model_ppo") print(self.model_features) print(self.model_policy) print(self.model_critic) print("\n\n") def forward(self, state): features = self.model_features(state) policy = self.model_policy(features) ext_value, int_value = self.model_critic(features) return policy, ext_value, int_value def save(self, path): print("saving ", path) torch.save(self.model_features.state_dict(), path + "model_features.pt") torch.save(self.model_policy.state_dict(), path + "model_policy.pt") torch.save(self.model_critic.state_dict(), path + "model_critic.pt") def load(self, path): print("loading ", path) self.model_features.load_state_dict(torch.load(path + "model_features.pt", map_location = self.device)) self.model_policy.load_state_dict(torch.load(path + "model_policy.pt", map_location = self.device)) self.model_critic.load_state_dict(torch.load(path + "model_critic.pt", map_location = self.device)) self.model_features.eval() self.model_policy.eval() self.model_critic.eval() if __name__ == "__main__": state_shape = (10, 8, 8) actions_count = 500 batch_size = 32 model = Model(state_shape, actions_count) state = torch.randn((batch_size, ) + state_shape) policy, ext_value, int_value = model(state) print("shape = ", policy.shape, ext_value.shape, int_value.shape)