PS/Binary Search Tree

백준 1539번: 이진 검색 트리 (JAVA)

닻과매 2022. 5. 3. 22:23

https://www.acmicpc.net/problem/1539

 

1539번: 이진 검색 트리

P는 크기가 N인 배열이다. P에는 0보다 크거나 같고, N-1보다 작거나 같은 정수가 중복 없이 채워져 있다. 이진 검색 트리는 루트가 있는 이진 트리로, 각각의 노드에 정수 값이 저장되어 있는 트

www.acmicpc.net

문제

P는 크기가 N인 배열이다. P에는 0보다 크거나 같고, N-1보다 작거나 같은 정수가 중복 없이 채워져 있다. 이진 검색 트리는 루트가 있는 이진 트리로, 각각의 노드에 정수 값이 저장되어 있는 트리이다. 이진 검색 트리를 P배열을 이용해서 만드는 법은 다음과 같다. 일단 root를 만들고 거기에 P[0]의 값을 넣은 후에 다음과 같은 과정을 거친다.

for (int i=1; i<=n-1; i++) {
    insert(root, P[i]);
}

여기서 insert함수는 다음과 같다.

void insert(Vertex V, int X) {
    if (x < V에 저장되어 있는 수) {
        if (V가 왼쪽 자식이 있으면) {
            insert(V의 왼쪽 자식, X);
        } else {
            V의 왼쪽 자식을 새로 만들고, 그 곳에 X를 저장함
        }
    } else {
        if (V가 오른쪽 자식이 있으면) {
            insert(V의 오른쪽 자식, X);
        } else {
            V의 오른쪽 자식을 새로 만들고, 그 곳에 X를 저장함
        }
    }
}

N과, 배열 P에 있는 수가 주어졌을 때, P로 이진 검색 트리를 만들었을 때, 모든 노드의 높이의 합을 출력하는 프로그램을 작성하시오. 트리의 높이는 루트에서 부터의 거리 + 1이다.

 

입력

첫째 줄에 N이 주어진다. N은 250,000보다 작거나 같은 자연수이다. 둘째 줄부터 N개의 줄에 P[0]부터 P[N-1]의 원소가 한 줄에 하나씩 들어온다.

출력

주어진 P배열로 이진 검색 트리를 만들었을 때, 높이의 합을 출력한다. 이 값은 2^63보다 작다.

 


 

풀이

 

cf) 안 되는 풀이: 브루트포스

가장 먼저 떠오르는 방법이라면, 일일히 다 세는 방법일 것이다. 하지만, 이진 검색 트리이지만 균형을 맞추는 과정이 없기에 최악의 경우 O(N^2) 시간복잡도를 가진다. 따라서 시간초과가 뜬다.

 

 

되는 풀이

한 마디로, (현재 원소가 갖게 되는 높이) = max((현재 tree에 들어있는 원소 중 현재 원소보다 작으면서 가장 큰 원소가 갖는 높이), (현재 tree에 들어있는 원소 중 현재 원소보다 크면서 가장 작은 원소가 갖는 높이)) + 1이 된다. 만약 현재 원소보다 크거나/작은 원소가 없다면 위 식에서 0이라고 생각하면 된다(실제 구현에서는 null이 나오기에 처리 필요).

 

상세한 설명은

https://m.blog.naver.com/PostView.naver?isHttpsRedirect=true&blogId=occidere&logNo=221133866451

참고.

 

 

코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.TreeSet;

public class Main {

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine());
        // 노드별로 높이를 기록할 배열
        int[] len = new int[N];
        long ans = 0;
        TreeSet<Integer> tree = new TreeSet<>();

        for (int i = 0; i < N; i++) {
            int num = Integer.parseInt(br.readLine());
            if (tree.higher(num) == null) {
                if (tree.lower(num) == null) {
                    len[num] = 1;
                } else {
                    len[num] = len[tree.lower(num)]+1;
                }
            } else {
                if (tree.lower(num) == null) {
                    len[num] = len[tree.higher(num)] + 1;
                } else {
                    len[num] = Math.max(len[tree.higher(num)], len[tree.lower(num)]) + 1;
                }
            }
            ans += len[num];
            tree.add(num);
        }

        System.out.println(ans);
    }

}