cartpole v-1
cartpole v-1unknown
python
4 years ago
2.6 kB
10
Indexable
import gym import numpy as np from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense import pygad.kerasga import pygad def fitness_func(solution, sol_idx): global keras_ga, model, observation_space_size, env model_weights_matrix = pygad.kerasga.model_weights_as_matrix(model=model, weights_vector=solution) model.set_weights(weights=model_weights_matrix) # play game observation = env.reset() sum_reward = 0 done = False c = 0 while (not done) and c<1000: state = np.reshape(observation, [1, observation_space_size]) q_values = model.predict(state) action = np.argmax(q_values[0]) observation_next, reward, done, info = env.step(action) observation = observation_next sum_reward += reward c += 1 return sum_reward def callback_generation(ga_instance): print("Generation = {generation}".format(generation=ga_instance.generations_completed)) print("Fitness = {fitness}".format(fitness=ga_instance.best_solution()[1])) env = gym.make("CartPole-v1") observation_space_size = env.observation_space.shape[0] action_space_size = env.action_space.n model = Sequential() model.add(Dense(16, input_shape=(observation_space_size,), activation='relu')) model.add(Dense(16, activation='relu')) model.add(Dense(action_space_size, activation='linear')) model.summary() keras_ga = pygad.kerasga.KerasGA(model=model, num_solutions=10) ga_instance = pygad.GA(num_generations=25, num_parents_mating=5, initial_population=keras_ga.population_weights, fitness_func=fitness_func, parent_selection_type="sss", crossover_type="single_point", mutation_type="random", mutation_percent_genes=10, keep_parents=-1, on_generation=callback_generation) ga_instance.run() ga_instance.plot_result(title="PyGAD & Keras - Iteration vs. Fitness", linewidth=4) solution, solution_fitness, solution_idx = ga_instance.best_solution() print("Fitness value of the best solution = {solution_fitness}".format(solution_fitness=solution_fitness)) print("Index of the best solution : {solution_idx}".format(solution_idx=solution_idx)) model_weights_matrix = pygad.kerasga.model_weights_as_matrix(model=model, weights_vector=solution) model.set_weights(weights=model_weights_matrix) model.save("cartpole_weights")
Editor is loading...