본문 바로가기
Algorithm/Python

백준 2042번 구간 합 구하기(BIT)

by Shark_상어 2023. 3. 30.
728x90

어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.

입력

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.

입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

출력

첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

 

2. 코드

import sys
input = sys.stdin.readline

def prefix_sum(i):
    result = 0
    while i:
        result += tree[i]
        i -= (i & -i)
    return result

def update(i, dif):
    while i <= n:
        tree[i] += dif
        i += (i & -i)

def interval_sum(start, end):
    return prefix_sum(end) - prefix_sum(start - 1)

n, m, k = map(int, input().split())

arr = [0] * (n + 1)
tree = [0] * (n + 1)


for i in range(1, n + 1):
    x = int(input())
    arr[i] = x
    update(i, x)

for i in range(m + k):
    a, b, c = map(int, input().split())

    if a == 1:
        update(b, c - arr[b])
        arr[b] = c
    else:
        print(interval_sum(b, c))

 

3. 문제해결법

1. 이 문제는 세그먼트 트리를 이용해서 풀 수 있는 문제이다.

2. 하지만 필자는 세그먼트 트리 보다 코드가 더 간결 하고 공간복잡도에 관련하여 조금더 가벼운 BIT(바이너리 인덱스 트리)를 활용 했다.

3. BIT 활용 방식은 아래와 같다.

# 트리 구조 만들기: 0이 아닌 마지막 비트 = 내가 저장하고 있는 값들의 개수
# ex) n = 16 일때 마지막 비트는 16이며, 1 ~ 16 모든 데이터들의 정보를 나타 낸다.
# ex) n = 8 일때 마지막 비트는 8이며, 1 ~ 8 모든 데이터들의 정보를 나타 낸다.
# ex) n = 7 일때 마지막 비트는 1이며, 자기 자신 정보만 담는다.

# 특정 값을 변경 할때: 0이 아닌 마지막 비트 만큼 더하면서 구간들의 값을 변경 ex) = 3rd

방식을 이용 할 경우 시간 복잡도는 O(log N) 을 보장 한다.

 

4. prefix_sum 함수

i번째 수 까지의 누적 합을 계산 하는  함수이며

0이 아닌 마지막 비트 만큼 빼가면서 이동한다.

 

5.update 함수

i번째 수를 dif 만큼 더 하는 함수

 

6.interval_sum 함수

start 부터 end 까지의 구간 합을 계산 하는 함수이다.

 

7. a == 1 일때

update(b, c - arr[b]) # 바뀐 크기(dif) 만큼 적용

 

8. a != 1 일때

inverval_sum(b, c) 이용 하게 되면 된다.

728x90