Untitled
unknown
julia
2 years ago
2.3 kB
12
Indexable
using Flux,GraphNeuralNetworks,Statistics
Random.seed!(42)
# demo data
adj = sparse([1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
1 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0
0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0
0 0 0 0 1 1 1 0 0 0 1 0 0 0 0 0 0 0
0 1 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 1 0 1 1 0 0 0 0 0 0 0
0 0 0 0 0 1 0 0 0 1 1 1 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0
0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0
0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 1
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1])
nodefeat = randn(Float32, 33, 18)
data = GNNGraph(adj,ndata=(;x=nodefeat))
struct GAE
enc
dec
end
Flux.@functor GAE
function (m::GAE)(g::GNNGraph)
emb = m.enc(g, g.x)
adj = m.dec(emb)
(emb = emb, adj = adj)
end
enc = GNNChain(
Dense(33 => 64,relu),
GATConv(64 => 32,relu),
Dense(32 => 32)
)
dec = Chain(
x -> σ(x' * x),
relu
)
model = GAE(enc, dec )
p = model(data)
# compare the adj before training
sparse(p.adj .> 0.5)
adjacency_matrix(data)
# loss = Probability of positive and negative edge link
loss(model, data, pos_g=nothing, neg_g= nothing) = begin
if pos_g == nothing
pos_g = hcat(data.graph[1:2]...)
end
if neg_g == nothing
neg_g = negative_sample(data) |> x -> hcat(x.graph[1:2]...)
end
tmp = model(data).adj + I # add eye, cause node are self-connected.
EPS = eps(Float32)
-log(mean(tmp[pos_g]) .+ EPS) - log(mean( 1 .- tmp[neg_g]) .+ EPS)
end
# training
ps = Flux.params(model)
opt = ADAM(0.001)
for i in 1:100
Flux.train!(loss, ps, [(model, data)], opt)
println(loss(model,data))
end
# compare the graphs after training.
p = model(data)
sparse(p.adj .> 0.5)
adjacency_matrix(data)
Editor is loading...
Leave a Comment