백준 문제풀이

백준 16993 - 연속합과 쿼리 [Python]

Vermeil 2022. 4. 15. 19:17

https://www.acmicpc.net/problem/16993

 

16993번: 연속합과 쿼리

길이가 N인 수열 A1, A2, ..., AN이 주어진다. 이때, 다음 쿼리를 수행하는 프로그램을 작성하시오. i j : Ai, Ai+1, ..., Aj에서 가장 큰 연속합을 출력한다. (1 ≤ i ≤ j ≤ N) 수열의 인덱스는 1부터 시작

www.acmicpc.net

세그먼트 트리를 사용해서 풀면 된다.

 

왼쪽 끝점을 포함할 때의 구간합 최대 \(Seg_{ls}\)

오른쪽 끝점을 포함할 때의 구간합 최대 \(Seg_{rs}\)

어떤 구간에서의 구간합 최대 \(Seg_{ms}\)

구간의 합 \(Seg_{s}\)

 

자식 노드를 각각 \(L, R\)로 두면,

 

\(Seg_{ls} = max(L_{ls}, L_{s} + R_{ls})\)

\(Seg_{rs} = max(R_{rs}, R_{s} + L_{rs})\)

\(Seg_{ms} = max(L_{ms}, R_{ms}, L_{rs} + R_{ls})\)

\(Seg_{s} = L_{s} + R_{s}\)

 

가 된다.

 

구현하면 끝이다.

74줄부터 보면 된다

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import sys, math
import heapq
 
from dataclasses import dataclass
from collections import deque
from bisect import bisect_left, bisect_right
 
input = sys.stdin.readline
 
hqp = heapq.heappop
hqs = heapq.heappush
 
 
# input
def ip(): return int(input())
def sp(): return str(input().rstrip())
 
def mip(): return map(int, input().split())
def msp(): return map(str, input().split().rstrip())
 
def lmip(): return list(map(int, input().split()))
def lmsp(): return list(map(str, input().split()))
 
 
# gcd, lcm
def gcd(x, y):
    while y:
        x, y = y, x % y
    return x
 
 
def lcm(x, y):
    return x * y // gcd(x, y)
 
 
# prime
def isPrime(x):
    if x <= 1return False
    for i in range(2int(x ** 0.5+ 1):
        if x % i == 0:
            return False
    return True
 
 
# Union Find
# p = [i for i in range(n + 1)]
 
def find(x):
    if x == p[x]:
        return x
    q = find(p[x])
    p[x] = q
    return q
 
 
def union(x, y):
    x = find(x)
    y = find(y)
 
    if x != y:
        p[y] = x
 
 
def getPow(a, x):
    ret = 1
    while x:
        if x & 1:
            ret = (ret * a) % MOD
        a = (a * a) % MOD
        x >>= 1
    return ret
 
 
############### Main! ###############
 
INF = 10 ** 9
 
@dataclass
class Node:
    ls: int = -INF
    rs: int = -INF
    ms: int = -INF
    s: int = 0
 
seg = [Node() for _ in range(404040)]
 
def init(x, s, e):
    if s == e:
        seg[x] = Node(a[s - 1], a[s - 1], a[s - 1], a[s - 1])
        return seg[x]
    m = (s + e) // 2
    L = init(x * 2, s, m)
    R = init(x * 2 + 1, m + 1, e)
    seg[x].ls = max(L.ls, L.s + R.ls)
    seg[x].rs = max(R.rs, R.s + L.rs)
    seg[x].ms = max(L.ms, R.ms, L.rs + R.ls)
    seg[x].s = L.s + R.s
    return seg[x]
 
def get(x, l, r, s, e):
    if e < l or r < s:
        return Node()
    if l <= s <= e <= r:
        return seg[x]
    m = (s + e) // 2
    L = get(x * 2, l, r, s, m)
    R = get(x * 2 + 1, l, r, m + 1, e)
    c = Node()
    c.ls = max(L.ls, L.s + R.ls)
    c.rs = max(R.rs, R.s + L.rs)
    c.ms = max(L.ms, R.ms, L.rs + R.ls)
    c.s = L.s + R.s
    return c
 
= ip()
= lmip()
init(11, n)
= ip()
for qqq in range(q):
    i, j = mip()
    print(get(1, i, j, 1, n).ms)
 
######## Priest greedev ########
cs