PS/Segment Tree

세그먼트 트리 기본 개념 with 백준 2042번: 구간 합 구하기 (Java)

닻과매 2022. 6. 21. 21:34

https://www.acmicpc.net/problem/2042

 

2042번: 구간 합 구하기

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄

www.acmicpc.net

문제

어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.

입력

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.

입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

출력

첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

 


 

코테를 생각한다면, 세그먼트 트리 공부할 시간은 아니긴 하다. 다만, PS에서 워낙 많이 쓰는 자료구조다보니 궁금해서 공부해보았다.

 

풀이

대략적인 개요, 구현 방법, 사용 방법 등은 바킹독님 PPT 자료로 대체한다(또한, 크로커즈님 블로그 설명도 매우 자세하여 공부하기 좋다.). 위 PPT에서 팬윅 트리는 건너뛰고, segment tree의 개념, add, update 함수만 보면 된다.

 

세부적인 부분만 살짝 부연설명하자면,

0) segment tree는 완전 이진 트리로 구현이 되기에, Segment Tree class를 직접 구현하거나 하는 대신 배열로 선언하여, 루트 노드의 index를 1, 왼쪽 자식 노드로 갈 때는 2*idx, 오른쪽 자식 노드로 갈 때는 2*idx+1하여 찾는다.

 

1) segment tree 배열의 크기: N개의 노드를 저장하는 segment tree의 height는 Math.ceil(log(N)/log2)가 되며, 이 때 총 노드는 2^(height+1)개가 된다. 이를 구해도 되며, 그냥 4*N으로 널널하게 잡아도 된다.

cf) 문제의 예시를 보다보면 '2*N'으로 잡아도 될 거 같은데, 2*N으로 잡으면 안 된다. -> 반례: N==6인 경우를 생각해보자.

 

2) 처음 배열을 만드는 과정은 바킹독님 말에 따르면 'prefix sum으로 구하면 O(N), update로 구하면 O(NlogN)인데 사실 PS에서 O(N)과 O(NlogN) 차이가 그리 크지 않으니 이미 구현한 update를 써라'고 한다. 그래야겠다.

 

3) 바킹독님 PPT에서 update함수 구현이 살짝 틀렸다. seg[idx]에 diff를 더하는 과정이 if문 전에 있어야 할 듯...? (조심스러움)

 

 

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import java.io.*;
import java.util.*;
 
public class Main {
    
    static long[] seg;
    
    public static void main(String[] args) throws IOException {
        StringBuilder sb = new StringBuilder();
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken());
        int K = Integer.parseInt(st.nextToken());
        // 1부터 시작, 크기는 여유있게 4*N
        seg = new long[4*N];
        long[] arr = new long[N+1];
        
        // init
        for (int i = 1; i <= N; i++) {
            arr[i] = Long.parseLong(br.readLine());
            update(11, N, i, arr[i]);
        }
        
        for (int i = 0; i < K+M; i++) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            long c = Long.parseLong(st.nextToken());
            if (a == 1) {
                update(11, N, b, c-arr[b]);
                arr[b] = c;
            }
            else sb.append(sum(11, N, b, c)+"\n");
        }
        System.out.println(sb.toString());
    }
    
    // 현재 idx는 st부터 en까지의 합을 담고 있으며, i번째 원소를 diff만큼 더하고 싶다
    static void update(int idx, int st, int en, int i, long diff) {
        seg[idx] += diff;
        if (st == en) return;
 
        // 홀수일 경우 왼쪽 노드에 더 붙도록 구현됨
        if (i <= (st+en)/2) update(2*idx, st, (st+en)/2, i, diff);
        else update(2*idx+1, (st+en)/2 + 1, en, i, diff);
    }
    
    // 현재 idx는 st부터 en까지의 합을 담고 있으며, l부터 r까지의 합을 구하고 싶다.
    static long sum(int idx, int st, int en, int l, long r) {
        // 더하는 구간을 벗어났을 경우
        if (l > en || r < st) return 0;
        // 더하는 구간이 현재 구간에 완전 포함되는 경우
        if (l <= st && en <= r) return seg[idx];
        // 그 외의 경우: 일단 반으로 쪼개기
        return sum(2*idx, st, (st+en)/2, l, r) + sum(2*idx+1, (st+en)/2+1, en, l, r);
    }
}
 
cs