cartpole v-1

cartpole v-1
mail@pastecode.io avatar
unknown
python
2 years ago
2.6 kB
8
Indexable
Never
import gym
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import pygad.kerasga
import pygad


def fitness_func(solution, sol_idx):
    global keras_ga, model, observation_space_size, env

    model_weights_matrix = pygad.kerasga.model_weights_as_matrix(model=model, weights_vector=solution)
    model.set_weights(weights=model_weights_matrix)

    # play game
    observation = env.reset()
    sum_reward = 0
    done = False
    c = 0
    while (not done) and c<1000:
        state = np.reshape(observation, [1, observation_space_size])
        q_values = model.predict(state)
        action = np.argmax(q_values[0])
        observation_next, reward, done, info = env.step(action)
        observation = observation_next
        sum_reward += reward
        c += 1

    return sum_reward


def callback_generation(ga_instance):
    print("Generation = {generation}".format(generation=ga_instance.generations_completed))
    print("Fitness    = {fitness}".format(fitness=ga_instance.best_solution()[1]))


env = gym.make("CartPole-v1")
observation_space_size = env.observation_space.shape[0]
action_space_size = env.action_space.n

model = Sequential()
model.add(Dense(16, input_shape=(observation_space_size,), activation='relu'))
model.add(Dense(16, activation='relu'))
model.add(Dense(action_space_size, activation='linear'))
model.summary()

keras_ga = pygad.kerasga.KerasGA(model=model, num_solutions=10)

ga_instance = pygad.GA(num_generations=25,
                       num_parents_mating=5,
                       initial_population=keras_ga.population_weights,
                       fitness_func=fitness_func,
                       parent_selection_type="sss",
                       crossover_type="single_point",
                       mutation_type="random",
                       mutation_percent_genes=10,
                       keep_parents=-1,
                       on_generation=callback_generation)

ga_instance.run()

ga_instance.plot_result(title="PyGAD & Keras - Iteration vs. Fitness", linewidth=4)

solution, solution_fitness, solution_idx = ga_instance.best_solution()
print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness))
print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx))

model_weights_matrix = pygad.kerasga.model_weights_as_matrix(model=model, weights_vector=solution)
model.set_weights(weights=model_weights_matrix)
model.save("cartpole_weights")