PS/Tree

백준 1167번: 트리의 지름 (Java)

닻과매 2022. 6. 10. 17:34

https://www.acmicpc.net/problem/1167

 

1167번: 트리의 지름

트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2 ≤ V ≤ 100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. 정점 번호는 1부터 V까지

www.acmicpc.net

문제

트리의 지름이란, 트리에서 임의의 두 점 사이의 거리 중 가장 긴 것을 말한다. 트리의 지름을 구하는 프로그램을 작성하시오.

입력

트리가 입력으로 주어진다. 먼저 첫 번째 줄에서는 트리의 정점의 개수 V가 주어지고 (2 ≤ V ≤ 100,000)둘째 줄부터 V개의 줄에 걸쳐 간선의 정보가 다음과 같이 주어진다. 정점 번호는 1부터 V까지 매겨져 있다.

먼저 정점 번호가 주어지고, 이어서 연결된 간선의 정보를 의미하는 정수가 두 개씩 주어지는데, 하나는 정점번호, 다른 하나는 그 정점까지의 거리이다. 예를 들어 네 번째 줄의 경우 정점 3은 정점 1과 거리가 2인 간선으로 연결되어 있고, 정점 4와는 거리가 3인 간선으로 연결되어 있는 것을 보여준다. 각 줄의 마지막에는 -1이 입력으로 주어진다. 주어지는 거리는 모두 10,000 이하의 자연수이다.

출력

첫째 줄에 트리의 지름을 출력한다.

 


 

풀이

풀이 1. 내 풀이

모든 node에 대해 순회하면서 '해당 node를 root로 하는 subtree에서의 지름'을 구한다. 만약 해당 node의 자식이 0개이면 0일 것이고, 1이면 leaf node까지의 거리, 2 이상이면 '가장 긴 거리' + '두 번째로 긴 거리'가 될 것이다. 이 과정을 dfs 재귀 한 번에 볼 수 있다. 일반적으로 트리의 지름을 이렇게 구하진 않는 거 같던데, 내가 스스로 생각해낸 풀이라 그런지 마음에 든다. 다만, 특정 샘플에 대해서는 (상위 2개의 원소를 구하기 위해) 정렬하는 과정이 시간을 많이 잡아먹을지도?

 

풀이 2. 트리의 지름을 구하는 공식

임의의 node를 선택하여, 해당 node에서 가장 먼 node x를 선택한다. 그리고 x에서 가장 먼 점 y를 선택하면, 두 점 사이의 거리 d(x, y)가 트리의 지름이 된다.

증명: https://blog.myungwoo.kr/112를 참고. 매우 깔끔한 증명이다.

다만, 증명 과정 중 iii - a의 경우, 조금 자세히 설명을 적자면

d(t, y) > max(d(t, u), d(t, v))이면 모순

d(t, y) = max(d(t, u), d(t, v))인 경우

 1) d(t, u) = d(t, v)이면 d(y, z) = d(u, v)가 되면서 성립

 2) d(t, u) != d(t, v)이면 모순

으로 증명하면 된다.

 

코드

풀이 1

import java.io.*;
import java.util.*;

public class Main {
    static int N, ans = 0;
    static List<List<Node>> tree;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        N = Integer.parseInt(br.readLine());
        tree = new ArrayList<>();
        for (int i = 0; i <= N; i++) {
            tree.add(new ArrayList<Node>());
        }

        for (int i = 1; i <= N; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v;
            while ((v=Integer.parseInt(st.nextToken())) != -1) {
                int dist = Integer.parseInt(st.nextToken());
                tree.get(u).add(new Node(v, dist));
            }
        }

        dfs(1, 0);
        System.out.println(ans);
    }

    // dfs로 순회하면서 각 점을 지나는 가장 긴 경로의 길이를 구하는 method
    static int dfs(int cur, int prev) {
        int ret = 0;
        int temp = 0;
        List<Integer> list = new ArrayList<>();
        for (Node nxt: tree.get(cur)) {
            if (nxt.num == prev) continue;
            list.add(nxt.dist+dfs(nxt.num, cur));
        }
        Collections.sort(list);
        if (list.size() > 0) {
            ret = list.get(list.size()-1);
            temp = ret;
        } if (list.size() > 1) {
            temp += list.get(list.size()-2);
        }
        if (temp > ans) ans = temp;
        return ret;
    }

    // 점의 번호와 거리를 저장하는 class
    static class Node {
        int num;
        int dist;

        public Node(int num, int dist) {
            this.num = num;
            this.dist = dist;
        }
    }
}

 

풀이 2

import java.io.*;
import java.util.*;

public class Main {
    static BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
    static StringTokenizer st;
    static int n;
    static List<List<Node>> tree = new ArrayList<>();
    static boolean[] visited;
    static int val, maxDist = Integer.MIN_VALUE;

    public static void main(String[] args) throws IOException {
        n = Integer.parseInt(br.readLine());
        for (int i = 0; i <= n; i++) {
            tree.add(new ArrayList<>());
        }
        for (int i = 1; i <= n; i++) {
            st = new StringTokenizer(br.readLine());
            int start = Integer.parseInt(st.nextToken()), end;
            while ((end = Integer.parseInt(st.nextToken())) != -1) {
                tree.get(start).add(new Node(end, Integer.parseInt(st.nextToken())));
            }
        }
        visited = new boolean[n + 1];
        dfs(1, 0);
        visited = new boolean[n + 1];
        maxDist = Integer.MIN_VALUE;
        dfs(val, 0);
        System.out.println(maxDist);
    }

    private static void dfs(int cur, int dist) {
        visited[cur] = true;
        if (dist > maxDist) {
            val = cur;
            maxDist = dist;
        }
        for (Node nxt : tree.get(cur)) {
            if (nxt.dist != 0 && !visited[nxt.num]) {
                dfs(nxt.num, dist + nxt.dist);
            }
        }
    }

    static class Node {
        int num;
        int dist;

        public Node(int num, int dist) {
            this.num = num;
            this.dist = dist;
        }
    }
}