백준 문제풀이

백준 10848 - 팔렘방의 다리 [Python]

Vermeil 2022. 4. 12. 22:15

https://www.acmicpc.net/problem/10848

 

10848번: 팔렘방의 다리

입력의 첫 줄에는 K와 N이 주어진다. 이후 N개의 줄에는 4개의 값 Pi, Si, Qi, Ti가 각각 주어진다. Pi와 Qi는 한글자 'A' 혹은 'B'이다. 0 ≤ Si, Ti ≤ 1, 000, 000, 000 사는 곳이나 사무실이 서로 다른 시민에

www.acmicpc.net

정말정말 좋은 문제이다.

 

일단, 한 도시에 집과 사무실이 모두 있는 경우는 없다고 가정하고, 다리의 길이를 무시해보자.

 

 

\(k = 1\)일 때는, S와 T를 합쳐 정렬한 후, 중앙값이 다리의 좌표가 된다.

 

 

\(k = 2\)일 때...

 

\(f_i(x)\) 를 i번째 사람이 다리 x를 통해서 출근할 때의 거리라고 하면,

\(f_i(x) = |S_i - x| + |T_i - x|\) 가 된다.

 

다리의 좌표를 \(a, b(a < b)\) 라고 하면, 이때의 거리는

\(\sum_{i=1}^{m}{min(f_i(a), f_i(b))}\) 가 된다.

 

\(f_i(x)\)를 그려보면,

이런식으로 나온다. 절댓값 그래프 두 개를 더한 형태이다.

 

여기서 \( f_i(a) < f_i(b) \) 인 곳, 즉 다리 a를 선택해야 할 조건을 찾아보자.

 

\(S_i - a < b - T_i \)여야 하므로, \(S_i + T_i < a + b \)로 식을 정리할 수 있다.

이때, \(S_i + T_i \)는 고정된다는 점을 활용해보자.

 

집과 사무실의 좌표 쌍을 \(S_i + T_i \) 순으로 정렬하면, \(i\)를 기준으로 왼쪽은 다리 \(a\), 오른쪽은 다리 \(b\)를 사용하도록 하면 된다! 이는 곧, 왼쪽과 오른쪽을 각각 \(k = 1\)일 때의 방법으로 풀면 된다는 말이다.

 

나는 굳이 중앙값을 구하려고 해서 여기서 막혔다.

조금 생각해보니, 중앙값이 \(x\)일 때 거리의 합은 \( (x - sum_{left}) + (sum_{right} - x) = sum_{right} - sum_{left} \)인 것을 알 수 있었다. 이때 \(sum_{left} \)와 \(sum_{right} \)는 각각 중앙값을 기준으로 왼쪽 합과 오른쪽 합이다.

 

\(sum_{left}\)와 \(sum_{right} \)를 구하기 위해서는, 우선순위 큐 2개를 사용하여 \(left\)와 \(right\)를 각각 관리해주면 된다.

 

 

왼쪽부터 돌리면 \(k = 1\)에 대해 풀 수 있고, 그리고 오른쪽부터 다시 돌리면 \(k = 2\)에 대해서도 풀 수 있다.

왼쪽에서 돌릴 때를 \([1, i]\)에서 다리 a를 사용하는 경우라고 보면 되므로, 거리 합을 미리 저장해두면 된다.

 

이렇게 구한 값에, 무시했던 다리의 길이와 같은 도시에 있는 쌍을 더해주면 된다.

 

 

코드는 71줄부터 보면 된다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import sys, math
import heapq
 
from collections import deque
 
input = sys.stdin.readline
 
hqp = heapq.heappop
hqs = heapq.heappush
 
 
# input
def ip(): return int(input())
def sp(): return str(input().rstrip())
 
def mip(): return map(int, input().split())
def msp(): return map(str, input().split().rstrip())
 
def lmip(): return list(map(int, input().split()))
def lmsp(): return list(map(str, input().split()))
 
 
# gcd, lcm
def gcd(x, y):
    while y:
        x, y = y, x % y
    return x
 
 
def lcm(x, y):
    return x * y // gcd(x, y)
 
 
# prime
def isPrime(x):
    if x <= 1return False
    for i in range(2int(x ** 0.5+ 1):
        if x % i == 0:
            return False
    return True
 
 
# Union Find
# p = [i for i in range(n + 1)]
 
def find(x):
    if x == p[x]:
        return x
    q = find(p[x])
    p[x] = q
    return q
 
 
def union(x, y):
    x = find(x)
    y = find(y)
 
    if x != y:
        p[y] = x
 
 
def getPow(a, x):
    ret = 1
    while x:
        if x & 1:
            ret = (ret * a) % MOD
        a = (a * a) % MOD
        x >>= 1
    return ret
 
############### Main! ###############
 
INF = 10 ** 9 + 7
 
lhq = []
rhq = []
lp = 0
rp = 0
 
def push(x):
    global lp, rp
    if lhq:
        m = -hqp(lhq)
        hqs(lhq, -m)
    else:
        m = INF
 
    if x <= m:
        hqs(lhq, -x)
        lp += x
    else:
        hqs(rhq, x)
        rp += x
 
    if len(lhq) - 1 > len(rhq):
        tmp = -hqp(lhq)
        hqs(rhq, tmp)
        lp -= tmp
        rp += tmp
 
    if len(lhq) < len(rhq):
        tmp = hqp(rhq)
        hqs(lhq, -tmp)
        lp += tmp
        rp -= tmp
 
k, n = mip()
= []
plus = 0
for i in range(n):
    q = lmsp()
    if q[0== q[2]:
        plus += abs(int(q[1]) - int(q[3]))
    else:
        a.append([int(q[1]), int(q[3])])
 
a.sort(key=lambda x: x[0+ x[1])
= []
for i in range(len(a)):
    push(a[i][0])
    push(a[i][1])
    p.append(rp - lp)
 
if p:
    ans = p[-1]
else:
    ans = 0
 
if k == 2:
    p.append(0)
    lhq = []
    rhq = []
    lp = 0
    rp = 0
    for i in range(len(a) - 1-1-1):
        push(a[i][0])
        push(a[i][1])
        ans = min(ans, rp - lp + p[i - 1])
 
print(ans + len(a) + plus)
 
######## Priest greedev ########
cs

 

여담으로, 코포 레이팅으로 볼록 껍질을 만들어버리는 대참사가 일어났다....