백준 문제풀이

백준 13544 - 수열과 쿼리 3 [Python]

Vermeil 2021. 7. 1. 22:15

https://www.acmicpc.net/problem/13544

 

13544번: 수열과 쿼리 3

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. i j k: Ai, Ai+1, ..., Aj로 이루어진 부분 수열 중에서 k보다 큰 원소의 개수를 출력한다.

www.acmicpc.net

정렬된 배열 \(A\)가 있다고 하자. 이 \(A\) 내에서 어떤 수 \(k\)보다 큰 수의 개수를 \(O(N)\)보다 빠른 시간에 구하려면, 이진 탐색을 사용하여 \(O(lgN)\)에 구할 수 있다.

 

그런데 이 문제는 배열이 정렬되어 있지 않고, 그리고 구간을 나누어 쿼리에 답하라고 한다.

 

만약, 배열을 잘 잘라서 정렬된 상태로 가지고 있으면, 그 잘려진 배열 각각에서 쿼리 구간의 범위에 포함되는 부분만 가져와, k보다 큰 수를 (앞서 설명한 이진 탐색으로) 구하여 다 더하는 방식으로 답을 구할 수 있지 않을까?

 

말이 좀 복잡하긴 한데, 대체 어떻게 해먹는 문제인지 보기 위해 알고리즘 분류를 누르면, 머지 소트 트리라는걸 쓰라고 한다.

 

머지 소트 트리는 이름 그대로 머지소트가 가미된 세그트리이다.(?) 세그트리에 배열을 넣는 방식으로 구현된다.

세그트리를 짤 줄 안다면, 쉽게 구현할 수 있다(아마도)

 

https://justicehui.github.io/medium-algorithm/2020/02/25/merge-sort-tree/

이 글의 그림을 통해 쉽게 이해할 수 있다

 

 

이제 이걸 어떻게 써먹을까?

구간 \(A_{i} ~ A_{j})를 포함하는 부분만 골라서, upper bound 등의 이진 탐색을 이용해 k보다 큰 수를 빠르게 찾을 수 있다!

이게 끝이다. 너무 간단하다

 

 

- 파이썬은 merge 함수를 직접 구현해야 한다. 화이팅

- 채점하는데 9분이나 걸렸다. 파이썬 너무느려

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys
from bisect import bisect_right
input = sys.stdin.readline
 
seg = [[] for _ in range(2020202)]
 
def merge(l, r):
    lx = 0
    rx = 0
    retarr = []
    while lx < len(l) and rx < len(r):
        if l[lx] < r[rx]:
            retarr.append(l[lx])
            lx += 1
        else:
            retarr.append(r[rx])
            rx += 1
    if rx == len(r):
        for i in range(lx, len(l)):
            retarr.append(l[i])
    else:
        for i in range(rx, len(r)):
            retarr.append(r[i])
    return retarr
 
def init(x, s, e):
    if s == e:
        seg[x] = [a[s - 1]]
        return seg[x]
    mid = (s + e) // 2
    seg[x] = merge(init(x * 2, s, mid), init(x * 2 + 1, mid + 1, e))
    return seg[x]
 
def getAns(x, l, r, s, e, k):
    if l <= s and e <= r:
        return len(seg[x]) - bisect_right(seg[x], k)
    if e < l or r < s:
        return 0
    mid = (s + e) // 2
    return getAns(x * 2, l, r, s, mid, k) + getAns(x * 2 + 1, l, r, mid + 1, e, k)
 
= int(input())
= list(map(int, input().split()))
init(11, n)
 
= int(input())
ans = 0
for qu in range(q):
    i, j, k = map(int, input().split())
    i ^= ans
    j ^= ans
    k ^= ans
    ans = getAns(1, i, j, 1, n, k)
    print(ans)
 
cs