알고리즘

Segment Tree Beats // 백준 17474 - 수열과 쿼리 26 [C++]

Vermeil 2022. 5. 29. 19:04

https://www.acmicpc.net/problem/17474

 

17474번: 수열과 쿼리 26

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오.  1 L R X: 모든 L ≤ i ≤ R에 대해서 Ai = min(Ai, X) 를 적용한다.  2 L R: max(AL, AL+1, ..., AR)을 출력한다. 3

www.acmicpc.net

[알고리즘 분류]

Segment tree beats

 

 

 

greedev에게 세그비츠가 웰논인지 물어봤는데, 다음과 같은 답변을 받았다.

그냥 다 웰논이라고 하지 이럴거면

(수쿼 25~30은 모두 세그비츠를 사용한다고 한다.)

 

나는 정말 무식하지만, 최대한 열심히 설명해보도록 하겠다. 나보다 설명을 몇십배는 잘 해둔 글이 수두룩하기 때문에, 나의 글로 이해하기가 힘들다면 매우 높은 확률로 내가 글을 못 써서 그런 것이다

 

공부하고 싶다면 삼성 소프트웨어 멤버십 글을 참고하였으면 좋겠다.

 

 

 

Segment Tree Beats는, lazy조건을 약화하고, 리턴조건을 강화하는 기술이다. 정말정말 어렵다. 레이지 세그 최종진화

 

문제의 업데이트 쿼리를 보자.

\(l \leq i \leq r\)인 \(i\)에 대해, \(A_i = min(A_i, X)\)를 적용하는 쿼리이다. 쉬운 중단조건부터 건드려보자.

 

어떤 구간 \([l, r]\) 내의 최댓값이 \(X\)보다 크지 않다면, 이 구간 내에서 갱신을 해도 수열의 값들은 변하지 않는다. 이를 토대로, 중단조건을 \(r < s\) || \(e < l\) || \(max(A_{l, ... , r}) \leq X\)로 수정하는 것을 생각해볼 수 있다. 이제 갱신조건을 건드려보자.

 

 

 

우리는 구간의 합을 구해야 하므로, '구간 내 수들이 모두 같을 때' 값에 변화를 주도록 하면 쉽게 구간 합을 구할 수 있다.

\([3, 5, 5, 5, 10, 10]\)에서 구간 전체에 [1번 쿼리: 4]를 주었다고 하면,

\([3, 5 - 1, 5 - 1, 5 - 1, 10 - 6, 10 - 6]\)이 된다. 깔끔하다!

 

그러나 [100000, 1, 100000, 1, 100000, ... , 1]과 같은 구간에서 \((1, L, R, 99999)\)라는 쿼리가 들어온다면, 그리고 99998, 99997, ... 이런식으로 쿼리를 준다면, \(O(QNlgN)\)이라는 정신나간 시간복잡도가 나오고 말 것이다. 이런 불상사를 막기 위해서 다른 방법을 생각해보자.

 

 

 

두 번째로 큰 수를 활용하는 방법을 생각해야 할 것 같다. 구간에서 가장 큰 값을 max, 두 번째로 큰 값을 max2라고 하자.

 

\(max2 < k < max\)인 경우, 구간의 합은 \((max - k) * max\_cnt\)만큼 감소한다. 전파도 빠르게 된다.

 

\(k \leq max2\)인 경우는 어떨까? 이 경우에서 갱신을 한다면, 서로 다른 수의 개수가 적어도 하나 이상 감소한다.

위는 max, 아래는 max2이다.

위와 같은 상황에서, \((1, 1, 2, 5)\)라는 쿼리가 들어왔다고 해보자.

서로 다른 두 수가 하나로 합쳐졌다. 이런 이유로 \(k \leq max2\)인 경우는 많아야 N번 정도만 일어나게 되고, (인접한 두 수가 다른 수로 바뀔 수도 있지만, 그래도 대략 \(2Q\) 정도밖에 되지 않는다) 따라서 \((N+Q)lgN\)의 시간복잡도가 나오게 된다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
#include <iostream>
#include <algorithm>
 
using namespace std;
#define endl '\n'
typedef long long ll;
 
 
struct Node{
    ll max, cnt, max2, sum;
};
 
int N, M;
int a[1010101];
Node seg[4040404];
 
 
Node merge(Node l, Node r){
    if (l.max == r.max) return {l.max, l.cnt + r.cnt, max(l.max2, r.max2), l.sum + r.sum};
    if (l.max < r.max) return {r.max, r.cnt, max(l.max, r.max2), l.sum + r.sum};
    return {l.max, l.cnt, max(l.max2, r.max), l.sum + r.sum};
}
 
Node init(int x, int s, int e){
    if (s == e) return seg[x] = {a[s], 1-1, a[s]};
    int m = (s + e) / 2;
    return seg[x] = merge(init(x * 2, s, m), init(x * 2 + 1, m + 1, e));
}
 
void lazyProp(int x, int s, int e){
    if (s == e) return;
 
    for (int i = x * 2; i <= x * 2 + 1; i++){
        if (seg[x].max < seg[i].max){
            seg[i].sum -= seg[x].max * seg[i].cnt;
            seg[i].max = seg[x].max;
        }
    }
}
 
void update(int x, int s, int e, int l, int r, int k){
    lazyProp(x, s, e);
    if (r < s || e < l || seg[x].max <= k) return;
    if (l <= s && e <= r && seg[x].max2 < k){
        seg[x].sum -= seg[x].max * seg[x].cnt;
        seg[x].max = k;
        lazyProp(x, s, e);
        return;
    }
 
    int m = (s + e) / 2;
    update(x * 2, s, m, l, r, k);
    update(x * 2 + 1, m + 1, e, l, r, k);
 
    seg[x] = merge(seg[x * 2], seg[x * 2 + 1]);
}
 
ll getMax(int x, int s, int e, int l, int r){
    lazyProp(x, s, e);
    if (r < s || e < l) return 0;
    if (l <= s && e <= r) return seg[x].max;
 
    int m = (s + e) / 2;
    return max(getMax(x * 2, s, m, l, r), getMax(x * 2 + 1, m + 1, e, l, r));
}
 
ll getSum(int x, int s, int e, int l, int r){
    lazyProp(x, s, e);
    if (r < s || e < l) return 0;
    if (l <= s && e <= r) return seg[x].sum;
 
    int m = (s + e) / 2;
    return getSum(x * 2, s, m, l, r) + getSum(x * 2 + 1, m + 1, e, l, r);
}
 
 
int main(){
    ios_base::sync_with_stdio(false); cin.tie(0);
    cin >> N;
    for (int i = 1; i <= N; i++){
        cin >> a[i];
    }
    init(11, N);
    cin >> M;
    for (int qr = 0; qr < M; qr++){
        int Q, i, j, k;
        cin >> Q >> i >> j;
        if (Q == 1){
            cin >> k;
            update(11, N, i, j, k);
        }
        else if (Q == 2){
            cout << getMax(11, N, i, j) << endl;
        }
        else{
            cout << getSum(11, N, i, j) << endl;
        }
    }
    return 0;
}
cs

파이썬으로는 풀 수 없다ㅠ

 

 

너무 어려운 주제라서 설명을 너무 못했다. 오류도 엄청 많을 것 같다..