Untitled

mail@pastecode.io avatar
unknown
julia
2 years ago
3.1 kB
17
No Index
Never
# a toy model for check the time and BLAS threads count
using MKL, LinearAlgebra
using Flux, Statistics, Random, MLDatasets, CUDA, Dates
using Flux.Data: DataLoader
using Flux.Losses: logitbinarycrossentropy
using Base: @kwdef

@kwdef struct Args_Gan
    latent_dim = 100
    η_gen = 1e-3
    η_dic = 1e-3
    epoch = 1
    batch_size = 128
    seed = 1234
end

function Discriminator()
    Chain(
        Conv((4,4), 1 => 6, x -> leakyrelu(x,0.2f0) ; stride = 2, pad = 1),
        Dropout(0.25),
        Conv((4,4), 6 => 12, x -> leakyrelu(x,0.2f0) ; stride = 2, pad = 1),
        Dropout(0.25),
        x -> reshape(x, 7*7*12,:),
        Dense(7*7*12,1)
    )
end

function Generator(latent_dim::Int)
    Chain(
        Dense(latent_dim, 7*7*25),
        BatchNorm(7*7*25,relu),
        x -> reshape(x, 7,7,25,:),
        ConvTranspose((5,5), 25 => 12; stride = 1, pad = 2),
        BatchNorm(12,relu),
        ConvTranspose((4, 4), 12 => 4; stride = 2, pad = 1),
        BatchNorm(4, relu),
        ConvTranspose((4, 4), 4 => 1; stride = 2, pad = 1),
        x -> tanh.(x)
    )
end

function loss_disc(trueInput, fakeInput)
    loss_t = logitbinarycrossentropy( trueInput, 1 )
    loss_f = logitbinarycrossentropy( fakeInput, 0 )

    loss_t + loss_f
end

loss_gen(fakeResult) = logitbinarycrossentropy( fakeResult, 1 )

function get_data(args)
    xtrain, ytrain =  MLDatasets.MNIST.traindata(Float32)
    # Normalize to [-1, 1]
    @. xtrain =  2 * xtrain - 1
    noise  = randn(Float32,args.latent_dim,size(xtrain,3))
    xtrain = reshape(xtrain, 28,28,1,:)

    DataLoader((xtrain,noise), batchsize=args.batch_size, shuffle=true)
end

function train(;kws...)
    args = Args_Gan(; kws...)
    args.seed > 0 && Random.seed!(args.seed)
    @info "Training start"
 
    # load MNIST images
    trainData = get_data(args)

    # initialize generator and discrimator
    generator = Generator(args.latent_dim)
    discrimator = Discriminator()

    ps_gen = Flux.params(generator)
    ps_disc = Flux.params(discrimator)

    # Optimizers
    opt_disc = ADAM(args.η_dic)
    opt_gen = ADAM(args.η_gen)

    # Training
    for epoch in 1:args.epoch
        @info "epoch $(epoch)"
        loss1,loss2 = 0.0 ,0.0
        for (true_in, rand_in) in trainData
            loss1, back = Flux.pullback(ps_gen) do
                fakeImg =  generator(rand_in)
                fakeOutput = discrimator(fakeImg)
                loss_gen(fakeOutput)
            end
            grad = back(1f0)
            Flux.update!(opt_gen, ps_gen, grad)

            loss2, back = Flux.pullback(ps_disc) do
                fakeImg =  generator(rand_in)
                loss_disc( discrimator(true_in), discrimator(fakeImg))
            end
            grad = back(1f0)
            Flux.update!(opt_disc, ps_disc, grad)
        end

        println( "\n\n" ) 
    end
end 

train(epoch=1) #  first run  

for i in reverse(1:1:8)
    BLAS.set_num_threads(i)
    @info "$(now())||| BLAS threads is $(BLAS.get_num_threads())"
    @info "$(now())||| Threads.nthreads is $(Threads.nthreads())"
    @time train(epoch=1)
end