Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
24 kB
1
Indexable
Never
from typing import Optional, List, Tuple, Dict
from itertools import combinations

import torch
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.transforms import BaseTransform
from typing import Any
from scipy.spatial import ConvexHull
import numpy as np
from torch_geometric.transforms import RadiusGraph
from torch_geometric.utils import remove_self_loops

class CellularComplex:
    def __init__(self,
                adj: dict, 
                nodes_00: torch.Tensor, 
                edges_11: torch.Tensor,
                coboundary_01: torch.Tensor, 
                coboundary_12: torch.Tensor,
                cell_22: torch.Tensor,
                edge_lookup: dict,
                cycle_lookup: dict,
                edge2id: dict,
                cycle_edges_lookup: dict,
                cycle_edges_bidirected_lookup: dict,
                cycle_present: int,
                cycle_total: int ):

        """
        Cellular Complex
        """
        self.adj = adj
        self.coboundary_01 = coboundary_01
        self.coboundary_12 = coboundary_12
        self.nodes_00 = nodes_00
        self.edges_11 = edges_11
        self.cell_22 = cell_22
        self.edge_lookup = edge_lookup
        self.cycle_lookup = cycle_lookup
        self.edge2id = edge2id
        self.cycle_edges_lookup = cycle_edges_lookup
        self.cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup
        self.cycle_present = cycle_present
        self.cycle_total = cycle_total
        


    @classmethod
    def from_nx_graph(cls, graph: nx.Graph, data: Data):

        """
        This function is used to find all the adjacencies between 0->0 communication, 0->1, 1->1, 1->2, 2->2.
        Also it returns the lookup dictionaries for each edge and cycle that might be needed when calculating the invariants. For cycles, only one orientation is considered.
        
        """

        nodes_00 = []
        coboundary_01 = []
        edges_11 = []
        coboundary_12 = []

        edge_lookup = {}
        edge2id = {}

        send_00 = [] # send of vertex to vertex
        rec_00 = [] # rec of vertex to vertex
        send_01 = [] # send of vertex to edge
        rec_01 = [] # rec of vertex to edge
        digraph = graph.to_directed()
        cycle_present = 0
        cycle_total = 0

        cycles = [cycle for cycle in nx.simple_cycles(digraph) if len(cycle) > 2] # find cycles
        

        # if there are cycles, check and delete different orientations. keep only 1 per cycle
        if len(cycles) != 0:
            cycle_present += 1
            dict_lists = {tuple(sorted(lst)): lst for lst in cycles}
            unique_lists = list(dict_lists.values())
            cycles = unique_lists
        
        # if there are no cycles, conenct the two closest vertices until a cycle is formed
        if len(cycles) == 0:
            dist_matrix = torch.cdist(data.pos, data.pos)
            dist_matrix.fill_diagonal_(float('inf'))
            for edge in data.edge_index.t():
                dist_matrix[edge[0], edge[1]] = float('inf')
                dist_matrix[edge[1], edge[0]] = float('inf')

            while True:
                # Find the pair of nodes with the smallest distance
                min_val = torch.min(dist_matrix)
                min_index = torch.where(dist_matrix == min_val)
                node1, node2 = min_index[0][0], min_index[1][0]

                # Set the distance between these nodes to infinity so they won't be selected again
                dist_matrix[node1, node2] = float('inf')
                dist_matrix[node2, node1] = float('inf')

                # Add an edge between these nodes
                data.edge_index = torch.cat([data.edge_index, torch.tensor([[node1, node2], [node2, node1]], dtype=torch.long)], dim=1)

                graph_new = nx.Graph()
                graph_new.add_nodes_from(range(data.x.size(0)))
                graph_new.add_edges_from(data.edge_index.T.tolist())
                digraph_new = graph_new.to_directed()
                digraph = digraph_new
                graph = graph_new

                cycles = [cycle for cycle in nx.simple_cycles(digraph_new) if len(cycle) > 2] # find cycles
                if len(cycles) !=0:
                    dict_lists = {tuple(sorted(lst)): lst for lst in cycles}
                    unique_lists = list(dict_lists.values())
                    cycles = unique_lists
                    break
        cycle_total = len(cycles)

        for edge_id, edge in enumerate(digraph.edges): # iterate through all the edges 

            edge_lookup[edge_id] = [edge[0], edge[1]] # store a edge lookup dic
            edge2id[edge[0], edge[1]] = edge_id

            send_00.append(edge[0]) # append 00 and 01
            rec_00.append(edge[1])
            send_01.append(edge[0])
            rec_01.append(edge_id)
            send_01.append(edge[1])
            rec_01.append(edge_id)
                
        nodes_00.extend([send_00, rec_00]) # 0->0 cell communication
        coboundary_01.extend([send_01, rec_01]) # 0->1 cell communcation



        # take nodes that form a cycle and return their edges
        def to_edge_set(cycle):
            edges = []
            n_edges = len(cycle)
            for node_id, node in enumerate(cycle):
                next_node = cycle[(node_id + 1) % n_edges]
                assert digraph.has_edge(node, next_node)
                edge = edge2id[node, next_node]
                edges.append(edge)
            return edges
        
        # same function as above but in both directions (needed for later)
        def to_edge_set_both_directions(cycle):
            edges = []
            n_edges = len(cycle)
            for node_id, node in enumerate(cycle):
                next_node = cycle[(node_id + 1) % n_edges]
                
                if digraph.has_edge(node, next_node):
                    edge = edge2id[node, next_node]
                    edges.append(edge)
                    
                if digraph.has_edge(next_node, node):
                    edge = edge2id[next_node, node]
                    edges.append(edge)
                    
            return edges




        send_11 = [] # send edge to edge communication over upper adj
        rec_11 = [] # rec edge to edge communication over upper adj
        send_12 = [] # send of edge to ring 
        rec_12 = [] # rec of edge to ring
        cycle_lookup = {}
        cycle_edges_lookup = {}
        cycle_edges_bidirected_lookup = {}
        len_cycle = []


        # iterate for each cycle to find all the adjacencies that are related
        for cycle_id, cycle in enumerate(cycles): 
            len_cycle.append(len(cycle))
            cycle_lookup[cycle_id] = cycle # store a cycle lookup dic
            cycle_edges_bidirected_lookup[cycle_id] = to_edge_set_both_directions(cycle)  # store all edges within the ring (consider both directions)
            cycle = to_edge_set(cycle) # change node ids to edge ids
            cycle_edges_lookup[cycle_id] = cycle # finding edges of cycles so i can check
                                            # whether there is some lower adjacency intersection (needed for 2->2 cell communication)


            # 1->1 cell communication (edge index pairs that are part of the same ring)
            for i in cycle:
                for j in cycle:
                    if i != j:
                        send_11.append(i) 
                        rec_11.append(j)


            # 1->2 cell communication
            for edge_id in cycle:
                send_12.append(edge_id)
                rec_12.append(cycle_id)

        coboundary_12.extend([send_12, rec_12]) #  1->2 cell communication 
        edges_11.extend([send_11, rec_11])



        # 2->2 cell communication
        cell_22 = []
        send_22 = []
        rec_22 = []

        for send_cell in cycle_edges_lookup.keys():
            send_cycle_edge = set(cycle_edges_lookup[send_cell])
            for rec_cell in cycle_edges_lookup.keys():
                # if it is the same ring continue
                if send_cell == rec_cell:
                    continue
                rec_cycle_edge = set(cycle_edges_lookup[rec_cell])
                if len(send_cycle_edge.intersection(rec_cycle_edge)) > 0: # if two rings have common edges
                    send_22.append(send_cell)
                    rec_22.append(rec_cell)

        cell_22.extend([send_22, rec_22])
                    

        nodes_00 = torch.tensor(nodes_00, dtype=torch.long) # node -> node
        edges_11 = torch.tensor(edges_11, dtype=torch.long) # edge -> edge of same ring
        coboundary_01 = torch.tensor(coboundary_01, dtype=torch.long) # node -> edge
        coboundary_12 = torch.tensor(coboundary_12, dtype=torch.long) # edge -> cycle 
        cell_22 = torch.tensor(cell_22, dtype=torch.long)
        edge_lookup = edge_lookup # edge id -> edge (node pair)
        cycle_lookup = cycle_lookup # cycle id -> cycle (nodes that are part of the cycle) 
        edge2id = edge2id # edge (node pair) -> edge id
        cycle_edges_lookup = cycle_edges_lookup # cycle id -> cycle (edges that are part of the cycle)
        cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup # consider all edges in one ring. helpful for 11 inv
        cycle_present = torch.tensor(cycle_present)


        adj = {}
        adj["adj_00"] = nodes_00
        adj["adj_01"] = coboundary_01
        adj["adj_11"] = edges_11
        adj["adj_12"] = coboundary_12
        adj["adj_22"] = cell_22 # cell22 dim is [2,0] when no cycle



        cc = cls(adj = adj,
                nodes_00 = nodes_00,
                edges_11 = edges_11,
                coboundary_01 = coboundary_01, 
                coboundary_12 = coboundary_12,
                cell_22 = cell_22,
                edge_lookup = edge_lookup,
                cycle_lookup = cycle_lookup,
                edge2id = edge2id,
                cycle_edges_lookup = cycle_edges_lookup,
                cycle_edges_bidirected_lookup = cycle_edges_bidirected_lookup,
                cycle_present = cycle_present,
                cycle_total = cycle_total)

        return cc



class CellularComplexData(Data):
    

    @classmethod
    def take_feat(cls, graph: nx.Graph, data: Data, cc: CellularComplex):
        """
        Take features of 0-1-2 cell individually.
        """

        # for nodes
        x_0 = data.x

        # for edges
        x_1 = torch.mean(x_0[cc.nodes_00], dim = 0)
        
        # for cycles
        x_2 = []
        for cycle in list(cc.cycle_lookup.values()):
            x_2.append(x_0[cycle].to(dtype = torch.float32).mean(0))
        
        if len(x_2) == 0:
            x_2 = torch.empty((0,11))
        else:
            x_2 = torch.stack(x_2)
        x_dict = {"x_0": x_0, "x_1": x_1, "x_2": x_2}

        return x_dict

    """
    The next functions are used to calculate the invariants for all the different communications.
    """


    @staticmethod
    def find_ring_length(positions, vertices):
        # Create a dictionary mapping vertex to its position
        vertex_to_position = {vertex: position for vertex, position in zip(vertices, positions)}
        
        # Make sure vertices are cyclic
        cyclic_vertices = vertices + vertices[:1]
        
        # Calculate total length
        total_length = sum((vertex_to_position[cyclic_vertices[i+1]] - vertex_to_position[cyclic_vertices[i]]).norm() for i in range(len(vertices)))

        return total_length.item()


    @staticmethod
    def find_area_triangle(positions):
        # Form the vectors
        A, B, C = positions
        AB = B - A
        AC = C - A

        cross_product = np.cross(AB, AC)

        # magnitude 
        area_parallelogram = np.linalg.norm(cross_product)

        # area of the triangle is half the area of the parallelogram
        area_triangle = area_parallelogram / 2

        return area_triangle


    @staticmethod
    def ritters_bounding_sphere(points):
        # number of dimensions
        n = points.shape[1]
        
        # Get a point from the set, this is inside the sphere
        sphere_center = points[0]
        
        # Find the point P that is farthest away from initial_point
        distances = np.linalg.norm(points - sphere_center, axis=1)
        farthest_point = points[np.argmax(distances)]
        
        # Find the point Q that is farthest away from P
        distances = np.linalg.norm(points - farthest_point, axis=1)
        other_end_point = points[np.argmax(distances)]
        
        # Set the center of the sphere to be the midpoint of P and Q
        sphere_center = (farthest_point + other_end_point) / 2
        
        #Set the radius of the sphere to be the distance between P and Q
        sphere_radius = np.linalg.norm(farthest_point - other_end_point) / 2
        
        # Until all points are within the sphere, expand the sphere to include outliers
        for point in points:
            distance = np.linalg.norm(point - sphere_center)
            if distance > sphere_radius:
                # Set the radius to be the distance between the point and the sphere's center
                # and move the center half way towards the point
                sphere_radius = (sphere_radius + distance) / 2
                sphere_center = sphere_center + (distance - sphere_radius) * (point - sphere_center) / distance
        
        return sphere_center, sphere_radius



    @staticmethod
    def calculate_angle(v1, v2):
        dot_product = np.dot(v1, v2)

        magnitude_v1 = np.linalg.norm(v1)
        magnitude_v2 = np.linalg.norm(v2)

        angle = np.arccos(np.clip(dot_product / (magnitude_v1 * magnitude_v2), -1, 1))

        return angle

    
    @staticmethod
    def find_best_planes(points):
        coeff_planes = []

        #  each permutation of axes
        for axes in [[0, 1, 2], [0, 2, 1], [1, 2, 0]]:
            
            a = points[:, axes[0]]
            b = points[:, axes[1]]
            c = points[:, axes[2]]

            A = np.c_[a, b, np.ones(points.shape[0])]
            C, B, A = np.linalg.lstsq(A, c, rcond=None)[0]

            coeff_planes.append([A, B, C])

        return coeff_planes

    
    @staticmethod
    def dist_to_best_planes(point, planes):
        distances = []
        axes = [[0, 1, 2], [0, 2, 1], [1, 2, 0]]

        for i, (A, B, C) in enumerate(planes):

            distance = np.abs(A*point[axes[i][0]] + B*point[axes[i][1]] - point[axes[i][2]]) / np.sqrt(A**2 + B**2 + 1)
            distances.append(distance)

        distances = torch.tensor(distances)
        return distances.mean()





    @classmethod
    def take_inv(cls, graph: nx.Graph, data: Data, cc: CellularComplex):

        """
        Use all the functions that were written above to assign for each adjacency their respective invariants
        """

        pos = data.pos
        inv_dict = {}

        # for 00 and 01
        send_00, rec_00 = cc.nodes_00
        dist_00 = torch.linalg.norm(pos[send_00] - pos[rec_00], dim=1)
        dist_01 = dist_00[cc.coboundary_01[1]]

        inv_dict["dist_00"] = dist_00
        inv_dict["dist_01"] = dist_01


        if len(cc.cycle_lookup.values()) > 0:
            #for 11 and 12
            center_coordinates = []
            sphere_radius = []
            sphere_center = []
            send_11, rec_11 = cc.edges_11
            edge1_v = [cc.edge_lookup[i.item()][0] for i in send_11]
            edge2_v = [cc.edge_lookup[i.item()][0] for i in rec_11]
            dist_11 = torch.linalg.norm(pos[edge1_v] - pos[edge2_v], dim=1) # distance for edge to edge communication
            inv_dict["dist_11"] = dist_11

            ring_length = []
            vol = []
            area = []
            angles = []
            dist_to_sphere_center = []
            mean_coordinate_ring = []
            dist_to_mean_coord = []
            best_planes = []
            dist_to_best_planes_arr = []
            radius_per_edge = []
            sphere_center_arr = []


            # Calculating area and volume
            for ring in list(cc.cycle_lookup.values()):

                vertex_indices = torch.tensor(ring)

                mean_coordinate = pos[vertex_indices].mean(dim=0) # centre of mass
                mean_coordinate_ring.append(mean_coordinate)
                
                # best_planes.append(CellularComplexData.find_best_planes(pos[vertex_indices]))
                center_coordinates.append(mean_coordinate.tolist())

                ring_length.append(CellularComplexData.find_ring_length(pos, ring))
                if len(vertex_indices)>3:
                    hull = ConvexHull(pos[vertex_indices], qhull_options='QJ')
                    vol.append(abs(hull.volume))
                    area.append(abs(hull.area))
                elif len(vertex_indices)>2:
                    vol.append(0) # vol is 0 if there are 3 points
                    area.append(CellularComplexData.find_area_triangle(pos[vertex_indices]))
                else:
                    vol.append(0)
                    area.append(0)
                
                # sphere_center, radius = CellularComplexData.ritters_bounding_sphere(pos[vertex_indices])
                
                distances = torch.norm(pos[vertex_indices] - mean_coordinate, dim=1)
                radius = torch.max(distances)
                sphere_radius.append(radius)
                

                # sphere_center_arr.append(sphere_center)
                
                

            # Calculate invariants for edges
            for i in range(len(cc.edges_11[0])):
                # Get edge pair
                edge1 = cc.edges_11[0][i]
                edge2 = cc.edges_11[1][i]

                # check in which ring is it
                list_biderected_ring_edges = list(cc.cycle_edges_bidirected_lookup.values())
                find_ring_1 = [edge1 in sublist for sublist in list_biderected_ring_edges]
                find_ring_2 = [edge2 in sublist for sublist in list_biderected_ring_edges]
                first_true_index = next(i for i, (a, b) in enumerate(zip(find_ring_1, find_ring_2)) if a and b) # find first true index 
 
    
                edge1_vertices = cc.edge_lookup[edge1.item()]
                # take invariants. pick the first ring that the two edges are present
                # dist_to_sphere_center.append(torch.norm(pos[edge1_vertices[0]] - sphere_center_arr[first_true_index]))
                dist_to_mean_coord.append(torch.norm(pos[edge1_vertices[0]] - mean_coordinate_ring[first_true_index]))
                # dbp = CellularComplexData.dist_to_best_planes(pos[edge1_vertices[0]], best_planes[first_true_index])
                # dist_to_best_planes_arr.append(dbp.item())
                radius_per_edge.append(sphere_radius[first_true_index])

                edge2_vertices = cc.edge_lookup[edge2.item()]

                edge1_pos = pos[edge1_vertices]
                edge2_pos = pos[edge2_vertices]

                v1 = edge1_pos[1] - edge1_pos[0]
                v2 = edge2_pos[1] - edge2_pos[0]

                angle = CellularComplexData.calculate_angle(v1, v2)
                angles.append(np.degrees(angle))

            # invariants for 11 communication (same size with cc.edges_11 (adj 1->1 that share upper))
            angles = torch.tensor(angles)
            dist_to_sphere_center = torch.tensor(dist_to_sphere_center)
            dist_to_mean_coord = torch.tensor(dist_to_mean_coord)
            dist_to_best_planes_arr = torch.tensor(dist_to_best_planes_arr)
            radius_per_edge = torch.tensor(radius_per_edge)


            inv_dict["angles"] = angles
            inv_dict["dist_to_sphere_center"] = dist_to_sphere_center
            inv_dict["dist_to_mean_coord"] = dist_to_mean_coord
            inv_dict["dist_to_best_planes"] = dist_to_best_planes_arr
            inv_dict["radius_per_edge"] = radius_per_edge

            

            # for 1->2 communication (same size with cc.coboundary_12 (adj 1->2))
            vol_12 = []
            area_12 = []
            radius_12 = []

            for j in range(len(cc.coboundary_12[1])):
                ring_index = cc.coboundary_12[1][j]
                vol_12.append(vol[ring_index])
                area_12.append(area[ring_index])
                # radius_12.append(sphere_radius[ring_index])

            # invariants for 1->2 communication
            vol_12 = torch.tensor(vol_12)
            area_12 = torch.tensor(area_12)
            radius_12 = torch.tensor(radius_12)


            inv_dict["vol_12"] = vol_12
            inv_dict["area_12"] = area_12
            inv_dict["radius_12"] = radius_12



        for k in ["dist_11", "angles", "dist_to_sphere_center", "dist_to_mean_coord", "dist_to_best_planes", "radius_per_edge", "vol_12", "area_12", "radius_12"]:
            if k not in inv_dict:
                inv_dict[k] = torch.empty(0)
        return inv_dict



    @classmethod
    def from_data_cc_pair(cls, data: Data, cc: CellularComplex, x_dict:dict, inv_dict:dict):

        """
        Unpack everything so I can use it in the model later.
        """

        for k, v in x_dict.items():
            data[k] = v
        
        for k, v in cc.adj.items():
            data[k] = v

        for k, v in inv_dict.items():
            data[k] = v
        
        data["graph_len"] = data.x.size(0)
        data["cycle_pres"] = cc.cycle_present
        data["cycle_total"] = cc.cycle_total

        for att in ['edge_attr', 'idx', 'name', 'z']:
            if hasattr(data, att):
                data.pop(att)
    
        mapping = data.items()._mapping
        return cls(**mapping)

    def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any:

        if 'adj' in key or "edge_index" in key:
            return 1
        else:
            return 0
        
    def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any:
        """
        Perform advanced increment for different adjacencies as the 0-0 cell communication
        gets incremented in a different way from 1->2 communication.
        """

        if 'adj' in key:
            i, j = key[4], key[5]
            return torch.tensor([[getattr(self, f'x_{i}').size(0)], [getattr(self, f'x_{j}').size(0)]])
        elif "edge_index" in key:
            return torch.tensor([[getattr(self, 'x_0').size(0)], [getattr(self, 'x_0').size(0)]])
        else:
            return super().__inc__(key, value, *args, **kwargs)



class LiftGraphToCC(BaseTransform):
    def __call__(self, data: Data) -> CellularComplexData:
        graph = nx.Graph()
        # graph.add_nodes_from(range(data.num_nodes))
        
        # if you want to connect more
        conn = RadiusGraph(2)(data)
        if len(data.edge_index[0]) < len(conn.edge_index[0]):
            combined_edge_index = torch.cat((data.edge_index, conn.edge_index), dim=1)
            data.edge_index = torch.tensor(list(set(tuple(i.tolist()) for i in combined_edge_index.T)), dtype=torch.long).T
            data.edge_index, _ = remove_self_loops(data.edge_index)

        graph.add_nodes_from(range(data.x.size(0)))
        graph.add_edges_from(data.edge_index.T.tolist())
        cc = CellularComplex.from_nx_graph(graph, data)
        x_dict = CellularComplexData.take_feat(graph, data, cc)
        inv_dict = CellularComplexData.take_inv(graph, data, cc)
        new_data = CellularComplexData.from_data_cc_pair(data, cc, x_dict, inv_dict)
        return new_data