Untitled
unknown
julia
a year ago
2.3 kB
11
Indexable
using Flux,GraphNeuralNetworks,Statistics Random.seed!(42) # demo data adj = sparse([1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1]) nodefeat = randn(Float32, 33, 18) data = GNNGraph(adj,ndata=(;x=nodefeat)) struct GAE enc dec end Flux.@functor GAE function (m::GAE)(g::GNNGraph) emb = m.enc(g, g.x) adj = m.dec(emb) (emb = emb, adj = adj) end enc = GNNChain( Dense(33 => 64,relu), GATConv(64 => 32,relu), Dense(32 => 32) ) dec = Chain( x -> σ(x' * x), relu ) model = GAE(enc, dec ) p = model(data) # compare the adj before training sparse(p.adj .> 0.5) adjacency_matrix(data) # loss = Probability of positive and negative edge link loss(model, data, pos_g=nothing, neg_g= nothing) = begin if pos_g == nothing pos_g = hcat(data.graph[1:2]...) end if neg_g == nothing neg_g = negative_sample(data) |> x -> hcat(x.graph[1:2]...) end tmp = model(data).adj + I # add eye, cause node are self-connected. EPS = eps(Float32) -log(mean(tmp[pos_g]) .+ EPS) - log(mean( 1 .- tmp[neg_g]) .+ EPS) end # training ps = Flux.params(model) opt = ADAM(0.001) for i in 1:100 Flux.train!(loss, ps, [(model, data)], opt) println(loss(model,data)) end # compare the graphs after training. p = model(data) sparse(p.adj .> 0.5) adjacency_matrix(data)
Editor is loading...
Leave a Comment