# Untitled

unknown
plain_text
a year ago
24 kB
1
Indexable
Never
```from typing import Optional, List, Tuple, Dict
from itertools import combinations

import torch
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from typing import Any
from scipy.spatial import ConvexHull
import numpy as np
from torch_geometric.utils import remove_self_loops

class CellularComplex:
def __init__(self,
nodes_00: torch.Tensor,
edges_11: torch.Tensor,
coboundary_01: torch.Tensor,
coboundary_12: torch.Tensor,
cell_22: torch.Tensor,
edge_lookup: dict,
cycle_lookup: dict,
edge2id: dict,
cycle_edges_lookup: dict,
cycle_edges_bidirected_lookup: dict,
cycle_present: int,
cycle_total: int ):

"""
Cellular Complex
"""
self.coboundary_01 = coboundary_01
self.coboundary_12 = coboundary_12
self.nodes_00 = nodes_00
self.edges_11 = edges_11
self.cell_22 = cell_22
self.edge_lookup = edge_lookup
self.cycle_lookup = cycle_lookup
self.edge2id = edge2id
self.cycle_edges_lookup = cycle_edges_lookup
self.cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup
self.cycle_present = cycle_present
self.cycle_total = cycle_total

@classmethod
def from_nx_graph(cls, graph: nx.Graph, data: Data):

"""
This function is used to find all the adjacencies between 0->0 communication, 0->1, 1->1, 1->2, 2->2.
Also it returns the lookup dictionaries for each edge and cycle that might be needed when calculating the invariants. For cycles, only one orientation is considered.

"""

nodes_00 = []
coboundary_01 = []
edges_11 = []
coboundary_12 = []

edge_lookup = {}
edge2id = {}

send_00 = [] # send of vertex to vertex
rec_00 = [] # rec of vertex to vertex
send_01 = [] # send of vertex to edge
rec_01 = [] # rec of vertex to edge
digraph = graph.to_directed()
cycle_present = 0
cycle_total = 0

cycles = [cycle for cycle in nx.simple_cycles(digraph) if len(cycle) > 2] # find cycles

# if there are cycles, check and delete different orientations. keep only 1 per cycle
if len(cycles) != 0:
cycle_present += 1
dict_lists = {tuple(sorted(lst)): lst for lst in cycles}
unique_lists = list(dict_lists.values())
cycles = unique_lists

# if there are no cycles, conenct the two closest vertices until a cycle is formed
if len(cycles) == 0:
dist_matrix = torch.cdist(data.pos, data.pos)
dist_matrix.fill_diagonal_(float('inf'))
for edge in data.edge_index.t():
dist_matrix[edge[0], edge[1]] = float('inf')
dist_matrix[edge[1], edge[0]] = float('inf')

while True:
# Find the pair of nodes with the smallest distance
min_val = torch.min(dist_matrix)
min_index = torch.where(dist_matrix == min_val)
node1, node2 = min_index[0][0], min_index[1][0]

# Set the distance between these nodes to infinity so they won't be selected again
dist_matrix[node1, node2] = float('inf')
dist_matrix[node2, node1] = float('inf')

# Add an edge between these nodes
data.edge_index = torch.cat([data.edge_index, torch.tensor([[node1, node2], [node2, node1]], dtype=torch.long)], dim=1)

graph_new = nx.Graph()
digraph_new = graph_new.to_directed()
digraph = digraph_new
graph = graph_new

cycles = [cycle for cycle in nx.simple_cycles(digraph_new) if len(cycle) > 2] # find cycles
if len(cycles) !=0:
dict_lists = {tuple(sorted(lst)): lst for lst in cycles}
unique_lists = list(dict_lists.values())
cycles = unique_lists
break
cycle_total = len(cycles)

for edge_id, edge in enumerate(digraph.edges): # iterate through all the edges

edge_lookup[edge_id] = [edge[0], edge[1]] # store a edge lookup dic
edge2id[edge[0], edge[1]] = edge_id

send_00.append(edge[0]) # append 00 and 01
rec_00.append(edge[1])
send_01.append(edge[0])
rec_01.append(edge_id)
send_01.append(edge[1])
rec_01.append(edge_id)

nodes_00.extend([send_00, rec_00]) # 0->0 cell communication
coboundary_01.extend([send_01, rec_01]) # 0->1 cell communcation

# take nodes that form a cycle and return their edges
def to_edge_set(cycle):
edges = []
n_edges = len(cycle)
for node_id, node in enumerate(cycle):
next_node = cycle[(node_id + 1) % n_edges]
assert digraph.has_edge(node, next_node)
edge = edge2id[node, next_node]
edges.append(edge)
return edges

# same function as above but in both directions (needed for later)
def to_edge_set_both_directions(cycle):
edges = []
n_edges = len(cycle)
for node_id, node in enumerate(cycle):
next_node = cycle[(node_id + 1) % n_edges]

if digraph.has_edge(node, next_node):
edge = edge2id[node, next_node]
edges.append(edge)

if digraph.has_edge(next_node, node):
edge = edge2id[next_node, node]
edges.append(edge)

return edges

send_11 = [] # send edge to edge communication over upper adj
rec_11 = [] # rec edge to edge communication over upper adj
send_12 = [] # send of edge to ring
rec_12 = [] # rec of edge to ring
cycle_lookup = {}
cycle_edges_lookup = {}
cycle_edges_bidirected_lookup = {}
len_cycle = []

# iterate for each cycle to find all the adjacencies that are related
for cycle_id, cycle in enumerate(cycles):
len_cycle.append(len(cycle))
cycle_lookup[cycle_id] = cycle # store a cycle lookup dic
cycle_edges_bidirected_lookup[cycle_id] = to_edge_set_both_directions(cycle)  # store all edges within the ring (consider both directions)
cycle = to_edge_set(cycle) # change node ids to edge ids
cycle_edges_lookup[cycle_id] = cycle # finding edges of cycles so i can check
# whether there is some lower adjacency intersection (needed for 2->2 cell communication)

# 1->1 cell communication (edge index pairs that are part of the same ring)
for i in cycle:
for j in cycle:
if i != j:
send_11.append(i)
rec_11.append(j)

# 1->2 cell communication
for edge_id in cycle:
send_12.append(edge_id)
rec_12.append(cycle_id)

coboundary_12.extend([send_12, rec_12]) #  1->2 cell communication
edges_11.extend([send_11, rec_11])

# 2->2 cell communication
cell_22 = []
send_22 = []
rec_22 = []

for send_cell in cycle_edges_lookup.keys():
send_cycle_edge = set(cycle_edges_lookup[send_cell])
for rec_cell in cycle_edges_lookup.keys():
# if it is the same ring continue
if send_cell == rec_cell:
continue
rec_cycle_edge = set(cycle_edges_lookup[rec_cell])
if len(send_cycle_edge.intersection(rec_cycle_edge)) > 0: # if two rings have common edges
send_22.append(send_cell)
rec_22.append(rec_cell)

cell_22.extend([send_22, rec_22])

nodes_00 = torch.tensor(nodes_00, dtype=torch.long) # node -> node
edges_11 = torch.tensor(edges_11, dtype=torch.long) # edge -> edge of same ring
coboundary_01 = torch.tensor(coboundary_01, dtype=torch.long) # node -> edge
coboundary_12 = torch.tensor(coboundary_12, dtype=torch.long) # edge -> cycle
cell_22 = torch.tensor(cell_22, dtype=torch.long)
edge_lookup = edge_lookup # edge id -> edge (node pair)
cycle_lookup = cycle_lookup # cycle id -> cycle (nodes that are part of the cycle)
edge2id = edge2id # edge (node pair) -> edge id
cycle_edges_lookup = cycle_edges_lookup # cycle id -> cycle (edges that are part of the cycle)
cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup # consider all edges in one ring. helpful for 11 inv
cycle_present = torch.tensor(cycle_present)

nodes_00 = nodes_00,
edges_11 = edges_11,
coboundary_01 = coboundary_01,
coboundary_12 = coboundary_12,
cell_22 = cell_22,
edge_lookup = edge_lookup,
cycle_lookup = cycle_lookup,
edge2id = edge2id,
cycle_edges_lookup = cycle_edges_lookup,
cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup,
cycle_present = cycle_present,
cycle_total = cycle_total)

return cc

class CellularComplexData(Data):

@classmethod
def take_feat(cls, graph: nx.Graph, data: Data, cc: CellularComplex):
"""
Take features of 0-1-2 cell individually.
"""

# for nodes
x_0 = data.x

# for edges
x_1 = torch.mean(x_0[cc.nodes_00], dim = 0)

# for cycles
x_2 = []
for cycle in list(cc.cycle_lookup.values()):
x_2.append(x_0[cycle].to(dtype = torch.float32).mean(0))

if len(x_2) == 0:
x_2 = torch.empty((0,11))
else:
x_2 = torch.stack(x_2)
x_dict = {"x_0": x_0, "x_1": x_1, "x_2": x_2}

return x_dict

"""
The next functions are used to calculate the invariants for all the different communications.
"""

@staticmethod
def find_ring_length(positions, vertices):
# Create a dictionary mapping vertex to its position
vertex_to_position = {vertex: position for vertex, position in zip(vertices, positions)}

# Make sure vertices are cyclic
cyclic_vertices = vertices + vertices[:1]

# Calculate total length
total_length = sum((vertex_to_position[cyclic_vertices[i+1]] - vertex_to_position[cyclic_vertices[i]]).norm() for i in range(len(vertices)))

@staticmethod
def find_area_triangle(positions):
# Form the vectors
A, B, C = positions
AB = B - A
AC = C - A

cross_product = np.cross(AB, AC)

# magnitude
area_parallelogram = np.linalg.norm(cross_product)

# area of the triangle is half the area of the parallelogram
area_triangle = area_parallelogram / 2

return area_triangle

@staticmethod
def ritters_bounding_sphere(points):
# number of dimensions
n = points.shape[1]

# Get a point from the set, this is inside the sphere
sphere_center = points[0]

# Find the point P that is farthest away from initial_point
distances = np.linalg.norm(points - sphere_center, axis=1)
farthest_point = points[np.argmax(distances)]

# Find the point Q that is farthest away from P
distances = np.linalg.norm(points - farthest_point, axis=1)
other_end_point = points[np.argmax(distances)]

# Set the center of the sphere to be the midpoint of P and Q
sphere_center = (farthest_point + other_end_point) / 2

#Set the radius of the sphere to be the distance between P and Q
sphere_radius = np.linalg.norm(farthest_point - other_end_point) / 2

# Until all points are within the sphere, expand the sphere to include outliers
for point in points:
distance = np.linalg.norm(point - sphere_center)
# Set the radius to be the distance between the point and the sphere's center
# and move the center half way towards the point
sphere_center = sphere_center + (distance - sphere_radius) * (point - sphere_center) / distance

@staticmethod
def calculate_angle(v1, v2):
dot_product = np.dot(v1, v2)

magnitude_v1 = np.linalg.norm(v1)
magnitude_v2 = np.linalg.norm(v2)

angle = np.arccos(np.clip(dot_product / (magnitude_v1 * magnitude_v2), -1, 1))

return angle

@staticmethod
def find_best_planes(points):
coeff_planes = []

#  each permutation of axes
for axes in [[0, 1, 2], [0, 2, 1], [1, 2, 0]]:

a = points[:, axes[0]]
b = points[:, axes[1]]
c = points[:, axes[2]]

A = np.c_[a, b, np.ones(points.shape[0])]
C, B, A = np.linalg.lstsq(A, c, rcond=None)[0]

coeff_planes.append([A, B, C])

return coeff_planes

@staticmethod
def dist_to_best_planes(point, planes):
distances = []
axes = [[0, 1, 2], [0, 2, 1], [1, 2, 0]]

for i, (A, B, C) in enumerate(planes):

distance = np.abs(A*point[axes[i][0]] + B*point[axes[i][1]] - point[axes[i][2]]) / np.sqrt(A**2 + B**2 + 1)
distances.append(distance)

distances = torch.tensor(distances)
return distances.mean()

@classmethod
def take_inv(cls, graph: nx.Graph, data: Data, cc: CellularComplex):

"""
Use all the functions that were written above to assign for each adjacency their respective invariants
"""

pos = data.pos
inv_dict = {}

# for 00 and 01
send_00, rec_00 = cc.nodes_00
dist_00 = torch.linalg.norm(pos[send_00] - pos[rec_00], dim=1)
dist_01 = dist_00[cc.coboundary_01[1]]

inv_dict["dist_00"] = dist_00
inv_dict["dist_01"] = dist_01

if len(cc.cycle_lookup.values()) > 0:
#for 11 and 12
center_coordinates = []
sphere_center = []
send_11, rec_11 = cc.edges_11
edge1_v = [cc.edge_lookup[i.item()][0] for i in send_11]
edge2_v = [cc.edge_lookup[i.item()][0] for i in rec_11]
dist_11 = torch.linalg.norm(pos[edge1_v] - pos[edge2_v], dim=1) # distance for edge to edge communication
inv_dict["dist_11"] = dist_11

ring_length = []
vol = []
area = []
angles = []
dist_to_sphere_center = []
mean_coordinate_ring = []
dist_to_mean_coord = []
best_planes = []
dist_to_best_planes_arr = []
sphere_center_arr = []

# Calculating area and volume
for ring in list(cc.cycle_lookup.values()):

vertex_indices = torch.tensor(ring)

mean_coordinate = pos[vertex_indices].mean(dim=0) # centre of mass
mean_coordinate_ring.append(mean_coordinate)

# best_planes.append(CellularComplexData.find_best_planes(pos[vertex_indices]))
center_coordinates.append(mean_coordinate.tolist())

ring_length.append(CellularComplexData.find_ring_length(pos, ring))
if len(vertex_indices)>3:
hull = ConvexHull(pos[vertex_indices], qhull_options='QJ')
vol.append(abs(hull.volume))
area.append(abs(hull.area))
elif len(vertex_indices)>2:
vol.append(0) # vol is 0 if there are 3 points
area.append(CellularComplexData.find_area_triangle(pos[vertex_indices]))
else:
vol.append(0)
area.append(0)

distances = torch.norm(pos[vertex_indices] - mean_coordinate, dim=1)

# sphere_center_arr.append(sphere_center)

# Calculate invariants for edges
for i in range(len(cc.edges_11[0])):
# Get edge pair
edge1 = cc.edges_11[0][i]
edge2 = cc.edges_11[1][i]

# check in which ring is it
list_biderected_ring_edges = list(cc.cycle_edges_bidirected_lookup.values())
find_ring_1 = [edge1 in sublist for sublist in list_biderected_ring_edges]
find_ring_2 = [edge2 in sublist for sublist in list_biderected_ring_edges]
first_true_index = next(i for i, (a, b) in enumerate(zip(find_ring_1, find_ring_2)) if a and b) # find first true index

edge1_vertices = cc.edge_lookup[edge1.item()]
# take invariants. pick the first ring that the two edges are present
# dist_to_sphere_center.append(torch.norm(pos[edge1_vertices[0]] - sphere_center_arr[first_true_index]))
dist_to_mean_coord.append(torch.norm(pos[edge1_vertices[0]] - mean_coordinate_ring[first_true_index]))
# dbp = CellularComplexData.dist_to_best_planes(pos[edge1_vertices[0]], best_planes[first_true_index])
# dist_to_best_planes_arr.append(dbp.item())

edge2_vertices = cc.edge_lookup[edge2.item()]

edge1_pos = pos[edge1_vertices]
edge2_pos = pos[edge2_vertices]

v1 = edge1_pos[1] - edge1_pos[0]
v2 = edge2_pos[1] - edge2_pos[0]

angle = CellularComplexData.calculate_angle(v1, v2)
angles.append(np.degrees(angle))

# invariants for 11 communication (same size with cc.edges_11 (adj 1->1 that share upper))
angles = torch.tensor(angles)
dist_to_sphere_center = torch.tensor(dist_to_sphere_center)
dist_to_mean_coord = torch.tensor(dist_to_mean_coord)
dist_to_best_planes_arr = torch.tensor(dist_to_best_planes_arr)

inv_dict["angles"] = angles
inv_dict["dist_to_sphere_center"] = dist_to_sphere_center
inv_dict["dist_to_mean_coord"] = dist_to_mean_coord
inv_dict["dist_to_best_planes"] = dist_to_best_planes_arr

# for 1->2 communication (same size with cc.coboundary_12 (adj 1->2))
vol_12 = []
area_12 = []

for j in range(len(cc.coboundary_12[1])):
ring_index = cc.coboundary_12[1][j]
vol_12.append(vol[ring_index])
area_12.append(area[ring_index])

# invariants for 1->2 communication
vol_12 = torch.tensor(vol_12)
area_12 = torch.tensor(area_12)

inv_dict["vol_12"] = vol_12
inv_dict["area_12"] = area_12

for k in ["dist_11", "angles", "dist_to_sphere_center", "dist_to_mean_coord", "dist_to_best_planes", "radius_per_edge", "vol_12", "area_12", "radius_12"]:
if k not in inv_dict:
inv_dict[k] = torch.empty(0)
return inv_dict

@classmethod
def from_data_cc_pair(cls, data: Data, cc: CellularComplex, x_dict:dict, inv_dict:dict):

"""
Unpack everything so I can use it in the model later.
"""

for k, v in x_dict.items():
data[k] = v

data[k] = v

for k, v in inv_dict.items():
data[k] = v

data["graph_len"] = data.x.size(0)
data["cycle_pres"] = cc.cycle_present
data["cycle_total"] = cc.cycle_total

for att in ['edge_attr', 'idx', 'name', 'z']:
if hasattr(data, att):
data.pop(att)

mapping = data.items()._mapping
return cls(**mapping)

def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:

if 'adj' in key or "edge_index" in key:
return 1
else:
return 0

def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
"""
gets incremented in a different way from 1->2 communication.
"""

i, j = key[4], key[5]
elif "edge_index" in key:
else:
return super().__inc__(key, value, *args, **kwargs)

class LiftGraphToCC(BaseTransform):
def __call__(self, data: Data) -> CellularComplexData:
graph = nx.Graph()

# if you want to connect more