알고리즘/Floyd-Warshall

Floyd-Warshall 플로이드 워셜 문제

코딩균 2021. 7. 30. 16:47

문제

1번부터 N번까지의 노드가 있는데 중간에 소개팅녀를 만나는 노드를 거쳐야 한다

 

따라서 n번 노드를 가야한다면 중간에 소개팅녀가 있는 k번 노드를 지나야 한다는 것

 

각 노드 사이는 양방향으로 갈 수 있고 1의 cost를 가진다

 

 

플로이드 워셜 알고리즘

모든 지점에서 다른 모든 지점까지의 최단 경로 구하는 알고리즘

(cf. 다익스트라 - 한 지점에서 다른 특정 지점까지의 최단 경로를 구하는 알고리즘)

 

2차원 배열을 사용하여 각 노드에서 다른 노드로 가는 모든 최단 거리를 구해야함

-> O(N^2)의 시간 복잡도가 필요

 

방문판매원이 1번 노드에서 4번 노드까지 최단거리로 방문해야 한다면

  • 1번에서 출발하여 2번 노드를 경유해서 4번 노드로 가는 경우
  • 1번에서 출발하여 3번 노드를 경유해서 4번 노드로 가는 경우

두가지를 모두 고려하여 최단 거리를 구해주어야 함

이때, graph[1][3]은 1에서 3으로 가는 최단의 거리(cost)이기 때문에 

  • graph[1][2] + graph[2][4]
  • graph[1][3] + graph[3][4]

이러한 형식으로 최단 거리를 표현할 수 있다.

경유를 하는 node들은 for문을 통해서 해결한다고 가정한다면 점화식은

graph[i][j] = min(graph[i][j], graph[i][k] + graph[k][j]
(1<=k<=n)

graph는 각 노드에서 각 노드로 가는 가장 짧은 거리라고 한다

via는 각 노드에서 각 노드로 가장 짧은 거리로 가기 위해서 거쳐야 하는 중간 노드

 

import sys
input = sys.stdin.readline
INF = int(1e9)

n, m = map(int, input().split())

graph = [[INF]*(n+1) for i in range(0, n+1)]

via = [[0]*(n+1) for i in range(0, n+1)]

for i in range(1, n+1):
    for j in range(1, n+1):
        if i==j:
            graph[i][j] = 0
            via[i][j] = i

for _ in range(0, m):
    a, b = map(int, input().split())
    graph[a][b] = 1
    graph[b][a] = 1
    via[b][a] = a
    via[a][b] = b

# k = 중간 방문 소개팅녀 회사 / x = 최종 계약 회사 / 판매원은 1번 회사
x, k = map(int, input().split())

for k in range(1, n+1):
    for i in range(1, n+1):
        for j in range(1, n+1):
            if graph[i][k]+graph[k][j]<graph[i][j]:
                graph[i][j] = graph[i][k]+graph[k][j]
                via[i][j] = k

경유할 경우에 가장 짧은 거리가 나올 때마다 그 때의 노드를 저장한 값이 via이기 때문에 

역추적 하듯이 경로를 알아내면 된다

 

예를 들어 1->4번이면 1->4 direct로 가는 경우가 가장 짧으면 via[1][4]=4이겠지만

중간에 거쳐가는 node가 있다면 v[1][4] != 4일 것이다

따라서 아닌 목적지 노드가 via의 값이 아닌 경우에는 v[1][via]값으로 하여서 다시 살펴본다

결국 트리형식으로 분할해나가며 경로를 list에 append하는 방식

 

path = list()

def split(s, e):
    
    if(via[s][e]==e):
        path.append(s)
        path.append(e)
    else:
        split(s, via[s][e])
        path.pop()
        split(via[s][e], e)


print(graph[1][k] + graph[k][x])

print(via)
print(1, end=" ")
split(1, k)
print()
split(k, x)

print(path)