prim

 avatar
unknown
plain_text
2 years ago
3.9 kB
4
Indexable
import java.io.*;
import java.util.*;
import java.util.stream.*;
import static java.util.stream.Collectors.toList;

class Pair implements Comparable<Pair> {
    public int vertex;
    public int weight;

    Pair(int vertex, int weight) {
        this.vertex = vertex;
        this.weight = weight;
    }
    
    @Override
    public int compareTo(Pair b) {
        return (b.weight < this.weight) ? 1 : -1;
    }
}

class Graph {
    private List<List<Pair>> map = new ArrayList<>();

    public void addVertices(Integer n) {
        for (int i = 0; i < n; i++) {
            map.add(new ArrayList<Pair>());
        }
    }

    public void addEdge(Integer v1, Integer v2, int weight) {
        map.get(v1).add(new Pair(v2, weight));
        map.get(v2).add(new Pair(v1, weight));
    }

    public int prim(int v, int n) {
        PriorityQueue<Pair> queue = new PriorityQueue<>();
        boolean[] inMST = new boolean[n];
        int[] keys = new int[n]; // weights of vertices in MST
        int total = 0; // store the weight of MST

        // Set keys to infinity
        for (int i = 0; i < n; i++) {
            keys[i] = Integer.MAX_VALUE;
        }

        queue.add(new Pair(v, 0));
        keys[v] = 0;

        while (!queue.isEmpty()) {
            Pair curVertex = queue.poll();

            // Skip added vertices
            if (inMST[curVertex.vertex]) {
                continue;
            }

            // Add current vertex to the MST and update the total weight
            inMST[curVertex.vertex] = true;
            total += curVertex.weight;

            // Loop through the adjacent vertices with the current vertex
            for (Pair adjVertex : map.get(curVertex.vertex)) {
                // If the adj vertex is not in the MST and its weight < key
                // 1. Set key = weight
                // 2. Add the adj vertex to queue
                if (!inMST[adjVertex.vertex] 
                        && adjVertex.weight < keys[adjVertex.vertex]) {
                    keys[adjVertex.vertex] = adjVertex.weight;
                    queue.add(adjVertex);
                }
            }
        }

        return total;
    }
}

class Result {
    public static int prims(int n, List<List<Integer>> edges, int s) {
        Graph graph = new Graph();

        graph.addVertices(n);
        
        for (List<Integer> edge: edges) {
            graph.addEdge(edge.get(0) - 1, edge.get(1) - 1, edge.get(2));
        }

        return graph.prim(s - 1, n);
    }
}

public class Solution {
    public static void main(String[] args) throws IOException {
        BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

        String[] firstMultipleInput = bufferedReader.readLine().replaceAll("\\s+$", "").split(" ");

        int n = Integer.parseInt(firstMultipleInput[0]);

        int m = Integer.parseInt(firstMultipleInput[1]);

        List<List<Integer>> edges = new ArrayList<>();

        IntStream.range(0, m).forEach(i -> {
            try {
                edges.add(
                        Stream.of(bufferedReader.readLine().replaceAll("\\s+$", "").split(" "))
                                .map(Integer::parseInt)
                                .collect(toList()));
            } catch (IOException ex) {
                throw new RuntimeException(ex);
            }
        });

        int start = Integer.parseInt(bufferedReader.readLine().trim());

        int result = Result.prims(n, edges, start);

        bufferedWriter.write(String.valueOf(result));
        bufferedWriter.newLine();

        bufferedReader.close();
        bufferedWriter.close();
    }
}
Editor is loading...