Untitled
unknown
julia
3 years ago
3.1 kB
43
No Index
# a toy model for check the time and BLAS threads count
using MKL, LinearAlgebra
using Flux, Statistics, Random, MLDatasets, CUDA, Dates
using Flux.Data: DataLoader
using Flux.Losses: logitbinarycrossentropy
using Base: @kwdef
@kwdef struct Args_Gan
latent_dim = 100
η_gen = 1e-3
η_dic = 1e-3
epoch = 1
batch_size = 128
seed = 1234
end
function Discriminator()
Chain(
Conv((4,4), 1 => 6, x -> leakyrelu(x,0.2f0) ; stride = 2, pad = 1),
Dropout(0.25),
Conv((4,4), 6 => 12, x -> leakyrelu(x,0.2f0) ; stride = 2, pad = 1),
Dropout(0.25),
x -> reshape(x, 7*7*12,:),
Dense(7*7*12,1)
)
end
function Generator(latent_dim::Int)
Chain(
Dense(latent_dim, 7*7*25),
BatchNorm(7*7*25,relu),
x -> reshape(x, 7,7,25,:),
ConvTranspose((5,5), 25 => 12; stride = 1, pad = 2),
BatchNorm(12,relu),
ConvTranspose((4, 4), 12 => 4; stride = 2, pad = 1),
BatchNorm(4, relu),
ConvTranspose((4, 4), 4 => 1; stride = 2, pad = 1),
x -> tanh.(x)
)
end
function loss_disc(trueInput, fakeInput)
loss_t = logitbinarycrossentropy( trueInput, 1 )
loss_f = logitbinarycrossentropy( fakeInput, 0 )
loss_t + loss_f
end
loss_gen(fakeResult) = logitbinarycrossentropy( fakeResult, 1 )
function get_data(args)
xtrain, ytrain = MLDatasets.MNIST.traindata(Float32)
# Normalize to [-1, 1]
@. xtrain = 2 * xtrain - 1
noise = randn(Float32,args.latent_dim,size(xtrain,3))
xtrain = reshape(xtrain, 28,28,1,:)
DataLoader((xtrain,noise), batchsize=args.batch_size, shuffle=true)
end
function train(;kws...)
args = Args_Gan(; kws...)
args.seed > 0 && Random.seed!(args.seed)
@info "Training start"
# load MNIST images
trainData = get_data(args)
# initialize generator and discrimator
generator = Generator(args.latent_dim)
discrimator = Discriminator()
ps_gen = Flux.params(generator)
ps_disc = Flux.params(discrimator)
# Optimizers
opt_disc = ADAM(args.η_dic)
opt_gen = ADAM(args.η_gen)
# Training
for epoch in 1:args.epoch
@info "epoch $(epoch)"
loss1,loss2 = 0.0 ,0.0
for (true_in, rand_in) in trainData
loss1, back = Flux.pullback(ps_gen) do
fakeImg = generator(rand_in)
fakeOutput = discrimator(fakeImg)
loss_gen(fakeOutput)
end
grad = back(1f0)
Flux.update!(opt_gen, ps_gen, grad)
loss2, back = Flux.pullback(ps_disc) do
fakeImg = generator(rand_in)
loss_disc( discrimator(true_in), discrimator(fakeImg))
end
grad = back(1f0)
Flux.update!(opt_disc, ps_disc, grad)
end
println( "\n\n" )
end
end
train(epoch=1) # first run
for i in reverse(1:1:8)
BLAS.set_num_threads(i)
@info "$(now())||| BLAS threads is $(BLAS.get_num_threads())"
@info "$(now())||| Threads.nthreads is $(Threads.nthreads())"
@time train(epoch=1)
endEditor is loading...