PS/Divide and Conquer

백준 10830번: 행렬 제곱 (Python & JAVA)

닻과매 2022. 2. 15. 17:02

https://www.acmicpc.net/problem/10830

 

10830번: 행렬 제곱

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

www.acmicpc.net

문제

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

입력

첫째 줄에 행렬의 크기 N과 B가 주어진다. (2 ≤ N ≤  5, 1 ≤ B ≤ 100,000,000,000)

둘째 줄부터 N개의 줄에 행렬의 각 원소가 주어진다. 행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0이다.

출력

첫째 줄부터 N개의 줄에 걸쳐 행렬 A를 B제곱한 결과를 출력한다.

 


 

내 풀이 (Python 풀이)

dp라는 배열에 주어진 매트릭스를 1제곱, 2제곱, 4제곱, ... 한 값을 넣어준 후, B를 2진수의 관점으로 보면서 비트체크를 하여 해당 비트가 1이면 결과 매트릭스(=result, 초기값은 항등행렬 I)에 해당 비트 index에 해당하는 행렬을 곱해주었다. 지금 보니까 참 개고생하면서 풀었다.

 

정석적인 풀이 (JAVA 풀이)

분할정복으로 푼다. matrix 곱셈 구현만 해야한다는 점만 제외하면 매우 정석적인 분할정복 문제. 다시 보니까 그래도 몇개월동안 실력이 늘긴 늘었다는 생각이 든다.

 

배울 점

1제곱할 때도 각각의 원소에 1000으로 나눈 값을 출력해야 한다. 즉, input 받을 때부터 1000으로 나눠서 저장해야 한다. 많이들 실수다는 점이 그나마 위안이 된다...가 아니라 문제를 잘 읽는 습관을 들이자. 풀 수 있을 거 같으면 흥분해서 들이박으니 실수가 잦다.

 

코드

풀이 1. Python 풀이

import sys, math
from typing import List

N, power = map(int, sys.stdin.readline().split())
matrix = []

for _ in range(N):
    matrix.append(list(map(int, sys.stdin.readline().split())))

for i in range(N):
    for j in range(N):
        matrix[i][j] %= 1000

def multiply(matrix1, matrix2):
    result = [[0]*N for _ in range(N)]
    for i in range(N):
        for j in range(N):
            for k in range(N):
                result[i][j] += matrix1[i][k] * matrix2[k][j]
            result[i][j] = result[i][j] % 1000
    return result

dp = [0] * (math.ceil(math.log(power, 2)) + 1)

dp[0] = matrix
for i in range(1, len(dp)):
    dp[i] = multiply(dp[i-1], dp[i-1])

bin_power = bin(power)[2:][::-1]
result = [[0]*N for _ in range(N)]
for i in range(N):
    for j in range(N):
        if i == j:
            result[i][j] = 1

for idx, letter in enumerate(bin_power):
    if letter == "1":
        result = multiply(result, dp[idx])

for i in range(N):
    print(" ".join(map(str, result[i])))

 

풀이 2. JAVA 풀이

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;

public class Main {
	static int N;
	
	public static void main(String[] args) throws IOException{
		BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
		StringBuilder sb = new StringBuilder();
		StringTokenizer st = new StringTokenizer(br.readLine());
		N = Integer.parseInt(st.nextToken());
		long B = Long.parseLong(st.nextToken());
		int[][] matrix = new int[N][N];
		for (int i = 0; i < N; i++) {
			st = new StringTokenizer(br.readLine());
			for (int j = 0; j < N; j++) {
				matrix[i][j] = Integer.parseInt(st.nextToken()) % 1000;
			}
		}
		int[][] result = matrixSquare(matrix, B);
		for (int i = 0; i < N; i++) {
			for (int j = 0; j < N; j++) {
				sb.append(result[i][j] + " ");
			}
			sb.append("\n");
		}
		
		sb.setLength(sb.length()-1);
		System.out.println(sb.toString());
		
	}
	
	static int[][] matrixMultiply(int[][] matrix1, int[][] matrix2){
		int[][] newMatrix = new int[N][N];
		
		for (int i = 0; i < N; i++) {
			for (int j = 0; j < N; j++) {
				for (int k = 0; k < N; k++) {	// 곱하고 다 더한 후 1000으로 나눠도 됨! 5000000만 이하임
					newMatrix[i][j] += matrix1[i][k] * matrix2[k][j];
				}
				newMatrix[i][j] %= 1000;
			}
		}
		
		return newMatrix;
	}
	
	static int[][] matrixSquare(int[][] matrix, long B){
		if (B==1) return matrix;
		
		if (B % 2 == 1) {
			int[][] tempMatrix = matrixSquare(matrix, B/2);
			return matrixMultiply(matrixMultiply(tempMatrix, tempMatrix), matrix);
		}
		else {
			int[][] tempMatrix = matrixSquare(matrix, B/2);
			return matrixMultiply(tempMatrix, tempMatrix);
		}
	}

}