문제풀이/백준
2042 구간 합 구하기(백준) 문제풀이
soo-dal
2024. 2. 15. 07:37
문제
https://www.acmicpc.net/problem/2042
2042번: 구간 합 구하기
첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄
www.acmicpc.net
문제 접근 방식
배열의 구간 합을 구하는 문제이며, 배열 값이 계속 변경될 수 있다는 특징이 있다. 처음부터 끝까지 순회하며 합을 구한다면 O(N)이 걸리며, M 번 변경이 되기 때문에 그때마다 합을 다시 구해줘야한다. 전체적으로 O(NM)의 시간이 걸리기 때문에 시간 초과가 발생하며 다른 방법을 찾아야한다. 문제를 풀기 위해 세그먼트 트리라는 자료구조를 사용할 필요가 있으며 세그먼트 트리를 요약하면 아래와 같다.
세그먼트 트리(Segment Tree)
개념 - 구간합을 이분 트리로 구해놓고 원하는 구간을 빠르게 검색할 수 있는 자료구조
관련 함수
build - 구간 합을 저장하는 세그먼트 트리를 생성하는 함수
update - 특정 위치의 수를 수정할 시, 트리의 값을 갱신하는 함수
query - 특정 구간의 합을 반환하는 함수
특징
- 배열의 갱신 및 검색 시 각 O(logN)의 시간 복잡도를 가져도 갱신되는 구간합을 빠르게 구할 수 있다.
시간 복잡도
트리를 생성하기 위해 전체 노드 수 만큼 순회를 해야하며, 트리의 최대 노드 개수가 2^21개(= 약 2*1e6)이므로 트리를 생성하는데 무리가 없다. 또한, 갱신 및 검색을 할 때는 각각 O(logN)이 걸리며 갱신과 검색의 최대 횟수가 각각 1e4이므로 시간내에 동작한다.
코드
import sys
input=sys.stdin.readline
print=sys.stdout.write
class SegmentTree:
def __init__(self,nums):
self.nums=nums
# 트리의 최대 노드 개수를 구해서 배열로 선언
count=len(self.nums)
k,size=1,1
while True:
size=2**k
if size>=count:
break
k+=1
tree_size=2**(k+1)
self.trees=[0]*(tree_size+1)
# nums를 가지고 세그먼트 트리를 만들어준다.
self.build(0,len(self.nums)-1,1)
# 재귀함수를 통해 좌,우 부분의 합을 구하고, 해당 합을 합쳐서 현재 노드의 합을 구한다.
def build(self,start,end,node):
if start==end:
self.trees[node]=self.nums[start]
return self.trees[node]
mid=(start+end)//2
self.trees[node]=self.build(start,mid,node*2)+self.build(mid+1,end,node*2+1)
return self.trees[node]
# 재귀함수를 통해 해당하는 index로 이동하여 값을 갱신한다.
def update(self,start,end,node,index,value):
if index==start and index==end:
self.nums[index]=value
self.trees[node]=value
return self.trees[node]
elif not (start<=index<=end): # 구간내에 존재하지 않는다면 기존 트리 값을 반환
return self.trees[node]
# 중간 지점을 기준으로 좌, 우 부분에 update 함수를 사용하여 갱신할 부분이 있으면 갱신하고
# 좌,우 갱신 값을 통해 현재 노드의 값도 갱신한다.
mid=(start+end)//2
self.trees[node]=self.update(start,mid,node*2,index,value)+self.update(mid+1,end,node*2+1,index,value)
return self.trees[node]
# 구하고자 하는 범위(left,right)를 입력받아서 세그먼트 트리에서의 합을 구한다.
def query(self,start,end,node,left,right):
if right<start or left>end: # 범위 내에 존재하지 않는다면 0을 반환한다.
return 0
elif left <= start and end<=right: # 구하고자 하는 범위 내에 세그먼트가 포함된다면 해당 노드값을 반환한다.
return self.trees[node]
mid=(start+end)//2
# 좌측 부분과 우측 분분의 합을 합쳐서 총합을 구한다.
answer=self.query(start,mid,node*2,left,right) + self.query(mid+1,end,node*2+1,left,right)
return answer
if __name__=='__main__':
n,m,k=map(int,input().split())
nums=[int(input()) for _ in range(n)]
size=len(nums)
segment_tree=SegmentTree(nums)
for _ in range(m+k):
a,b,c=map(int,input().split())
if a==1:
segment_tree.update(0,size-1,1,b-1,c)
elif a==2:
print(str(segment_tree.query(0,n-1,1,b-1,c-1))+'\n')
회고
1. 세그먼트 트리를 통해 O(logN)으로 구간 합을 구할 수 있다.
2. 갱신이 발생했을 경우에도 O(logN)이 걸린다.
3. 전체적으로 봤을 때 단순 구간 합을 구하는 방식보다 효율적이다.