오히려 좋아..

상황이 나쁘게만 흘러가는 것 같을 때 외쳐보자.. .

궁금한 마음으로 포트폴리오 보기

Algorithm/기본 알고리즘 구현

세그먼트 트리 python

junha6316 2021. 7. 30. 20:27

배열의 부분합을 구할 떄 사용하는 자료구조

일반적인 for으로 구현하면 N이 리스트의 길이일떄  O(N)의 시간복잡도로 계산 가능하지만 세그먼트 트리를 사용하면 O(logN)으로 가능하다. 초기화 과정 O(logN) 출력 O(logN). 세그먼트 트리는 크게 두가지 함수가 존재한다. 하나는 주어진 정수 리스트를 바탕으로 세그먼트 트리를 만드는 init함수와 만든 세그먼트 트리를 이용해  특정 구간합을 반환하는 query함수이다. 구체적인 구현은 아래와 같다.

 

세그먼트 트리는 이진 트리로 루트, 왼쪽자식, 오른쪽 자식으로 구성되어있다.

 

한편 세그먼트 트리를 구성하는 leaf 노드를 제외한 모든 정점의 값은 해당 정점을 중심으로 왼쪽 자식과 오른쪽 자식으로 합이고  리프노드의 자리에는 주어진 정수 리스트의 값이 들어간다.

세그먼트 트리의 특정 정점의 값 (leaf 제외)  = 왼쪽 자식의 값 + 오른쪽 자식의 값

그렇기 떄문에 세그먼트 트리의 루트노드(가장 꼭대기 노드)에는 정수 리스트 원소의 총합이 들어가고 루트 노드의 왼쪽 자식은 전체 정수 리스트의 [0 ~ N//2]의 구간합 오른쪽 자식은 [N//2 + 1 ~ 마지막 값] 의 구간합이 들어간다. 이걸 이제 코드로 구현하면 된다. 코드는 아래와 같다. 

 

루트 노드 : 리스트 전체의 총 합

루트 노드의 왼쪽 자식 : 0번째 원소부터 ~ (N//2-1)원소까지의 합

루트 노드의 오른쪽 자식 : (N//2 + 1)번째 원소부터 N-1번쨰 원소의 합

num_list =[i for i in range(N+1)] #숫자를 담는 리스트
tree = [0] * (N * 4) #세그먼트가 가장 많이 분할 되면 전체 노드의 4배 완전 이진 트리를 생각 하면된다.

def init(start, end, here):
    """ 
    Segment Tree Initialization 
    start : 범위 시작 정점 인덱스
    end : 범위 마지막 정점 인덱스
    here : 현재 위치
    범위를 1/2으로 줄이며 재귀적으로 수행된다.
    """
    if start == end: #start와 end가 같으면 leaf노드
        tree[here] = num_list[start]
        return tree[here]
  
    mid = (start + end) // 2 #중간값을 계산한다.
    
    """
    세그먼트 현재 위치에서의 값 = (왼쪽 서브트리의 합) + (오른쪽 서브트리의 합) 
     
     	[3]
     [1]   [2]
     
    왼쪽 서브트리 범위   : start ~ mid 
    오른쪽 서브트리 범위 :  mid+1 ~ end 
    """
    
    tree[here] = init(start, mid, here * 2)  + init(mid + 1, end, here * 2 + 1)
    return tree[here]
    
def query(start, end, here, left, right):
    """
    구간 합을 찾는 함수
    start : 범위 시작 정점 인덱스
    end : 범위 마지막 정점 인덱스
    here : 현재 위치
    left : 찾아야하는 구간합 시작 인덱스
    right : 찾아야하는 구간합 마지막 인덱스
    """
    
    if left > end or right < start: return 0 #범위를 벗어나는 경우
    if left <= start and end<=right:return tree[here] #범위 내에 있는 경우
    mid = (start+end)//2
    #현재 위치의 합 = 왼쪽 범위의 구간합 + 오른쪽 범위의 구간 합
    sub_sum = query(start,mid, here*2,left, right) + query(mid+1, end, here*2+1, left, right)
    return sub_sum

init(1, N+1, 1) #tree[1]에 세그먼트 트리의 정점이 저장, 초기화
s, e = map(int, input().split())

print(query(1, N+1, 1, s, e)) #부분합 반환