Đề bài: http://www.spoj.com/problems/APIO10A/
Cho 1 dãy N số nguyên. Một hàm số bậc 2 : f(x) = a * x ^ 2 + b * x + c. Phân dãy trên thành các đoạn liên tiếp sao cho tổng các hàm f trên các dãy là lớn nhất (giá trị f của 1 dãy là f(x) với x là tổng của dãy đó).
Input format :
- Dòng đầu là số test case T
- Mỗi test case gồm 3 dòng :
- Dòng đầu là số nguyên dương N – số phần tử của dãy.
- Dòng 2 là 3 số nguyên a, b, c.
- Dòng còn lại gồm n số x1, x2, …, xn là n phần tử của dãy.
Output format :
- Mỗi test case gồm 1 dòng, là kết quả của bài toán.
Giới hạn :
T<=3
n ≤ 1, 000, 000,
−5 ≤ a ≤ −1
b <= 10,000,000
c <= 10,000,000
1 ≤ xi ≤ 100.
Thuật toán:
Gọi f(x) = a * x ^ 2 + b * x + c
Thuật O(n^2):
Gọi dp(i) là chi phí lớn nhất khi phân hoạch đoạn từ 1 -> i.
sum(i) là tổng các phần tử từ 1 -> i.
dp(i) = max(dp(j) + f(sum(i) – sum(j)) (1 <= i <= n; 0 <= j < i)
Thuật O(n): dùng Convex Hull Trick
dp(i) = max(dp(j) + f(sum(i) – sum(j)) (1 <= i <= n; 0 <= j < i)
⇔ dp(i) = dp(j) + a * (sum(i) – sum(j))^ 2 + b * (sum(i) – sum(j)) + c
⇔ dp(i) = (a * sum(i) ^ 2 + b * sum(i) + c) + (-2 * a * sum(i) * sum(j)) + a * sum(j) ^ 2 – b * sum(j) ^ 2
Đặt A = -2 * a * sum(j), X = sum(i), B = a * sum(j) ^ 2 – b * sum(j) ^ 2
⇔ ta được đường thẳng y = A * X + B.
Vì mảng sum tăng dần -> ta có thể dùng two-pointer để giảm đpt xuống O(n)
Code:
#include
using namespace std;
const int N = 1e6 + 10;
class ConvexHull {
private:
int head, tail;
long long A[N], B[N];
public:
void init() { head = tail = 0; }
bool bad(int l1, int l2, int l3) {
return (long double) (B[l3] - B[l1]) / (A[l1] - A[l3]) < (long double) (B[l2] - B[l1]) / (A[l1] - A[l2]);
}
void add(long long a, long long b) {
A[tail] = a; B[tail++] = b;
while (tail > 2 && bad(tail - 3, tail - 2, tail - 1)) {
A[tail - 2] = A[tail - 1];
B[tail - 2] = B[tail - 1];
tail--;
}
}
long long query(long long x) {
if (head >= tail) head = tail - 1;
while (head < tail - 1
&& A[head + 1] * x + B[head + 1] >= A[head] * x + B[head]) head++;
return A[head] * x + B[head];
}
} hull;
int n, a, b, c;
long long sum[N];
long long f(long long x) { return a * x * x + b * x + c; }
void load() {
scanf("%d%d%d%d", &n, &a, &b, &c);
for (int i = 1; i <= n; ++i) {
scanf("%lld", sum + i);
sum[i] += sum[i - 1];
}
}
void process() {
hull.init();
long long cost = f(sum[1]);
hull.add(-2 * a * sum[1], cost + a * sum[1] * sum[1] - b * sum[1]);
for (int i = 2; i <= n; ++i) {
cost = f(sum[i]) + max(0ll, hull.query(sum[i]));
hull.add(-2 * a * sum[i], cost + a * sum[i] * sum[i] - b * sum[i]);
}
printf("%lld\n", cost);
}
int main() {
// freopen("input.in", "r", stdin);
// freopen("output.out", "w", stdout);
int test; scanf("%d", &test);
while (test--) {
load();
process();
}
return 0;
}

Recent Comments