Untitled

mail@pastecode.io avatar
unknown
julia
a month ago
2.3 kB
8
Indexable
Never
using Flux,GraphNeuralNetworks,Statistics
Random.seed!(42)
# demo data 
adj =  sparse([1  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0
    1  1  1  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0
    0  1  1  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0
    0  0  1  1  1  0  0  0  0  0  0  0  0  0  0  0  0  0
    0  0  0  1  1  1  0  0  0  0  0  0  0  0  0  0  0  0
    0  0  0  0  1  1  1  0  0  0  1  0  0  0  0  0  0  0
    0  1  0  0  0  1  1  1  0  0  0  0  0  0  0  0  0  0
    0  0  0  0  0  0  1  1  1  1  0  0  0  0  0  0  0  0
    0  0  0  0  0  0  0  1  1  0  0  0  0  0  0  0  0  0
    0  0  0  0  0  0  0  1  0  1  1  0  0  0  0  0  0  0
    0  0  0  0  0  1  0  0  0  1  1  1  0  0  0  0  0  0
    0  0  0  0  0  0  0  0  0  0  1  1  1  0  0  0  0  0
    0  0  0  0  0  0  0  0  0  0  0  1  1  1  0  0  0  0
    0  0  0  0  0  0  0  0  0  0  0  0  1  1  1  0  1  0
    0  0  0  0  0  0  0  0  0  0  0  0  0  1  1  1  0  0
    0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  1  0  0
    0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  1  1
    0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  1])
nodefeat  = randn(Float32, 33, 18)
data = GNNGraph(adj,ndata=(;x=nodefeat))

struct GAE
    enc
    dec
end
Flux.@functor GAE

function (m::GAE)(g::GNNGraph)
    emb = m.enc(g, g.x)
    adj = m.dec(emb)
    (emb = emb, adj = adj)
end

enc = GNNChain(
    Dense(33 => 64,relu),
    GATConv(64 => 32,relu),
    Dense(32 => 32)
)
dec = Chain(
    x -> σ(x' * x),
    relu
)

model = GAE(enc, dec )

p = model(data)
# compare the adj before training
sparse(p.adj .> 0.5) 
adjacency_matrix(data)
# loss = Probability of positive and negative edge link 
loss(model, data, pos_g=nothing, neg_g= nothing) = begin
    if pos_g == nothing
        pos_g = hcat(data.graph[1:2]...)
    end
    if neg_g == nothing
        neg_g = negative_sample(data) |> x -> hcat(x.graph[1:2]...)
    end
    tmp = model(data).adj + I # add eye, cause node are self-connected.
    EPS = eps(Float32)
    -log(mean(tmp[pos_g]) .+ EPS) - log(mean( 1 .- tmp[neg_g]) .+ EPS)
end
# training
ps = Flux.params(model)
opt = ADAM(0.001)
for i in 1:100
    Flux.train!(loss, ps, [(model, data)], opt)
    println(loss(model,data))
end
# compare the graphs after training.
p = model(data)
sparse(p.adj .> 0.5)
adjacency_matrix(data)
Leave a Comment