세그먼트 트리(Segment Tree)란?
- 여러 개의 데이터가 존재할 때 특정 구간의 합(최솟값, 최댓값, 곱 등)을 구하는 데 사용하는 자료구조이다.
- 트리 종류 중에 하나로 이진 트리의 형태이며, 특정 구간의 합을 가장 빠르게 구할 수 있다는 장점이 있다. (O(logN))
아래 예제를 통해 세그먼트 트리(Segment Tree)를 왜 사용하는지 알아보자.
(Ex)
위와 같은 배열이 있다고 하자. 이 때 데이터의 개수는 10개로 인덱스는 0부터 9까지 차례대로 1~10의 원소가 삽입되어 있다. 만약 인덱스 2부터 8까지 데이터를 더하려면 어떻게 할까?
위 그림처럼 2~8 범위의 원소를 하나씩 다 더하면 된다. 결과는 42이다. 이러한 방식으로 다른 특정 구간의 합을 구한다고 고려했을 때 앞에서 하나씩 더하므로 데이터의 개수가 N이면 시간 복잡도는 O(N)이다. 따라서 이러한 방식을 이용하면 구간의 합을 구하는 속도가 너무 느리기 때문에 더 좋은 방법이 필요하다.
→ 더 좋은 방법이 바로 세그먼트 트리(Segment Tree)를 사용하는 것이다!
세그먼트 트리 구현 과정 (Python)
세그먼티 트리 초기화 및 생성 함수
arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
tree = [0] * (len(arr) * 4)
예제로 arr리스트와 세그먼트 트리인 tree를 생성하였다. (Python에서는 리스트로 세그먼트 트리를 생성한다. C++에서는 보통 vector를 이용)
이 때 tree의 크기 할당에 대해 의문이 생길 수 있다. 세그먼트 트리의 크기는 배열(arr)의 개수가 N개일 때, N보다 큰 가장 가까운 N의 제곱수를 구한 뒤에 그것의 2배를 하여 미리 세그먼트 트리의 크기를 만들어놓어야 한다. 위 예제의 경우 N이 10으로 가장 가까운 제곱수는 4^2=16으로 16*2=32개의 크기가 필요하다(넉넉하게 세그먼트 트리를 생성하는 것이다). 그래서 실제로는 데이터의 개수 N에 4를 곱한 크기만큼 미리 세그먼트 트리의 크기를 할당한다.
세그먼트 트리의 공간을 할당하였다. 이제 세그먼트 트리의 구조를 알아본 후 세그먼트 트리에 원소를 삽입해보자. 위 예제를 기준으로 아래 그림을 보면 된다.
먼저 루트 노드부터 보자. 세그먼트 트리의 루트 노드에는 0~9(index)까지의 arr 구간 합의 값이 삽입되고 루트 노드의 번호는 1번이 된다.
다음으로 루트 노드의 자식 노드를 보자.
왼쪽 자식 노드의 번호는 2번이 되며, 0~4까지의 arr 구간 합의 값이 삽입된다.
오른쪽 자식 노드의 번호는 3번이 되며, 5~9까지의 arr 구간 합의 값이 삽입된다.
이러한 방식으로 세그먼트 트리의 원소를 채워주게 된다.
이때 보통 리스트의 인덱스는 0번부터 시작하는데, 세그먼트 트리는 1번부터 시작하는지 의문이 생길 수 있다. 세그먼트 트리의 인덱스가 1번부터 시작하는 이유는 재귀적으로 편하게 세그먼트 트리를 생성하기 위해서이다. 1부터 시작하게 되면 2를 곱했을 때는 왼쪽 자식 노드를 가리키고, 2를 곱하고 1을 더하면 오른쪽 자식 노드를 가리키게 되어 효과적이기 때문이다.
# <세그먼트 트리를 배열의 각 구간 합으로 채워주기>
# start : 배열의 시작 인덱스, end : 배열의 마지막 인덱스
# index : 세그먼트 트리의 인덱스 (무조건 1부터 시작)
def init(start, end, index):
# 가장 끝에 도달했으면 arr 삽입
if start == end:
tree[index] = arr[start]
return tree[index]
mid = (start + end) // 2
# 좌측 노드와 우측 노드를 채워주면서 부모 노드의 값도 채워준다.
tree[index] = init(start, mid, index * 2) + init(mid + 1, end, index * 2 + 1)
return tree[index]
그림과 코드로 이해가 안된다면 직접 손으로 디버깅하여 살펴보는 것을 추천한다.
※ 세그먼트 트리는 매 노드가 이미 구간의 합을 가지고 있는 형태가 된다. 그리고 추가로 세그먼트 트리의 인덱스(index)와 구간의 합은 별개의 값이므로 헷갈리지 않도록 주의하자!
세그먼트 트리로 구간 합 구하는 함수
세그먼트 트리는 트리 구조이기 때문에 데이터를 탐색하는데 있어 O(logN)의 시간 복잡도를 가진다. 따라서 구간 합을 항상 O(logN)의 시간에 구할 수 있다.
예를 들어 6~9의 범위의 구간 합을 구한다고 하자. 그러면 아래의 그림처럼 3개의 빨간색 노드의 합만 구해주면 된다.
구하고자 하는 6~9의 범위의 구간 합은 7 + 8 + 9 + 10 = 34이다. 전에 세그먼트 트리에서 인덱스 7의 원소값은 19, 인덱스 13의 원소값은 8, 인덱스 25의 원소값은 7이었다. 즉, 구하고자 하는 값은 7 + 8 + 19 = 34가 되는 것이다.
구간의 합 구하는 함수 또한 재귀적으로 구현한다. 구간의 합은 '범위 안에 있는 경우'에 한해서만 더해주면 된다. 그 밖의 경우는 고려하지 않는다!
# <구간 합을 구하는 함수>
# start : 시작 인덱스, end : 마지막 인덱스
# left, right : 구간 합을 구하고자 하는 범위
def interval_sum(start, end, index, left, right):
# 범위 밖에 있는 경우
if left > end or right < start:
return 0
# 범위 안에 있는 경우
if left <= start and right >= end:
return tree[index]
# 그렇지 않다면 두 부분으로 나누어 합을 구하기
mid = (start + end) // 2
# start와 end가 변하면서 구간 합인 부분을 더해준다고 생각하면 된다.
return interval_sum(start, mid, index * 2, left, right) + interval_sum(mid + 1, end, index * 2 + 1, left, right)
특정 원소의 값을 수정하는 함수
특정 원소를 수정하면 구간의 합들이 달라지고, 세그먼트 트리의 원소값들도 당연히 달라진다. 따라서 특정 원소의 값을 수정할 때는 해당 원소를 포함하고 있는 모든 구간 합 노드들을 갱신해주면 된다. 즉, 세그먼트 트리의 모든 노드를 변경하는 것이 아닌 해당 원소를 포함하고 있는 부분적인 노드들만 바꿔주면 되는 것이다.
예를 들어 인덱스 6의 arr값(arr[6])을 수정한다고 하면 다음과 같이 5개의 구간 합 노드를 모두 수정하면 된다.
이 함수 또한 재귀적으로 구현하며, 수정할 노드로는 '범위 안에 있는 경우'에 한해서만 수정해준다.
# <특정 원소의 값을 수정하는 함수>
# start : 시작 인덱스, end : 마지막 인덱스
# what : 구간 합을 수정하고자 하는 노드
# value : 수정할 값
def update(start, end, index, what, value):
# 범위 밖에 있는 경우
if what < start or what > end:
return
# 범위 안에 있으면 내려가면서 다른 원소도 갱신
tree[index] += value
if start == end:
return
mid = (start + end) // 2
update(start, mid, index * 2, what, value)
update(mid + 1, end, index * 2 + 1, what, value)
전체 코드 (Python)
# (Ex)
arr = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# 실제로는 데이터의 개수 N에 4를 곱한 크기만큼 미리 세그먼트 트리의 공간을 할당한다.
tree = [0] * (len(arr) * 4)
# <세그먼트 트리를 배열의 각 구간 합으로 채워주기>
# start : 배열의 시작 인덱스, end : 배열의 마지막 인덱스
# index : 세그먼트 트리의 인덱스 (무조건 1부터 시작)
# 세그먼트 트리가 1부터 시작하는 이유는 2를 곱했을 때 왼쪽 자식노드를 가리키고
# 2를 곱하고 1을 더하면 오른쪽 자식노드를 가리키므로 효과적이기 때문에 이렇게 한다!
def init(start, end, index):
# 가장 끝에 도달했으면 arr 삽입
if start == end:
tree[index] = arr[start]
return tree[index]
mid = (start + end) // 2
# 좌측 노드와 우측 노드를 채워주면서 부모 노드의 값도 채워준다.
tree[index] = init(start, mid, index * 2) + init(mid + 1, end, index * 2 + 1)
return tree[index]
# <구간 합을 구하는 함수>
# start : 시작 인덱스, end : 마지막 인덱스
# left, right : 구간 합을 구하고자 하는 범위
def interval_sum(start, end, index, left, right):
# 범위 밖에 있는 경우
if left > end or right < start:
return 0
# 범위 안에 있는 경우
if left <= start and right >= end:
return tree[index]
# 그렇지 않다면 두 부분으로 나누어 합을 구하기
mid = (start + end) // 2
# start와 end가 변하면서 구간 합인 부분을 더해준다고 생각하면 된다.
return interval_sum(start, mid, index * 2, left, right) + interval_sum(mid + 1, end, index * 2 + 1, left, right)
# <특정 원소의 값을 수정하는 함수>
# 특정 원소를 수정하면 구간 합이 당연히 달라진다.
# 이때, 해당 원소를 포함하고 있는 모든 구간 합 노드들을 갱신해주면 된다.
# (즉, 전체가 아닌 부분적인 노드들만 바꿔주면 된다!)
# start : 시작 인덱스, end : 마지막 인덱스
# what : 구간 합을 수정하고자 하는 노드
# value : 수정할 값
def update(start, end, index, what, value):
# 범위 밖에 있는 경우
if what < start or what > end:
return
# 범위 안에 있으면 내려가면서 다른 원소도 갱신
tree[index] += value
if start == end:
return
mid = (start + end) // 2
update(start, mid, index * 2, what, value)
update(mid + 1, end, index * 2 + 1, what, value)
init(0, len(arr) - 1, 1)
print(interval_sum(0, len(arr) - 1, 1, 0, 9)) # 0부터 9까지의 구간 합 (1 + 2 + ... + 9 + 10)
print(interval_sum(0, len(arr) - 1, 1, 0, 2)) # 0부터 2까지의 구간 합 (1 + 2 + 3)
print(interval_sum(0, len(arr) - 1, 1, 6, 7)) # 0부터 2까지의 구간 합 (7 + 8)
# arr[0]을 +4만큼 수정
update(0, len(arr) - 1, 1, 0, 4)
print(interval_sum(0, len(arr) - 1, 1, 0, 2)) # 0부터 2까지의 구간 합 ((1 + 4) + 2 + 3)
# arr[9]를 -11만큼 수정
update(0, len(arr) - 1, 1, 9, -11)
print(interval_sum(0, len(arr) - 1, 1, 8, 9)) # 8부터 9까지의 구간 합 (9 + (10 - 11))
- 시간 복잡도는 O(logN)으로, 세그먼트 트리를 이용하면 기존의 구간 합을 계산할 때 훨씬 더 빠르게 구간 합을 구할 수 있다.
'Algorithm' 카테고리의 다른 글
백준 2573번 빙산(JAVA, Python) (0) | 2023.02.08 |
---|