Q : N개의 정수 배열에서 a~b까지 합을 구하려면 어떻게 하나요?
A : for문을 돌면서 더해요! $O(N)$
Q : 그럼 Q개의 쿼리형식으로 들어온다면요?
A : 누적합배열을 사용해요! $O(N+Q)$
Q : 그럼 여기에 업데이트도 있다면 있다면 어떻게 할까요?
A : 세그먼트 트리를 쓰면 되죠! $O(Q \log N)$
위의 모든 답변은 실제로 정답이고 더 어려운 알고리즘이 하위 문제를 해결할 수 있다.
예를들어 누적합으로 1번 문제를 해결할 수 있고 세그먼트 트리로 1, 2번 문제를 해결할 수 있다.
하지만 위 문제들을 제곱근 분할법(a.k.a. 루트질)을 사용해서 해결할 수도 있다!
물론 쉬운 알고리즘을 사용할 수 있다면 그걸 사용하는게 좋겠지만 연습용으로 풀기 좋다.
어떻게 해결할 지 알아보자.
버킷
위의 질문에서 나온 수열이 $1,2,3,4,5,6,7,8,9,10$이라고 하자.
이걸 임의의 사이즈의 버킷으로 묶은 후 합을 저장할 것이다.
같은 색은 같은 버킷에 있다는 뜻이고 같은 버킷에 있는 수들의 합은 B값으로 따로 관리해준다.
이제 2~8까지의 합을 구해보면 3~7은 2~8에 속하기때문에 직접 더하는 대신 B2값을 사용하면 된다.
그 외의 범위는 직접 더해주면 2~8까지 덧셈을 6번하는 것 대신 2+B2+8을 함으로써 3번만 해도 된다.
자 그럼 중요한건 버킷의 사이즈를 정하는 방법이다.
우선 위에서는 아무렇게나 정했지만 최악의 경우 4~8, 4~9, 5~8, 6~9, 7~9 이런식으로 들어오면 직접 더하는거랑 차이가 없다.
이 경우 버킷 $i$의 크기를 $b_i$라고 하면 $O(b_2+b_3)$만큼의 계산량을 사용하게 된다.
또한 2~9까지라면 $O(b_1+b_3)$, 2~6까지라면 $O(b_1+b_2)$가 될 것이다.
만약 위의 경우처럼 $O(b_2+b_3)$가 걸리는 쿼리가 $A$개, $O(b_1+b_3)$인 쿼리가 $B$개, $O(b_1+b_2)$가 걸리는 쿼리가 $C$개라면 모든 쿼리를 처리했을 때 $O((B+C)b_1+(A+C)b_2+(A+B)b_3)$이 걸리게 되고 우리는 이를 최소화해야한다. (아주 최악의 최악 상황만 모아놓은 것이다.)
산술평균 기하평균에 의해 위의 시간복잡도는 $(B+C)b_1+(A+C)b_2+(A+B)b_3 \ge 3\sqrt[3]{(B+C)(A+C)(A+B)b_1b_2b_3}$가 되고 $b_1 = b_2 = b_3$일 때 최소가 됨을 알 수 있다.
버킷의 개수가 3일 때를 예로 들었지만 일반적인 상황에서도 유사하게 귀납적으로 적용할 수 있다.
첫 번째 포인트
모든 버킷의 크기는 같아야 한다!
제곱근 분할법
자 그럼 버킷의 크기가 같아야함을 알았으니 모든 버킷의 크기를 $B$라고 하자. $N$이 $B$로 나누어떨어지지 않을 경우 뒤에 남는 원소들이 있을 수 있는데 이는 맨 뒷 버킷에 추가하거나 버킷을 따로 만들어도 상관 없다. (개인적으로 후자를 선호한다.)
임의로 $B = 4$로 정했고 $N \equiv 2 \mod B$이므로 따로 버킷을 만들었다.
이제 2~9 쿼리가 반복적으로 들어오면 $O(B + B + \frac{N-2B}{B})$를 반복해야 할 것이다.
$\frac{N-2B}{B}$은 중간에 버킷들을 더하는 횟수이다. 현재는 B2 한번만 더하지만 $B$의 크기가 달라지면 중간에 있는 버킷의 개수가 달라지므로 그 만큼 더해야 될 것이다.
$\frac{N-2B}{B} = \frac{N}{B}-2$이므로 시간복잡도는 $O(2B+\frac{N}{B}-2)=O(B+\frac{N}{B})$가 되고 이 또한 최적화 해야한다.
또 다시 산술평균 기하평균을 쓰면 $B+\frac{N}{B} \ge 2\sqrt{N}$이 되고 $B = \frac{N}{B}$일 때 최소가 된다.
즉 버킷의 사이즈는 $B = \sqrt{N}$일 때 최소가 된다!!
이렇게 "루트개수"로 버킷을 "나누기" 때문에 제곱근 분할법(sqrt decomposition)이라고 불린다.
두 번째 포인트
버킷의 크기는 수열의 크기의 제곱근이어야 한다!
제곱근 분할법으로 문제풀기
자 그럼 이제 뭔지 알았으니 이것을 이용하여 2042 구간 합 구하기를 풀어보자!
크기가 $B = \lfloor \sqrt{10} \rfloor = 3$인 버킷으로 구성했다.
1. 업데이트 쿼리
query) 8번째 인덱스의 수를 11로 변경하시오
이런식으로 매 쿼리마다 $O(1)$에 해결할 수 있다.
2. 구간 합 쿼리
query) 2 ~ 8 합을 구하시오
이런식으로 구하게 되는데 하나의 버킷은 많아봐야 $\sqrt{N}$개 만큼 있고 B2같이 버킷전체의 값을 사용하는 경우도 많아봐야 $\sqrt{N}$개이므로 쿼리마다 $O(\sqrt{N})$을 사용하게 된다!
문제를 시간복잡도 $O(Q\sqrt{N})$으로 해결할 수 있으므로 충분히 해결할 수 있다.
물론 세그먼트 트리보다 느리니까 제곱근 분할법 연습용으로만 사용하자.
정답코드
#include <bits/stdc++.h>
#define int long long
using namespace std;
signed main()
{
std::ios_base::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int n, m, k;
cin >> n >> m >> k;
int b = sqrt(n); // bucket size
vector<int> arr(n);
for (auto &i : arr)
cin >> i;
vector<int> brr; // bucket
for (int i = 0; i < n; i += b)
{
int s = 0;
for (int j = i; j < min(i + b, n); j++)
s += arr[j];
brr.push_back(s);
}
// O(qn**0.5)
for (int query = 0; query < m + k; query++)
{
int op, x, y;
cin >> op >> x >> y;
--x;
if (op == 1)
{
// 버킷 업데이트 O(1)
int where = x / b;
brr[where] -= arr[x];
arr[x] = y;
brr[where] += arr[x];
}
else if (op == 2)
{
--y;
int left = x / b;
int right = y / b;
int ans = 0;
for (int i = left + 1; i < right; i++)
ans += brr[i];
if (left == right)
{
// 같은 버킷이면 선형합 O(n**0.5)
for (int i = x; i <= y; i++)
ans += arr[i];
}
else
{
// 다른 버킷이면 왼쪽 버킷, 오른쪽 버킷 따로 더함 O(n**0.5)
for (int i = x; i < (left + 1) * b; i++)
ans += arr[i];
for (int i = right * b; i <= y; i++)
ans += arr[i];
}
cout << ans << endl;
}
}
return 0;
}
하지만 출력 쿼리가 최솟값을 출력하는 것이라면?
버킷을 스플레이같은 bbst나 pbds로 관리해서 $O(Q\sqrt{N}\log{N})$에 해결할 수 있긴 하다.
... 그냥 세그먼트 트리를 사용하자...
연습문제
모두 세그먼트 트리로 쉽게 풀리지만 제곱근 분할법을 연습해보자!
'프로그래밍 > 알고리즘' 카테고리의 다른 글
[알고리즘] smaller to larger를 사용해서 분리집합을 구현하자! (0) | 2023.04.13 |
---|---|
[알고리즘] Mo's 알고리즘 (0) | 2023.03.23 |
[알고리즘] Chat-GPT와 Problem Solving (2) | 2023.02.07 |
[알고리즘] KUPC 2022 출제/운영 후기 (1) | 2022.12.05 |
[알고리즘] LG CNS Code Monster 2022 (0) | 2022.11.29 |