PS/Binary Search

백준 7453번: 합이 0인 네 정수 (JAVA)

닻과매 2022. 4. 13. 22:45

https://www.acmicpc.net/problem/7453

 

7453번: 합이 0인 네 정수

첫째 줄에 배열의 크기 n (1 ≤ n ≤ 4000)이 주어진다. 다음 n개 줄에는 A, B, C, D에 포함되는 정수가 공백으로 구분되어져서 주어진다. 배열에 들어있는 정수의 절댓값은 최대 228이다.

www.acmicpc.net

문제

정수로 이루어진 크기가 같은 배열 A, B, C, D가 있다.

A[a], B[b], C[c], D[d]의 합이 0인 (a, b, c, d) 쌍의 개수를 구하는 프로그램을 작성하시오.

입력

첫째 줄에 배열의 크기 n (1 ≤ n ≤ 4000)이 주어진다. 다음 n개 줄에는 A, B, C, D에 포함되는 정수가 공백으로 구분되어져서 주어진다. 배열에 들어있는 정수의 절댓값은 최대 228이다.

 

출력

합이 0이 되는 쌍의 개수를 출력한다.

 


 

풀이

기본적으로 모든 풀이가, A의 원소와 B의 원소의 합, 그리고 C의 원소와 D의 원소의 합을 따로 구해놓는 데에서부터 시작한다. 이후, -(A와 B의 합)이 C와 D의 원소의 합에 몇 개 있는지 세면 된다. 세는 방법에서 차이가 난다.

 

 

풀이 1. Map을 이용한 풀이 (O(N^2))

가장 Big-O time complexity가 짧은 풀이이다. C의 원소와 D의 원소의 합을 Counter 방식으로 기록하여, -(A원소와 B 원소의 합)이 Counter에 몇 개 있는지 찾는다. HashMap을 이용하였기에 원소를 찾는 작업이 O(1)일거라 생각했으나, 'map은 원래 많이 느리다고 한다.' 또한, 'hash function을 안다면 데이터셋을 저격하여 생성함으로써 search의 time complexity를 O(N)이 되게할 수 있다고 한다'. 그래서, 해보면 시간초과 뜬다: 예전에 한 번 Map 풀이를 저격한다고 테스트케이스를 추가했다고 한다.

참고 링크: (https://www.acmicpc.net/board/view/76438)

 

풀이 2. 이분탐색을 이용 (O(N^2logN))

C의 원소와 D의 원소의 합을 모은 배열 CD를 만들고, 정렬한다. -(A의 원소와 B의 원소의 합)이 CD에 몇 개 있는지 lower bound, upper bound를 통해 구한다. C++에는 기본으로 있다던데, JAVA는 없으니 구현하도록 하자.

다만 이 풀이는 이 문제를 풀기에 시간이 꽤 빡빡하다. 내 풀이는 11초 언저리가 걸렸는데, 이 정도의 풀이는 서버 상태에 따라 TLE이 뜨기도, 성공하기도 한다. 어떤 분은 똑같이 binary search를 이용하여 6초가 나왔는데, 내 코드에서 하나씩 바꿔가면서 그 분 코드랑 비슷하게 바꿔보는데도 시간이 안 맞더라. 

 

풀이 3. 투 포인터를 이용 (O(N^2logN))

A의 원소와 B의 원소의 합을 모은 배열 AB, C의 원소와 D의 원소의 합을 모은 배열 CD를 만들고, 정렬한다. 

left = 0, right = N*2 - 1에서 시작하여, AB[left] + CD[right]의 합에 따라 투 포인터 식으로 처리한다.

특이할 점으로는, 값이 0이 될 때, left에 해당 원소가 몇 개 있는지 세고('leftCount'), right에 해당 원소가 몇 개 있는지 센 후('rightCount'), 두 원소를 곱하여 정답에 더한다.

 

이렇게, 동등한 두 집합으로 나누어 시간복잡도를 '대략' root씌운 정도로 줄이는 알고리즘(본 문제에서는 브루트포스 O(N^4) -> O(N^2logN))을 'meet in the middle' 알고리즘이라고 부른다고 한다.

정답은 최대 16000000^2까지 가능하다(모든 배열이 0 4000개로 되어있다고 생각해보자.). 따라서 정답은 long으로 저장해야한다.

 

 

코드

풀이 1. HashMap(TLE)

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.HashMap;
import java.util.Map;
import java.util.StringTokenizer;

public class Main {

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int N = Integer.parseInt(br.readLine());
        int[] A = new int[N], B = new int[N], C = new int[N], D = new int[N];
        for (int i = 0; i < N; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            A[i] = Integer.parseInt(st.nextToken());
            B[i] = Integer.parseInt(st.nextToken());
            C[i] = Integer.parseInt(st.nextToken());
            D[i] = Integer.parseInt(st.nextToken());
        }

        Map<Integer, Integer> CD = new HashMap<>();
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++) {
                if (!CD.containsKey(C[i] + D[j])) {
                    CD.put(C[i] + D[j], 1);
                } else {
                    CD.put(C[i] + D[j], CD.get(C[i] + D[j])+1);
                }
            }
        }

        long ans = 0;
        for (int i = 0; i < N; i++) {
            for (int j = 0; j < N; j++) {
                if (CD.containsKey(-A[i]-B[j])) {
                    ans += CD.get(-A[i]-B[j]);
                }
            }
        }
        System.out.println(ans);
    }

}

 

풀이 2. 이분 탐색

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;

public class Main {

	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		int N = Integer.parseInt(br.readLine());
		int[] A = new int[N], B = new int[N], C = new int[N], D = new int[N];
		for (int i = 0; i < N; i++) {
			StringTokenizer st = new StringTokenizer(br.readLine());
			A[i] = Integer.parseInt(st.nextToken());
			B[i] = Integer.parseInt(st.nextToken());
			C[i] = Integer.parseInt(st.nextToken());
			D[i] = Integer.parseInt(st.nextToken());
		}
		
		int[] CD = new int[N*N];
		int idx = 0;
		for (int i = 0; i < N; i++) {
			for (int j = 0; j < N; j++) {
				CD[idx++] = C[i] + D[j];
			}
		}
		
		Arrays.sort(CD);
		
		
		
		long ans = 0;
		for (int i = 0; i < N; i++) {
			for (int j = 0; j < N; j++) {
				int temp = A[i] + B[j];
				int upper = upperBound(-temp, CD);
				int lower = lowerBound(-temp, CD);
				ans += (upper - lower);
			}
		}
		System.out.println(ans);
	}
	
	
	static int upperBound(int key, int[] arr) {
		int start = 0, end = arr.length-1;
		while (start <= end) {
			int mid = (start + end)/2;
			if (arr[mid] > key) {
				end = mid - 1;
			} else {
				start = mid + 1;
			}
 		}
		return end;
	}
	
	static int lowerBound(int key, int[] arr) {
		int start = 0, end = arr.length-1;
		while (start <= end) {
			int mid = (start + end)/2;
			if (arr[mid] >= key) {
				end = mid - 1;
			} else {
				start = mid + 1;
			}
 		}
		return end;
	}
}

 

풀이 3. 투 포인터

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;

public class Main {
	static int[] AB, CD;

	public static void main(String[] args) throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		int N = Integer.parseInt(br.readLine());
		int[] A = new int[N], B = new int[N], C = new int[N], D = new int[N];
		for (int i = 0; i < N; i++) {
			StringTokenizer st = new StringTokenizer(br.readLine());
			A[i] = Integer.parseInt(st.nextToken());
			B[i] = Integer.parseInt(st.nextToken());
			C[i] = Integer.parseInt(st.nextToken());
			D[i] = Integer.parseInt(st.nextToken());
		}
		
		AB = new int[N*N];
		CD = new int[N*N];
		int idx = 0;
		for (int i = 0; i < N; i++) {
			for (int j = 0; j < N; j++) {
				AB[idx] = A[i] + B[j];
				CD[idx++] = C[i] + D[j];
			}
		}
		
		Arrays.sort(AB);
		Arrays.sort(CD);
		
		long ans = 0;
		int left = 0, right = N*N-1;
		while (left < N*N && right >= 0) {
			if (AB[left] + CD[right] < 0) {
				left++;
			} else if (AB[left] + CD[right] > 0) {
				right--;
			} else {
				long leftCount = 1, rightCount = 1;
				while (left + 1 < N*N && (AB[left] == AB[left+1])) {
					leftCount++;
					left++;
				}
				while (right > 0 && (CD[right] == CD[right-1])) {
					rightCount++;
					right--;
				}
				ans += leftCount * rightCount;
				left++;
			}
		}
		
		System.out.println(ans);
	}
	
}

 

결과

HashMap을 이용한 풀이(=맨 아래 코드)는 '시간 초과 (5%)' 가 뜬다.

이분탐색을 이용한 풀이(=1622B 언저리의 코드)는 상황에 따라 '시간 초과'랑 '맞았습니다(11초 이상)'을 왔다갔다한다.

투 포인터를 이용한 풀이(=맨 위 코드)는 대략 4초가 걸린다.

 

제출 내역만 봐도 알 수 있듯이, 진짜 줄이려고 노력 많이 했다.. 같은 이분탐색인데, 저기 6124ms 뜨는 코드도 있는데 내 코드는 왜 이러지;;