Untitled
unknown
plain_text
2 years ago
24 kB
5
Indexable
from typing import Optional, List, Tuple, Dict from itertools import combinations import torch import networkx as nx from torch_geometric.data import Data from torch_geometric.transforms import BaseTransform from typing import Any from scipy.spatial import ConvexHull import numpy as np from torch_geometric.transforms import RadiusGraph from torch_geometric.utils import remove_self_loops class CellularComplex: def __init__(self, adj: dict, nodes_00: torch.Tensor, edges_11: torch.Tensor, coboundary_01: torch.Tensor, coboundary_12: torch.Tensor, cell_22: torch.Tensor, edge_lookup: dict, cycle_lookup: dict, edge2id: dict, cycle_edges_lookup: dict, cycle_edges_bidirected_lookup: dict, cycle_present: int, cycle_total: int ): """ Cellular Complex """ self.adj = adj self.coboundary_01 = coboundary_01 self.coboundary_12 = coboundary_12 self.nodes_00 = nodes_00 self.edges_11 = edges_11 self.cell_22 = cell_22 self.edge_lookup = edge_lookup self.cycle_lookup = cycle_lookup self.edge2id = edge2id self.cycle_edges_lookup = cycle_edges_lookup self.cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup self.cycle_present = cycle_present self.cycle_total = cycle_total @classmethod def from_nx_graph(cls, graph: nx.Graph, data: Data): """ This function is used to find all the adjacencies between 0->0 communication, 0->1, 1->1, 1->2, 2->2. Also it returns the lookup dictionaries for each edge and cycle that might be needed when calculating the invariants. For cycles, only one orientation is considered. """ nodes_00 = [] coboundary_01 = [] edges_11 = [] coboundary_12 = [] edge_lookup = {} edge2id = {} send_00 = [] # send of vertex to vertex rec_00 = [] # rec of vertex to vertex send_01 = [] # send of vertex to edge rec_01 = [] # rec of vertex to edge digraph = graph.to_directed() cycle_present = 0 cycle_total = 0 cycles = [cycle for cycle in nx.simple_cycles(digraph) if len(cycle) > 2] # find cycles # if there are cycles, check and delete different orientations. keep only 1 per cycle if len(cycles) != 0: cycle_present += 1 dict_lists = {tuple(sorted(lst)): lst for lst in cycles} unique_lists = list(dict_lists.values()) cycles = unique_lists # if there are no cycles, conenct the two closest vertices until a cycle is formed if len(cycles) == 0: dist_matrix = torch.cdist(data.pos, data.pos) dist_matrix.fill_diagonal_(float('inf')) for edge in data.edge_index.t(): dist_matrix[edge[0], edge[1]] = float('inf') dist_matrix[edge[1], edge[0]] = float('inf') while True: # Find the pair of nodes with the smallest distance min_val = torch.min(dist_matrix) min_index = torch.where(dist_matrix == min_val) node1, node2 = min_index[0][0], min_index[1][0] # Set the distance between these nodes to infinity so they won't be selected again dist_matrix[node1, node2] = float('inf') dist_matrix[node2, node1] = float('inf') # Add an edge between these nodes data.edge_index = torch.cat([data.edge_index, torch.tensor([[node1, node2], [node2, node1]], dtype=torch.long)], dim=1) graph_new = nx.Graph() graph_new.add_nodes_from(range(data.x.size(0))) graph_new.add_edges_from(data.edge_index.T.tolist()) digraph_new = graph_new.to_directed() digraph = digraph_new graph = graph_new cycles = [cycle for cycle in nx.simple_cycles(digraph_new) if len(cycle) > 2] # find cycles if len(cycles) !=0: dict_lists = {tuple(sorted(lst)): lst for lst in cycles} unique_lists = list(dict_lists.values()) cycles = unique_lists break cycle_total = len(cycles) for edge_id, edge in enumerate(digraph.edges): # iterate through all the edges edge_lookup[edge_id] = [edge[0], edge[1]] # store a edge lookup dic edge2id[edge[0], edge[1]] = edge_id send_00.append(edge[0]) # append 00 and 01 rec_00.append(edge[1]) send_01.append(edge[0]) rec_01.append(edge_id) send_01.append(edge[1]) rec_01.append(edge_id) nodes_00.extend([send_00, rec_00]) # 0->0 cell communication coboundary_01.extend([send_01, rec_01]) # 0->1 cell communcation # take nodes that form a cycle and return their edges def to_edge_set(cycle): edges = [] n_edges = len(cycle) for node_id, node in enumerate(cycle): next_node = cycle[(node_id + 1) % n_edges] assert digraph.has_edge(node, next_node) edge = edge2id[node, next_node] edges.append(edge) return edges # same function as above but in both directions (needed for later) def to_edge_set_both_directions(cycle): edges = [] n_edges = len(cycle) for node_id, node in enumerate(cycle): next_node = cycle[(node_id + 1) % n_edges] if digraph.has_edge(node, next_node): edge = edge2id[node, next_node] edges.append(edge) if digraph.has_edge(next_node, node): edge = edge2id[next_node, node] edges.append(edge) return edges send_11 = [] # send edge to edge communication over upper adj rec_11 = [] # rec edge to edge communication over upper adj send_12 = [] # send of edge to ring rec_12 = [] # rec of edge to ring cycle_lookup = {} cycle_edges_lookup = {} cycle_edges_bidirected_lookup = {} len_cycle = [] # iterate for each cycle to find all the adjacencies that are related for cycle_id, cycle in enumerate(cycles): len_cycle.append(len(cycle)) cycle_lookup[cycle_id] = cycle # store a cycle lookup dic cycle_edges_bidirected_lookup[cycle_id] = to_edge_set_both_directions(cycle) # store all edges within the ring (consider both directions) cycle = to_edge_set(cycle) # change node ids to edge ids cycle_edges_lookup[cycle_id] = cycle # finding edges of cycles so i can check # whether there is some lower adjacency intersection (needed for 2->2 cell communication) # 1->1 cell communication (edge index pairs that are part of the same ring) for i in cycle: for j in cycle: if i != j: send_11.append(i) rec_11.append(j) # 1->2 cell communication for edge_id in cycle: send_12.append(edge_id) rec_12.append(cycle_id) coboundary_12.extend([send_12, rec_12]) # 1->2 cell communication edges_11.extend([send_11, rec_11]) # 2->2 cell communication cell_22 = [] send_22 = [] rec_22 = [] for send_cell in cycle_edges_lookup.keys(): send_cycle_edge = set(cycle_edges_lookup[send_cell]) for rec_cell in cycle_edges_lookup.keys(): # if it is the same ring continue if send_cell == rec_cell: continue rec_cycle_edge = set(cycle_edges_lookup[rec_cell]) if len(send_cycle_edge.intersection(rec_cycle_edge)) > 0: # if two rings have common edges send_22.append(send_cell) rec_22.append(rec_cell) cell_22.extend([send_22, rec_22]) nodes_00 = torch.tensor(nodes_00, dtype=torch.long) # node -> node edges_11 = torch.tensor(edges_11, dtype=torch.long) # edge -> edge of same ring coboundary_01 = torch.tensor(coboundary_01, dtype=torch.long) # node -> edge coboundary_12 = torch.tensor(coboundary_12, dtype=torch.long) # edge -> cycle cell_22 = torch.tensor(cell_22, dtype=torch.long) edge_lookup = edge_lookup # edge id -> edge (node pair) cycle_lookup = cycle_lookup # cycle id -> cycle (nodes that are part of the cycle) edge2id = edge2id # edge (node pair) -> edge id cycle_edges_lookup = cycle_edges_lookup # cycle id -> cycle (edges that are part of the cycle) cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup # consider all edges in one ring. helpful for 11 inv cycle_present = torch.tensor(cycle_present) adj = {} adj["adj_00"] = nodes_00 adj["adj_01"] = coboundary_01 adj["adj_11"] = edges_11 adj["adj_12"] = coboundary_12 adj["adj_22"] = cell_22 # cell22 dim is [2,0] when no cycle cc = cls(adj = adj, nodes_00 = nodes_00, edges_11 = edges_11, coboundary_01 = coboundary_01, coboundary_12 = coboundary_12, cell_22 = cell_22, edge_lookup = edge_lookup, cycle_lookup = cycle_lookup, edge2id = edge2id, cycle_edges_lookup = cycle_edges_lookup, cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup, cycle_present = cycle_present, cycle_total = cycle_total) return cc class CellularComplexData(Data): @classmethod def take_feat(cls, graph: nx.Graph, data: Data, cc: CellularComplex): """ Take features of 0-1-2 cell individually. """ # for nodes x_0 = data.x # for edges x_1 = torch.mean(x_0[cc.nodes_00], dim = 0) # for cycles x_2 = [] for cycle in list(cc.cycle_lookup.values()): x_2.append(x_0[cycle].to(dtype = torch.float32).mean(0)) if len(x_2) == 0: x_2 = torch.empty((0,11)) else: x_2 = torch.stack(x_2) x_dict = {"x_0": x_0, "x_1": x_1, "x_2": x_2} return x_dict """ The next functions are used to calculate the invariants for all the different communications. """ @staticmethod def find_ring_length(positions, vertices): # Create a dictionary mapping vertex to its position vertex_to_position = {vertex: position for vertex, position in zip(vertices, positions)} # Make sure vertices are cyclic cyclic_vertices = vertices + vertices[:1] # Calculate total length total_length = sum((vertex_to_position[cyclic_vertices[i+1]] - vertex_to_position[cyclic_vertices[i]]).norm() for i in range(len(vertices))) return total_length.item() @staticmethod def find_area_triangle(positions): # Form the vectors A, B, C = positions AB = B - A AC = C - A cross_product = np.cross(AB, AC) # magnitude area_parallelogram = np.linalg.norm(cross_product) # area of the triangle is half the area of the parallelogram area_triangle = area_parallelogram / 2 return area_triangle @staticmethod def ritters_bounding_sphere(points): # number of dimensions n = points.shape[1] # Get a point from the set, this is inside the sphere sphere_center = points[0] # Find the point P that is farthest away from initial_point distances = np.linalg.norm(points - sphere_center, axis=1) farthest_point = points[np.argmax(distances)] # Find the point Q that is farthest away from P distances = np.linalg.norm(points - farthest_point, axis=1) other_end_point = points[np.argmax(distances)] # Set the center of the sphere to be the midpoint of P and Q sphere_center = (farthest_point + other_end_point) / 2 #Set the radius of the sphere to be the distance between P and Q sphere_radius = np.linalg.norm(farthest_point - other_end_point) / 2 # Until all points are within the sphere, expand the sphere to include outliers for point in points: distance = np.linalg.norm(point - sphere_center) if distance > sphere_radius: # Set the radius to be the distance between the point and the sphere's center # and move the center half way towards the point sphere_radius = (sphere_radius + distance) / 2 sphere_center = sphere_center + (distance - sphere_radius) * (point - sphere_center) / distance return sphere_center, sphere_radius @staticmethod def calculate_angle(v1, v2): dot_product = np.dot(v1, v2) magnitude_v1 = np.linalg.norm(v1) magnitude_v2 = np.linalg.norm(v2) angle = np.arccos(np.clip(dot_product / (magnitude_v1 * magnitude_v2), -1, 1)) return angle @staticmethod def find_best_planes(points): coeff_planes = [] # each permutation of axes for axes in [[0, 1, 2], [0, 2, 1], [1, 2, 0]]: a = points[:, axes[0]] b = points[:, axes[1]] c = points[:, axes[2]] A = np.c_[a, b, np.ones(points.shape[0])] C, B, A = np.linalg.lstsq(A, c, rcond=None)[0] coeff_planes.append([A, B, C]) return coeff_planes @staticmethod def dist_to_best_planes(point, planes): distances = [] axes = [[0, 1, 2], [0, 2, 1], [1, 2, 0]] for i, (A, B, C) in enumerate(planes): distance = np.abs(A*point[axes[i][0]] + B*point[axes[i][1]] - point[axes[i][2]]) / np.sqrt(A**2 + B**2 + 1) distances.append(distance) distances = torch.tensor(distances) return distances.mean() @classmethod def take_inv(cls, graph: nx.Graph, data: Data, cc: CellularComplex): """ Use all the functions that were written above to assign for each adjacency their respective invariants """ pos = data.pos inv_dict = {} # for 00 and 01 send_00, rec_00 = cc.nodes_00 dist_00 = torch.linalg.norm(pos[send_00] - pos[rec_00], dim=1) dist_01 = dist_00[cc.coboundary_01[1]] inv_dict["dist_00"] = dist_00 inv_dict["dist_01"] = dist_01 if len(cc.cycle_lookup.values()) > 0: #for 11 and 12 center_coordinates = [] sphere_radius = [] sphere_center = [] send_11, rec_11 = cc.edges_11 edge1_v = [cc.edge_lookup[i.item()][0] for i in send_11] edge2_v = [cc.edge_lookup[i.item()][0] for i in rec_11] dist_11 = torch.linalg.norm(pos[edge1_v] - pos[edge2_v], dim=1) # distance for edge to edge communication inv_dict["dist_11"] = dist_11 ring_length = [] vol = [] area = [] angles = [] dist_to_sphere_center = [] mean_coordinate_ring = [] dist_to_mean_coord = [] best_planes = [] dist_to_best_planes_arr = [] radius_per_edge = [] sphere_center_arr = [] # Calculating area and volume for ring in list(cc.cycle_lookup.values()): vertex_indices = torch.tensor(ring) mean_coordinate = pos[vertex_indices].mean(dim=0) # centre of mass mean_coordinate_ring.append(mean_coordinate) # best_planes.append(CellularComplexData.find_best_planes(pos[vertex_indices])) center_coordinates.append(mean_coordinate.tolist()) ring_length.append(CellularComplexData.find_ring_length(pos, ring)) if len(vertex_indices)>3: hull = ConvexHull(pos[vertex_indices], qhull_options='QJ') vol.append(abs(hull.volume)) area.append(abs(hull.area)) elif len(vertex_indices)>2: vol.append(0) # vol is 0 if there are 3 points area.append(CellularComplexData.find_area_triangle(pos[vertex_indices])) else: vol.append(0) area.append(0) # sphere_center, radius = CellularComplexData.ritters_bounding_sphere(pos[vertex_indices]) distances = torch.norm(pos[vertex_indices] - mean_coordinate, dim=1) radius = torch.max(distances) sphere_radius.append(radius) # sphere_center_arr.append(sphere_center) # Calculate invariants for edges for i in range(len(cc.edges_11[0])): # Get edge pair edge1 = cc.edges_11[0][i] edge2 = cc.edges_11[1][i] # check in which ring is it list_biderected_ring_edges = list(cc.cycle_edges_bidirected_lookup.values()) find_ring_1 = [edge1 in sublist for sublist in list_biderected_ring_edges] find_ring_2 = [edge2 in sublist for sublist in list_biderected_ring_edges] first_true_index = next(i for i, (a, b) in enumerate(zip(find_ring_1, find_ring_2)) if a and b) # find first true index edge1_vertices = cc.edge_lookup[edge1.item()] # take invariants. pick the first ring that the two edges are present # dist_to_sphere_center.append(torch.norm(pos[edge1_vertices[0]] - sphere_center_arr[first_true_index])) dist_to_mean_coord.append(torch.norm(pos[edge1_vertices[0]] - mean_coordinate_ring[first_true_index])) # dbp = CellularComplexData.dist_to_best_planes(pos[edge1_vertices[0]], best_planes[first_true_index]) # dist_to_best_planes_arr.append(dbp.item()) radius_per_edge.append(sphere_radius[first_true_index]) edge2_vertices = cc.edge_lookup[edge2.item()] edge1_pos = pos[edge1_vertices] edge2_pos = pos[edge2_vertices] v1 = edge1_pos[1] - edge1_pos[0] v2 = edge2_pos[1] - edge2_pos[0] angle = CellularComplexData.calculate_angle(v1, v2) angles.append(np.degrees(angle)) # invariants for 11 communication (same size with cc.edges_11 (adj 1->1 that share upper)) angles = torch.tensor(angles) dist_to_sphere_center = torch.tensor(dist_to_sphere_center) dist_to_mean_coord = torch.tensor(dist_to_mean_coord) dist_to_best_planes_arr = torch.tensor(dist_to_best_planes_arr) radius_per_edge = torch.tensor(radius_per_edge) inv_dict["angles"] = angles inv_dict["dist_to_sphere_center"] = dist_to_sphere_center inv_dict["dist_to_mean_coord"] = dist_to_mean_coord inv_dict["dist_to_best_planes"] = dist_to_best_planes_arr inv_dict["radius_per_edge"] = radius_per_edge # for 1->2 communication (same size with cc.coboundary_12 (adj 1->2)) vol_12 = [] area_12 = [] radius_12 = [] for j in range(len(cc.coboundary_12[1])): ring_index = cc.coboundary_12[1][j] vol_12.append(vol[ring_index]) area_12.append(area[ring_index]) # radius_12.append(sphere_radius[ring_index]) # invariants for 1->2 communication vol_12 = torch.tensor(vol_12) area_12 = torch.tensor(area_12) radius_12 = torch.tensor(radius_12) inv_dict["vol_12"] = vol_12 inv_dict["area_12"] = area_12 inv_dict["radius_12"] = radius_12 for k in ["dist_11", "angles", "dist_to_sphere_center", "dist_to_mean_coord", "dist_to_best_planes", "radius_per_edge", "vol_12", "area_12", "radius_12"]: if k not in inv_dict: inv_dict[k] = torch.empty(0) return inv_dict @classmethod def from_data_cc_pair(cls, data: Data, cc: CellularComplex, x_dict:dict, inv_dict:dict): """ Unpack everything so I can use it in the model later. """ for k, v in x_dict.items(): data[k] = v for k, v in cc.adj.items(): data[k] = v for k, v in inv_dict.items(): data[k] = v data["graph_len"] = data.x.size(0) data["cycle_pres"] = cc.cycle_present data["cycle_total"] = cc.cycle_total for att in ['edge_attr', 'idx', 'name', 'z']: if hasattr(data, att): data.pop(att) mapping = data.items()._mapping return cls(**mapping) def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: if 'adj' in key or "edge_index" in key: return 1 else: return 0 def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: """ Perform advanced increment for different adjacencies as the 0-0 cell communication gets incremented in a different way from 1->2 communication. """ if 'adj' in key: i, j = key[4], key[5] return torch.tensor([[getattr(self, f'x_{i}').size(0)], [getattr(self, f'x_{j}').size(0)]]) elif "edge_index" in key: return torch.tensor([[getattr(self, 'x_0').size(0)], [getattr(self, 'x_0').size(0)]]) else: return super().__inc__(key, value, *args, **kwargs) class LiftGraphToCC(BaseTransform): def __call__(self, data: Data) -> CellularComplexData: graph = nx.Graph() # graph.add_nodes_from(range(data.num_nodes)) # if you want to connect more conn = RadiusGraph(2)(data) if len(data.edge_index[0]) < len(conn.edge_index[0]): combined_edge_index = torch.cat((data.edge_index, conn.edge_index), dim=1) data.edge_index = torch.tensor(list(set(tuple(i.tolist()) for i in combined_edge_index.T)), dtype=torch.long).T data.edge_index, _ = remove_self_loops(data.edge_index) graph.add_nodes_from(range(data.x.size(0))) graph.add_edges_from(data.edge_index.T.tolist()) cc = CellularComplex.from_nx_graph(graph, data) x_dict = CellularComplexData.take_feat(graph, data, cc) inv_dict = CellularComplexData.take_inv(graph, data, cc) new_data = CellularComplexData.from_data_cc_pair(data, cc, x_dict, inv_dict) return new_data
Editor is loading...