Untitled

mail@pastecode.io avatar
unknown
python
3 years ago
5.3 kB
2
Indexable
Never
import torch
import torch.nn as nn

class ResidualBlock(torch.nn.Module):
    def __init__(self, channels_count):
        super(ResidualBlock, self).__init__()

        self.conv0 = torch.nn.Conv2d(channels_count, channels_count, kernel_size=3, padding=1)
        self.act0  = torch.nn.ReLU()

        self.conv1 = torch.nn.Conv2d(channels_count, channels_count, kernel_size=3, padding=1)
        self.act1  = torch.nn.ReLU()

        torch.nn.init.orthogonal_(self.conv0.weight, 0.1)
        torch.nn.init.zeros_(self.conv0.bias)

        torch.nn.init.orthogonal_(self.conv1.weight, 0.1)
        torch.nn.init.zeros_(self.conv1.bias)

    def forward(self, x):
        y = self.conv0(x)
        y = self.act0(y)

        y = self.conv1(y)
        y = self.act1(y + x)

        return y


class ModelPolicy(torch.nn.Module):
    def __init__(self, input_shape, outputs_count):
        super(ModelPolicy, self).__init__()

        self.layers = [
            torch.nn.Conv2d(input_shape[0], 8, kernel_size=1),
            torch.nn.ReLU(),

            torch.nn.Flatten(),

            torch.nn.Linear(8*input_shape[1]*input_shape[2], 512),
            torch.nn.ReLU(),
 
            torch.nn.Linear(512, outputs_count)
        ]
        
        for i in range(len(self.layers)):
            if hasattr(self.layers[i], "weight"):
                torch.nn.init.orthogonal_(self.layers[i].weight, 0.1)
                torch.nn.init.zeros_(self.layers[i].bias)

        self.model = nn.Sequential(*self.layers)
     
    def forward(self, x):
        return self.model(x)



class ModelCritic(torch.nn.Module):
    def __init__(self, input_shape):
        super(ModelCritic, self).__init__()

        self.layers_features = [
            torch.nn.Conv2d(input_shape[0], 8, kernel_size=1),
            torch.nn.ReLU(),

            torch.nn.Flatten(),

            torch.nn.Linear(8*input_shape[1]*input_shape[2], 256),
            torch.nn.ReLU(),

            torch.nn.Flatten()
        ]

        self.model_features = nn.Sequential(*self.layers_features)

        self.ext_value  = torch.nn.Linear(256, 1)
        self.int_value  = torch.nn.Linear(256, 1)
        
        torch.nn.init.orthogonal_(self.layers_features[0].weight, 0.1)
        torch.nn.init.zeros_(self.layers_features[0].bias)

        torch.nn.init.orthogonal_(self.layers_features[3].weight, 0.1)
        torch.nn.init.zeros_(self.layers_features[3].bias)

        torch.nn.init.orthogonal_(self.ext_value.weight, 0.1)
        torch.nn.init.zeros_(self.ext_value.bias)

        torch.nn.init.orthogonal_(self.int_value.weight, 0.1)
        torch.nn.init.zeros_(self.int_value.bias) 

     
    def forward(self, x):
        y = self.model_features(x)
        return self.ext_value(y), self.int_value(y)
     

 
class Model(torch.nn.Module):

    def __init__(self, input_shape, outputs_count):
        super(Model, self).__init__()

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        blocks_count   = 16
        channels_count = 256
    
        features_shape = (channels_count, input_shape[1], input_shape[2])
        
        self.layers_features = [   
            nn.Conv2d(input_shape[0], channels_count, kernel_size=3, padding=1),
            nn.ReLU()
        ]

        for i in range(blocks_count):
            self.layers_features.append(ResidualBlock(channels_count))
            
           
        self.model_features = nn.Sequential(*self.layers_features)
        self.model_features.to(self.device) 

        self.model_policy = ModelPolicy(features_shape, outputs_count)
        self.model_policy.to(self.device)

        self.model_critic = ModelCritic(features_shape)
        self.model_critic.to(self.device) 
        
        print("model_ppo")
        print(self.model_features)
        print(self.model_policy)
        print(self.model_critic)
        print("\n\n")

    def forward(self, state):
        features                = self.model_features(state)

        policy                  = self.model_policy(features)
        ext_value, int_value    = self.model_critic(features)

        return policy, ext_value, int_value

    def save(self, path):
        print("saving ", path)

        torch.save(self.model_features.state_dict(), path + "model_features.pt")
        torch.save(self.model_policy.state_dict(), path + "model_policy.pt")
        torch.save(self.model_critic.state_dict(), path + "model_critic.pt")

    def load(self, path):
        print("loading ", path) 

        self.model_features.load_state_dict(torch.load(path + "model_features.pt", map_location = self.device))
        self.model_policy.load_state_dict(torch.load(path + "model_policy.pt", map_location = self.device))
        self.model_critic.load_state_dict(torch.load(path + "model_critic.pt", map_location = self.device))
        
        self.model_features.eval() 
        self.model_policy.eval() 
        self.model_critic.eval()


if __name__ == "__main__":
    state_shape     = (10, 8, 8)
    actions_count   = 500
    batch_size      = 32 

    model           = Model(state_shape, actions_count)

    state           = torch.randn((batch_size, ) + state_shape)
    
    policy, ext_value, int_value = model(state)

    print("shape = ", policy.shape, ext_value.shape, int_value.shape)