PS/Dijkstra

백준 1238번: 파티 (JAVA)

닻과매 2022. 2. 24. 21:01

https://www.acmicpc.net/problem/1238

 

1238번: 파티

첫째 줄에 N(1 ≤ N ≤ 1,000), M(1 ≤ M ≤ 10,000), X가 공백으로 구분되어 입력된다. 두 번째 줄부터 M+1번째 줄까지 i번째 도로의 시작점, 끝점, 그리고 이 도로를 지나는데 필요한 소요시간 Ti가 들어

www.acmicpc.net

문제

N개의 숫자로 구분된 각각의 마을에 한 명의 학생이 살고 있다.

어느 날 이 N명의 학생이 X (1 ≤ X ≤ N)번 마을에 모여서 파티를 벌이기로 했다. 이 마을 사이에는 총 M개의 단방향 도로들이 있고 i번째 길을 지나는데 Ti(1 ≤ Ti ≤ 100)의 시간을 소비한다.

각각의 학생들은 파티에 참석하기 위해 걸어가서 다시 그들의 마을로 돌아와야 한다. 하지만 이 학생들은 워낙 게을러서 최단 시간에 오고 가기를 원한다.

이 도로들은 단방향이기 때문에 아마 그들이 오고 가는 길이 다를지도 모른다. N명의 학생들 중 오고 가는데 가장 많은 시간을 소비하는 학생은 누구일지 구하여라.

입력

첫째 줄에 N(1 ≤ N ≤ 1,000), M(1 ≤ M ≤ 10,000), X가 공백으로 구분되어 입력된다. 두 번째 줄부터 M+1번째 줄까지 i번째 도로의 시작점, 끝점, 그리고 이 도로를 지나는데 필요한 소요시간 Ti가 들어온다. 시작점과 끝점이 같은 도로는 없으며, 시작점과 한 도시 A에서 다른 도시 B로 가는 도로의 개수는 최대 1개이다.

모든 학생들은 집에서 X에 갈수 있고, X에서 집으로 돌아올 수 있는 데이터만 입력으로 주어진다.

출력

첫 번째 줄에 N명의 학생들 중 오고 가는데 가장 오래 걸리는 학생의 소요시간을 출력한다.

 


 

풀이

처음 생각에는 X를 제외한 1부터 N까지의 점 N-1개를 다익스트라를 돌려 N까지 가는 최솟값을 구하고, 반대로 X에서 다익스트라를 돌려 X에서 다른 정점까지 도착하는 최솟값을 구하여 둘의 합이 최대가 되는 지점을 찾을까 했다. 그런데 이 알고리즘은 대략 O(N*MlogM)이고, 1억이 넘을 듯하며, '1번의 연산'이 우선순위 큐에 넣고 빼고하는 작업이기에 무조건 시간 초과가 날 거라고 생각했다.

그래서 생각한 것이, '1→X는 시작점과 끝점이 뒤집힌 그래프에서 X→1이랑 같지 않나?'였다. 백준 질문 게시판에서 아이디어에 대한 확신을 얻고(이러면 안되는데ㅋㅋ) 구현하였다.

 

내 실수

dist2 코드 복붙하면서 dist1으로 안 바꾼 곳이 있었다. 복붙하면서 '제대로 다 바꿔야 한다!'라고 생각까지 해놓고 틀렸다. 시간 초과보다 우선순위 큐에 자료 쌓이던 게 먼저 터져서 메모리 초과가 떠서, 'JAVA 얘 메모리 많이 잡아먹더니 될 문제도 안 되는구나'라는 쪽에 생각이 포커스가 가버렸다. 실제 기출 테스트케이스 넣어보면서 무한루프를 발견하고, 코드를 자세히 보면서 틀린 지점을 찾았다. 금붕어인가 싶다.

 

 

코드

더보기
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.PriorityQueue;
import java.util.StringTokenizer;

public class Main {
    static final int INF = 100_000_000; // 2*M*T보다 큰 적당한 수

    public static void main(String[] args) throws IOException{
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());
        int N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken());
        int X = Integer.parseInt(st.nextToken());

        ArrayList<ArrayList<int[]>> graph1 = new ArrayList<>();
        ArrayList<ArrayList<int[]>> graph2 = new ArrayList<>();
        for (int i = 0; i <= N; i++) {
            graph1.add(new ArrayList<>());
            graph2.add(new ArrayList<>());
        }

        int[] dist1 = new int[N+1];
        int[] dist2 = new int[N+1];
        Arrays.fill(dist1, INF);
        Arrays.fill(dist2, INF);

        for (int i = 0; i < M; i++) {
            st = new StringTokenizer(br.readLine());
            int from = Integer.parseInt(st.nextToken());
            int to = Integer.parseInt(st.nextToken());
            int weight = Integer.parseInt(st.nextToken());
            graph1.get(from).add(new int[] {to, weight});
            graph2.get(to).add(new int[] {from, weight});
        }

        PriorityQueue<int[]> pq = new PriorityQueue<>((o1, o2) -> {
            return Integer.compare(o1[1], o2[1]);
        });

        int count = 0;
        pq.offer(new int[] {X, 0});
        dist1[X] = 0;
        while (!pq.isEmpty()) {
            int[] temp = pq.poll();
            int from = temp[0];
            int weight = temp[1];
            if (dist1[from] != weight) continue;
            for (int[] next: graph1.get(from)) {
                if (dist1[next[0]] <= next[1] + dist1[from]) continue;
                dist1[next[0]] = next[1] + dist1[from];
                pq.offer(new int[] {next[0], dist1[next[0]]});
            }
        }

        pq.offer(new int[] {X, 0});
        dist2[X] = 0;
        while (!pq.isEmpty()) {
            int[] temp = pq.poll();
            int from = temp[0];
            int weight = temp[1];
            if (dist2[from] != weight) continue;
            for (int[] next: graph2.get(from)) {
                if (dist2[next[0]] <= next[1] + dist2[from]) continue;
                dist2[next[0]] = next[1] + dist2[from];
                pq.offer(new int[] {next[0], dist2[next[0]]});
            }
        }

        int ans = 0;
        for (int i = 1; i <= N; i++) {
            ans = Math.max(ans, dist1[i] + dist2[i]);
        }
        System.out.println(ans);
    }

}