Untitled
unknown
plain_text
2 years ago
24 kB
8
Indexable
from typing import Optional, List, Tuple, Dict
from itertools import combinations
import torch
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from typing import Any
from scipy.spatial import ConvexHull
import numpy as np
from torch_geometric.transforms import RadiusGraph
from torch_geometric.utils import remove_self_loops
class CellularComplex:
def __init__(self,
adj: dict,
nodes_00: torch.Tensor,
edges_11: torch.Tensor,
coboundary_01: torch.Tensor,
coboundary_12: torch.Tensor,
cell_22: torch.Tensor,
edge_lookup: dict,
cycle_lookup: dict,
edge2id: dict,
cycle_edges_lookup: dict,
cycle_edges_bidirected_lookup: dict,
cycle_present: int,
cycle_total: int ):
"""
Cellular Complex
"""
self.adj = adj
self.coboundary_01 = coboundary_01
self.coboundary_12 = coboundary_12
self.nodes_00 = nodes_00
self.edges_11 = edges_11
self.cell_22 = cell_22
self.edge_lookup = edge_lookup
self.cycle_lookup = cycle_lookup
self.edge2id = edge2id
self.cycle_edges_lookup = cycle_edges_lookup
self.cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup
self.cycle_present = cycle_present
self.cycle_total = cycle_total
@classmethod
def from_nx_graph(cls, graph: nx.Graph, data: Data):
"""
This function is used to find all the adjacencies between 0->0 communication, 0->1, 1->1, 1->2, 2->2.
Also it returns the lookup dictionaries for each edge and cycle that might be needed when calculating the invariants. For cycles, only one orientation is considered.
"""
nodes_00 = []
coboundary_01 = []
edges_11 = []
coboundary_12 = []
edge_lookup = {}
edge2id = {}
send_00 = [] # send of vertex to vertex
rec_00 = [] # rec of vertex to vertex
send_01 = [] # send of vertex to edge
rec_01 = [] # rec of vertex to edge
digraph = graph.to_directed()
cycle_present = 0
cycle_total = 0
cycles = [cycle for cycle in nx.simple_cycles(digraph) if len(cycle) > 2] # find cycles
# if there are cycles, check and delete different orientations. keep only 1 per cycle
if len(cycles) != 0:
cycle_present += 1
dict_lists = {tuple(sorted(lst)): lst for lst in cycles}
unique_lists = list(dict_lists.values())
cycles = unique_lists
# if there are no cycles, conenct the two closest vertices until a cycle is formed
if len(cycles) == 0:
dist_matrix = torch.cdist(data.pos, data.pos)
dist_matrix.fill_diagonal_(float('inf'))
for edge in data.edge_index.t():
dist_matrix[edge[0], edge[1]] = float('inf')
dist_matrix[edge[1], edge[0]] = float('inf')
while True:
# Find the pair of nodes with the smallest distance
min_val = torch.min(dist_matrix)
min_index = torch.where(dist_matrix == min_val)
node1, node2 = min_index[0][0], min_index[1][0]
# Set the distance between these nodes to infinity so they won't be selected again
dist_matrix[node1, node2] = float('inf')
dist_matrix[node2, node1] = float('inf')
# Add an edge between these nodes
data.edge_index = torch.cat([data.edge_index, torch.tensor([[node1, node2], [node2, node1]], dtype=torch.long)], dim=1)
graph_new = nx.Graph()
graph_new.add_nodes_from(range(data.x.size(0)))
graph_new.add_edges_from(data.edge_index.T.tolist())
digraph_new = graph_new.to_directed()
digraph = digraph_new
graph = graph_new
cycles = [cycle for cycle in nx.simple_cycles(digraph_new) if len(cycle) > 2] # find cycles
if len(cycles) !=0:
dict_lists = {tuple(sorted(lst)): lst for lst in cycles}
unique_lists = list(dict_lists.values())
cycles = unique_lists
break
cycle_total = len(cycles)
for edge_id, edge in enumerate(digraph.edges): # iterate through all the edges
edge_lookup[edge_id] = [edge[0], edge[1]] # store a edge lookup dic
edge2id[edge[0], edge[1]] = edge_id
send_00.append(edge[0]) # append 00 and 01
rec_00.append(edge[1])
send_01.append(edge[0])
rec_01.append(edge_id)
send_01.append(edge[1])
rec_01.append(edge_id)
nodes_00.extend([send_00, rec_00]) # 0->0 cell communication
coboundary_01.extend([send_01, rec_01]) # 0->1 cell communcation
# take nodes that form a cycle and return their edges
def to_edge_set(cycle):
edges = []
n_edges = len(cycle)
for node_id, node in enumerate(cycle):
next_node = cycle[(node_id + 1) % n_edges]
assert digraph.has_edge(node, next_node)
edge = edge2id[node, next_node]
edges.append(edge)
return edges
# same function as above but in both directions (needed for later)
def to_edge_set_both_directions(cycle):
edges = []
n_edges = len(cycle)
for node_id, node in enumerate(cycle):
next_node = cycle[(node_id + 1) % n_edges]
if digraph.has_edge(node, next_node):
edge = edge2id[node, next_node]
edges.append(edge)
if digraph.has_edge(next_node, node):
edge = edge2id[next_node, node]
edges.append(edge)
return edges
send_11 = [] # send edge to edge communication over upper adj
rec_11 = [] # rec edge to edge communication over upper adj
send_12 = [] # send of edge to ring
rec_12 = [] # rec of edge to ring
cycle_lookup = {}
cycle_edges_lookup = {}
cycle_edges_bidirected_lookup = {}
len_cycle = []
# iterate for each cycle to find all the adjacencies that are related
for cycle_id, cycle in enumerate(cycles):
len_cycle.append(len(cycle))
cycle_lookup[cycle_id] = cycle # store a cycle lookup dic
cycle_edges_bidirected_lookup[cycle_id] = to_edge_set_both_directions(cycle) # store all edges within the ring (consider both directions)
cycle = to_edge_set(cycle) # change node ids to edge ids
cycle_edges_lookup[cycle_id] = cycle # finding edges of cycles so i can check
# whether there is some lower adjacency intersection (needed for 2->2 cell communication)
# 1->1 cell communication (edge index pairs that are part of the same ring)
for i in cycle:
for j in cycle:
if i != j:
send_11.append(i)
rec_11.append(j)
# 1->2 cell communication
for edge_id in cycle:
send_12.append(edge_id)
rec_12.append(cycle_id)
coboundary_12.extend([send_12, rec_12]) # 1->2 cell communication
edges_11.extend([send_11, rec_11])
# 2->2 cell communication
cell_22 = []
send_22 = []
rec_22 = []
for send_cell in cycle_edges_lookup.keys():
send_cycle_edge = set(cycle_edges_lookup[send_cell])
for rec_cell in cycle_edges_lookup.keys():
# if it is the same ring continue
if send_cell == rec_cell:
continue
rec_cycle_edge = set(cycle_edges_lookup[rec_cell])
if len(send_cycle_edge.intersection(rec_cycle_edge)) > 0: # if two rings have common edges
send_22.append(send_cell)
rec_22.append(rec_cell)
cell_22.extend([send_22, rec_22])
nodes_00 = torch.tensor(nodes_00, dtype=torch.long) # node -> node
edges_11 = torch.tensor(edges_11, dtype=torch.long) # edge -> edge of same ring
coboundary_01 = torch.tensor(coboundary_01, dtype=torch.long) # node -> edge
coboundary_12 = torch.tensor(coboundary_12, dtype=torch.long) # edge -> cycle
cell_22 = torch.tensor(cell_22, dtype=torch.long)
edge_lookup = edge_lookup # edge id -> edge (node pair)
cycle_lookup = cycle_lookup # cycle id -> cycle (nodes that are part of the cycle)
edge2id = edge2id # edge (node pair) -> edge id
cycle_edges_lookup = cycle_edges_lookup # cycle id -> cycle (edges that are part of the cycle)
cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup # consider all edges in one ring. helpful for 11 inv
cycle_present = torch.tensor(cycle_present)
adj = {}
adj["adj_00"] = nodes_00
adj["adj_01"] = coboundary_01
adj["adj_11"] = edges_11
adj["adj_12"] = coboundary_12
adj["adj_22"] = cell_22 # cell22 dim is [2,0] when no cycle
cc = cls(adj = adj,
nodes_00 = nodes_00,
edges_11 = edges_11,
coboundary_01 = coboundary_01,
coboundary_12 = coboundary_12,
cell_22 = cell_22,
edge_lookup = edge_lookup,
cycle_lookup = cycle_lookup,
edge2id = edge2id,
cycle_edges_lookup = cycle_edges_lookup,
cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup,
cycle_present = cycle_present,
cycle_total = cycle_total)
return cc
class CellularComplexData(Data):
@classmethod
def take_feat(cls, graph: nx.Graph, data: Data, cc: CellularComplex):
"""
Take features of 0-1-2 cell individually.
"""
# for nodes
x_0 = data.x
# for edges
x_1 = torch.mean(x_0[cc.nodes_00], dim = 0)
# for cycles
x_2 = []
for cycle in list(cc.cycle_lookup.values()):
x_2.append(x_0[cycle].to(dtype = torch.float32).mean(0))
if len(x_2) == 0:
x_2 = torch.empty((0,11))
else:
x_2 = torch.stack(x_2)
x_dict = {"x_0": x_0, "x_1": x_1, "x_2": x_2}
return x_dict
"""
The next functions are used to calculate the invariants for all the different communications.
"""
@staticmethod
def find_ring_length(positions, vertices):
# Create a dictionary mapping vertex to its position
vertex_to_position = {vertex: position for vertex, position in zip(vertices, positions)}
# Make sure vertices are cyclic
cyclic_vertices = vertices + vertices[:1]
# Calculate total length
total_length = sum((vertex_to_position[cyclic_vertices[i+1]] - vertex_to_position[cyclic_vertices[i]]).norm() for i in range(len(vertices)))
return total_length.item()
@staticmethod
def find_area_triangle(positions):
# Form the vectors
A, B, C = positions
AB = B - A
AC = C - A
cross_product = np.cross(AB, AC)
# magnitude
area_parallelogram = np.linalg.norm(cross_product)
# area of the triangle is half the area of the parallelogram
area_triangle = area_parallelogram / 2
return area_triangle
@staticmethod
def ritters_bounding_sphere(points):
# number of dimensions
n = points.shape[1]
# Get a point from the set, this is inside the sphere
sphere_center = points[0]
# Find the point P that is farthest away from initial_point
distances = np.linalg.norm(points - sphere_center, axis=1)
farthest_point = points[np.argmax(distances)]
# Find the point Q that is farthest away from P
distances = np.linalg.norm(points - farthest_point, axis=1)
other_end_point = points[np.argmax(distances)]
# Set the center of the sphere to be the midpoint of P and Q
sphere_center = (farthest_point + other_end_point) / 2
#Set the radius of the sphere to be the distance between P and Q
sphere_radius = np.linalg.norm(farthest_point - other_end_point) / 2
# Until all points are within the sphere, expand the sphere to include outliers
for point in points:
distance = np.linalg.norm(point - sphere_center)
if distance > sphere_radius:
# Set the radius to be the distance between the point and the sphere's center
# and move the center half way towards the point
sphere_radius = (sphere_radius + distance) / 2
sphere_center = sphere_center + (distance - sphere_radius) * (point - sphere_center) / distance
return sphere_center, sphere_radius
@staticmethod
def calculate_angle(v1, v2):
dot_product = np.dot(v1, v2)
magnitude_v1 = np.linalg.norm(v1)
magnitude_v2 = np.linalg.norm(v2)
angle = np.arccos(np.clip(dot_product / (magnitude_v1 * magnitude_v2), -1, 1))
return angle
@staticmethod
def find_best_planes(points):
coeff_planes = []
# each permutation of axes
for axes in [[0, 1, 2], [0, 2, 1], [1, 2, 0]]:
a = points[:, axes[0]]
b = points[:, axes[1]]
c = points[:, axes[2]]
A = np.c_[a, b, np.ones(points.shape[0])]
C, B, A = np.linalg.lstsq(A, c, rcond=None)[0]
coeff_planes.append([A, B, C])
return coeff_planes
@staticmethod
def dist_to_best_planes(point, planes):
distances = []
axes = [[0, 1, 2], [0, 2, 1], [1, 2, 0]]
for i, (A, B, C) in enumerate(planes):
distance = np.abs(A*point[axes[i][0]] + B*point[axes[i][1]] - point[axes[i][2]]) / np.sqrt(A**2 + B**2 + 1)
distances.append(distance)
distances = torch.tensor(distances)
return distances.mean()
@classmethod
def take_inv(cls, graph: nx.Graph, data: Data, cc: CellularComplex):
"""
Use all the functions that were written above to assign for each adjacency their respective invariants
"""
pos = data.pos
inv_dict = {}
# for 00 and 01
send_00, rec_00 = cc.nodes_00
dist_00 = torch.linalg.norm(pos[send_00] - pos[rec_00], dim=1)
dist_01 = dist_00[cc.coboundary_01[1]]
inv_dict["dist_00"] = dist_00
inv_dict["dist_01"] = dist_01
if len(cc.cycle_lookup.values()) > 0:
#for 11 and 12
center_coordinates = []
sphere_radius = []
sphere_center = []
send_11, rec_11 = cc.edges_11
edge1_v = [cc.edge_lookup[i.item()][0] for i in send_11]
edge2_v = [cc.edge_lookup[i.item()][0] for i in rec_11]
dist_11 = torch.linalg.norm(pos[edge1_v] - pos[edge2_v], dim=1) # distance for edge to edge communication
inv_dict["dist_11"] = dist_11
ring_length = []
vol = []
area = []
angles = []
dist_to_sphere_center = []
mean_coordinate_ring = []
dist_to_mean_coord = []
best_planes = []
dist_to_best_planes_arr = []
radius_per_edge = []
sphere_center_arr = []
# Calculating area and volume
for ring in list(cc.cycle_lookup.values()):
vertex_indices = torch.tensor(ring)
mean_coordinate = pos[vertex_indices].mean(dim=0) # centre of mass
mean_coordinate_ring.append(mean_coordinate)
# best_planes.append(CellularComplexData.find_best_planes(pos[vertex_indices]))
center_coordinates.append(mean_coordinate.tolist())
ring_length.append(CellularComplexData.find_ring_length(pos, ring))
if len(vertex_indices)>3:
hull = ConvexHull(pos[vertex_indices], qhull_options='QJ')
vol.append(abs(hull.volume))
area.append(abs(hull.area))
elif len(vertex_indices)>2:
vol.append(0) # vol is 0 if there are 3 points
area.append(CellularComplexData.find_area_triangle(pos[vertex_indices]))
else:
vol.append(0)
area.append(0)
# sphere_center, radius = CellularComplexData.ritters_bounding_sphere(pos[vertex_indices])
distances = torch.norm(pos[vertex_indices] - mean_coordinate, dim=1)
radius = torch.max(distances)
sphere_radius.append(radius)
# sphere_center_arr.append(sphere_center)
# Calculate invariants for edges
for i in range(len(cc.edges_11[0])):
# Get edge pair
edge1 = cc.edges_11[0][i]
edge2 = cc.edges_11[1][i]
# check in which ring is it
list_biderected_ring_edges = list(cc.cycle_edges_bidirected_lookup.values())
find_ring_1 = [edge1 in sublist for sublist in list_biderected_ring_edges]
find_ring_2 = [edge2 in sublist for sublist in list_biderected_ring_edges]
first_true_index = next(i for i, (a, b) in enumerate(zip(find_ring_1, find_ring_2)) if a and b) # find first true index
edge1_vertices = cc.edge_lookup[edge1.item()]
# take invariants. pick the first ring that the two edges are present
# dist_to_sphere_center.append(torch.norm(pos[edge1_vertices[0]] - sphere_center_arr[first_true_index]))
dist_to_mean_coord.append(torch.norm(pos[edge1_vertices[0]] - mean_coordinate_ring[first_true_index]))
# dbp = CellularComplexData.dist_to_best_planes(pos[edge1_vertices[0]], best_planes[first_true_index])
# dist_to_best_planes_arr.append(dbp.item())
radius_per_edge.append(sphere_radius[first_true_index])
edge2_vertices = cc.edge_lookup[edge2.item()]
edge1_pos = pos[edge1_vertices]
edge2_pos = pos[edge2_vertices]
v1 = edge1_pos[1] - edge1_pos[0]
v2 = edge2_pos[1] - edge2_pos[0]
angle = CellularComplexData.calculate_angle(v1, v2)
angles.append(np.degrees(angle))
# invariants for 11 communication (same size with cc.edges_11 (adj 1->1 that share upper))
angles = torch.tensor(angles)
dist_to_sphere_center = torch.tensor(dist_to_sphere_center)
dist_to_mean_coord = torch.tensor(dist_to_mean_coord)
dist_to_best_planes_arr = torch.tensor(dist_to_best_planes_arr)
radius_per_edge = torch.tensor(radius_per_edge)
inv_dict["angles"] = angles
inv_dict["dist_to_sphere_center"] = dist_to_sphere_center
inv_dict["dist_to_mean_coord"] = dist_to_mean_coord
inv_dict["dist_to_best_planes"] = dist_to_best_planes_arr
inv_dict["radius_per_edge"] = radius_per_edge
# for 1->2 communication (same size with cc.coboundary_12 (adj 1->2))
vol_12 = []
area_12 = []
radius_12 = []
for j in range(len(cc.coboundary_12[1])):
ring_index = cc.coboundary_12[1][j]
vol_12.append(vol[ring_index])
area_12.append(area[ring_index])
# radius_12.append(sphere_radius[ring_index])
# invariants for 1->2 communication
vol_12 = torch.tensor(vol_12)
area_12 = torch.tensor(area_12)
radius_12 = torch.tensor(radius_12)
inv_dict["vol_12"] = vol_12
inv_dict["area_12"] = area_12
inv_dict["radius_12"] = radius_12
for k in ["dist_11", "angles", "dist_to_sphere_center", "dist_to_mean_coord", "dist_to_best_planes", "radius_per_edge", "vol_12", "area_12", "radius_12"]:
if k not in inv_dict:
inv_dict[k] = torch.empty(0)
return inv_dict
@classmethod
def from_data_cc_pair(cls, data: Data, cc: CellularComplex, x_dict:dict, inv_dict:dict):
"""
Unpack everything so I can use it in the model later.
"""
for k, v in x_dict.items():
data[k] = v
for k, v in cc.adj.items():
data[k] = v
for k, v in inv_dict.items():
data[k] = v
data["graph_len"] = data.x.size(0)
data["cycle_pres"] = cc.cycle_present
data["cycle_total"] = cc.cycle_total
for att in ['edge_attr', 'idx', 'name', 'z']:
if hasattr(data, att):
data.pop(att)
mapping = data.items()._mapping
return cls(**mapping)
def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:
if 'adj' in key or "edge_index" in key:
return 1
else:
return 0
def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
"""
Perform advanced increment for different adjacencies as the 0-0 cell communication
gets incremented in a different way from 1->2 communication.
"""
if 'adj' in key:
i, j = key[4], key[5]
return torch.tensor([[getattr(self, f'x_{i}').size(0)], [getattr(self, f'x_{j}').size(0)]])
elif "edge_index" in key:
return torch.tensor([[getattr(self, 'x_0').size(0)], [getattr(self, 'x_0').size(0)]])
else:
return super().__inc__(key, value, *args, **kwargs)
class LiftGraphToCC(BaseTransform):
def __call__(self, data: Data) -> CellularComplexData:
graph = nx.Graph()
# graph.add_nodes_from(range(data.num_nodes))
# if you want to connect more
conn = RadiusGraph(2)(data)
if len(data.edge_index[0]) < len(conn.edge_index[0]):
combined_edge_index = torch.cat((data.edge_index, conn.edge_index), dim=1)
data.edge_index = torch.tensor(list(set(tuple(i.tolist()) for i in combined_edge_index.T)), dtype=torch.long).T
data.edge_index, _ = remove_self_loops(data.edge_index)
graph.add_nodes_from(range(data.x.size(0)))
graph.add_edges_from(data.edge_index.T.tolist())
cc = CellularComplex.from_nx_graph(graph, data)
x_dict = CellularComplexData.take_feat(graph, data, cc)
inv_dict = CellularComplexData.take_inv(graph, data, cc)
new_data = CellularComplexData.from_data_cc_pair(data, cc, x_dict, inv_dict)
return new_dataEditor is loading...