Untitled
unknown
julia
3 years ago
3.1 kB
32
No Index
# a toy model for check the time and BLAS threads count using MKL, LinearAlgebra using Flux, Statistics, Random, MLDatasets, CUDA, Dates using Flux.Data: DataLoader using Flux.Losses: logitbinarycrossentropy using Base: @kwdef @kwdef struct Args_Gan latent_dim = 100 η_gen = 1e-3 η_dic = 1e-3 epoch = 1 batch_size = 128 seed = 1234 end function Discriminator() Chain( Conv((4,4), 1 => 6, x -> leakyrelu(x,0.2f0) ; stride = 2, pad = 1), Dropout(0.25), Conv((4,4), 6 => 12, x -> leakyrelu(x,0.2f0) ; stride = 2, pad = 1), Dropout(0.25), x -> reshape(x, 7*7*12,:), Dense(7*7*12,1) ) end function Generator(latent_dim::Int) Chain( Dense(latent_dim, 7*7*25), BatchNorm(7*7*25,relu), x -> reshape(x, 7,7,25,:), ConvTranspose((5,5), 25 => 12; stride = 1, pad = 2), BatchNorm(12,relu), ConvTranspose((4, 4), 12 => 4; stride = 2, pad = 1), BatchNorm(4, relu), ConvTranspose((4, 4), 4 => 1; stride = 2, pad = 1), x -> tanh.(x) ) end function loss_disc(trueInput, fakeInput) loss_t = logitbinarycrossentropy( trueInput, 1 ) loss_f = logitbinarycrossentropy( fakeInput, 0 ) loss_t + loss_f end loss_gen(fakeResult) = logitbinarycrossentropy( fakeResult, 1 ) function get_data(args) xtrain, ytrain = MLDatasets.MNIST.traindata(Float32) # Normalize to [-1, 1] @. xtrain = 2 * xtrain - 1 noise = randn(Float32,args.latent_dim,size(xtrain,3)) xtrain = reshape(xtrain, 28,28,1,:) DataLoader((xtrain,noise), batchsize=args.batch_size, shuffle=true) end function train(;kws...) args = Args_Gan(; kws...) args.seed > 0 && Random.seed!(args.seed) @info "Training start" # load MNIST images trainData = get_data(args) # initialize generator and discrimator generator = Generator(args.latent_dim) discrimator = Discriminator() ps_gen = Flux.params(generator) ps_disc = Flux.params(discrimator) # Optimizers opt_disc = ADAM(args.η_dic) opt_gen = ADAM(args.η_gen) # Training for epoch in 1:args.epoch @info "epoch $(epoch)" loss1,loss2 = 0.0 ,0.0 for (true_in, rand_in) in trainData loss1, back = Flux.pullback(ps_gen) do fakeImg = generator(rand_in) fakeOutput = discrimator(fakeImg) loss_gen(fakeOutput) end grad = back(1f0) Flux.update!(opt_gen, ps_gen, grad) loss2, back = Flux.pullback(ps_disc) do fakeImg = generator(rand_in) loss_disc( discrimator(true_in), discrimator(fakeImg)) end grad = back(1f0) Flux.update!(opt_disc, ps_disc, grad) end println( "\n\n" ) end end train(epoch=1) # first run for i in reverse(1:1:8) BLAS.set_num_threads(i) @info "$(now())||| BLAS threads is $(BLAS.get_num_threads())" @info "$(now())||| Threads.nthreads is $(Threads.nthreads())" @time train(epoch=1) end
Editor is loading...