PS/DP

백준 1149번: RGB거리 (Python)

닻과매 2021. 10. 2. 10:43

문제

RGB거리에는 집이 N개 있다. 거리는 선분으로 나타낼 수 있고, 1번 집부터 N번 집이 순서대로 있다.

집은 빨강, 초록, 파랑 중 하나의 색으로 칠해야 한다. 각각의 집을 빨강, 초록, 파랑으로 칠하는 비용이 주어졌을 때, 아래 규칙을 만족하면서 모든 집을 칠하는 비용의 최솟값을 구해보자.

  • 1번 집의 색은 2번 집의 색과 같지 않아야 한다.
  • N번 집의 색은 N-1번 집의 색과 같지 않아야 한다.
  • i(2 ≤ i ≤ N-1)번 집의 색은 i-1번, i+1번 집의 색과 같지 않아야 한다.

입력

첫째 줄에 집의 수 N(2 ≤ N ≤ 1,000)이 주어진다. 둘째 줄부터 N개의 줄에는 각 집을 빨강, 초록, 파랑으로 칠하는 비용이 1번 집부터 한 줄에 하나씩 주어진다. 집을 칠하는 비용은 1,000보다 작거나 같은 자연수이다.

출력

첫째 줄에 모든 집을 칠하는 비용의 최솟값을 출력한다.

 


 

문제에서는 1번째부터 N번째 집이라고 부르지만, index랑 맞추기 위해서 0번째부터 N-1번째 집이라고 부르자(혹은 3*(N+1) 사이즈로 행렬을 2개 만들고, j=0인 index를 사용하지 않아도 된다.).

다음과 같은 3*N 행렬 2개를 정의하자:

 

  1. price[i][j]: j번째 집을 i번째 색으로 칠하는 비용(i=0: 'red', i=1: 'green', i=2: 'blue'라 하자)
  2. dp[i][j]: j번째 집을 i번째 색으로 칠했을 때, j+1번째 집부터 N-1번째 집까지 칠하는 최소 비용

 

그러면,

  1.  dp[i][N-1]: 'N-1번째 집을 i번째로 칠한 후, N번째 집을...' 집은 N-1번까지 있으니, 칠할 집이 없다. 정의상 0이 된다.
  2.  j가 N-1보다 작은 경우: j번째 집을 0(=red) 색으로 칠한 경우, j+1번째 집은 0번째 색으로 칠할 수 없으며, 1 or 2 색으로만 칠해야한다. 그러면, dp[0][j] = min(dp[1][j+1] + price[1][j+1], dp[2][j+1] + price[2][j+1])이 된다. dp[1][j], dp[2][j]의 경우도 마찬가지로 식을 세울 수 있다. 이렇게 j=0인 경우까지 최소값을 구할 수 있다.
  3. 전체를 칠하는 최소값은
    1. 처음을 빨간색으로 칠하는 비용 + 처음을 빨간색으로 칠한 경우 나머지를 칠하는 최소 비용 
    2. 처음을 초록색으로 칠하는 비용 + 처음을 초록색으로 칠한 경우 나머지를 칠하는 최소 비용 
    3. 처음을 파란색으로 칠하는 비용 + 처음을 파란색으로 칠한 경우 나머지를 칠하는 최소 비용

      중 최소값이 된다.

 


 

개인적인 피드백

  1. Top-down 방식으로 코드를 짜니 RecursionError이 발생했다. 이 경우, sys.setrecursionlimit(10000)를 통해 recursion depth를 늘려주자. 다만 알고리즘 자체가 시간 내에 굴러갈 수 있다는 확신이 있을때만 해야하지 않을까..싶다.
  2. '1번째 집부터 N번째 집'이라는 표현은 python의 list index랑 안 맞기에, 머리 속으로 '0번째 집부터 N-1번째 집...'이라 생각하든, N+1 길이의 list를 만들고 i번째 집이랑 i번째 index랑 대응시키든 하자. 안 그러니 머리 속이 너무 복잡해지더라.
  3. 단어 드래그 하고 Ctrl + Shift + L 누르면 해당 단어 한 번에 수정 가능(vscode).

 

Bottom-up 방식

import sys

N = int(sys.stdin.readline())
price = [[],[],[]]
for j in range(N):
    r, g, b = map(int, sys.stdin.readline().split())
    price[0].append(r)
    price[1].append(g)
    price[2].append(b)

dp = [[0]*N for _ in range(3)]
for j in range(N-2, -1, -1):
    dp[0][j] = min(price[1][j+1] + dp[1][j+1], price[2][j+1] + dp[2][j+1])
    dp[1][j] = min(price[0][j+1] + dp[0][j+1], price[2][j+1] + dp[2][j+1])
    dp[2][j] = min(price[0][j+1] + dp[0][j+1], price[1][j+1] + dp[1][j+1])

print(min(price[0][0]+dp[0][0], price[1][0]+dp[1][0], price[2][0]+dp[2][0]))

 

*Top-down 방식으로도 작성하였으나, 문제 푸는데 급급하여 코드가 더럽다. 나중에 수정이 필요함.

Top-down 방식

import sys
sys.setrecursionlimit(10000)
from typing import List


N = int(sys.stdin.readline())
red_list = []
green_list = []
blue_list = []
for i in range(N):
    r, g, b = map(int, sys.stdin.readline().split())
    red_list.append(r)
    green_list.append(g)
    blue_list.append(b)
memo_red = [0]*N
memo_green = [0]*N
memo_blue = [0]*N

def RGB_distance(num: int, prev: str)-> int:
    if num == N:
        return 0

    if not prev: # 처음
        memo_red[0] = RGB_distance(num+1, 'r')
        memo_green[0] = RGB_distance(num+1, 'g')
        memo_blue[0] = RGB_distance(num+1, 'b')
        return(min(red_list[0]+memo_red[0], green_list[0]+memo_green[0], blue_list[0]+memo_blue[0]))

    if prev == 'r':
        if memo_red[num-1]:
            return memo_red[num-1]
        memo_green[num] = RGB_distance(num+1, 'g')
        memo_blue[num] = RGB_distance(num+1, 'b')
        return(min(green_list[num]+memo_green[num], blue_list[num]+memo_blue[num]))

    if prev == 'g':
        if memo_green[num-1]:
            return memo_green[num-1]
        memo_red[num] = RGB_distance(num+1, 'r')
        memo_blue[num] = RGB_distance(num+1, 'b')
        return(min(red_list[num]+memo_red[num], blue_list[num]+memo_blue[num]))
    
    if prev == 'b':
        if memo_blue[num-1]:
            return memo_blue[num-1]
        memo_red[num] = RGB_distance(num+1, 'r')
        memo_green[num] = RGB_distance(num+1, 'g')
        return(min(red_list[num]+memo_red[num], green_list[num]+memo_green[num]))
    
print(RGB_distance(0, ""))

 

 

2021.11.10 revisted

다시 보니 되게 쉬운 문제네? 위의 top-down 코드는 왜이리 더럽냐;;

import sys


N = int(sys.stdin.readline())
dp = []
for _ in range(N):
    dp.append(list(map(int, sys.stdin.readline().split())))

for i in range(1, N):
    dp[i][0] += min(dp[i-1][1], dp[i-1][2])
    dp[i][1] += min(dp[i-1][0], dp[i-1][2])
    dp[i][2] += min(dp[i-1][0], dp[i-1][1])

print(min(dp[N-1]))