CF 631E 题解

admin 2021-10-07 PM 24℃ 0条

题目链接

分析

设 $sum$ 为 $a$ 的前缀和数组。

下面只考虑 $u>v$ 的情况。$u<v$ 时类似。

考虑将 $a_u$ 移动到 $a_v$ 会给答案带来多大的贡献。

显然是 $a_u\times(v-u)+sum_{u-1}-sum_{v-1}$。

整理得 $(-a_u\times u+sum_{u-1})+a_u\times v-sum_{v-1}$。

也就是所有直线 $y=vx-sum_{v-1}$ 在 $x=a_u$ 时的最大取值。

可以用李超线段树维护。时间复杂度 $O(n\log n)$。

实际有效代码只有 25 行,而且我用的是大括号换行的码风,所以其实很好写。

解决

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>

using namespace std;
using LL = long long;
using LLL = __int128;

const int N = 2e5 + 5, S = 1e6 + 5;
int n, a[N];
LL sum[N], ans;

class LCST
{
private:
    int mx[N], ls[N], rs[N], k[N], cnt, root, tot;
    LL b[N];

    int New(int _k, LL _b)
    {
        cnt++, k[cnt] = _k, b[cnt] = _b;
        return cnt;
    }

    LL y(int id, int x)
    {
        return 1LL * x * k[id] + b[id];
    }

    void insert(int &k, int l, int r, int id)
    {
        if (!k)
            k = ++tot;
        if (l == r)
        {
            if (y(id, l) > y(mx[k], l))
                mx[k] = id;
            return;
        }
        int mid = ((l + r) >> 1);
        if (y(id, mid) > y(mx[k], mid))
            swap(id, mx[k]);
        if (y(id, l) > y(mx[k], l))
            insert(ls[k], l, mid, id);
        if (y(id, r) > y(mx[k], r))
            insert(rs[k], mid + 1, r, id);
    }

    LL query(int k, int l, int r, int x)
    {
        if (!k)
            return -1e18;
        if (l == r)
            return y(mx[k], x);
        int mid = ((l + r) >> 1);
        if (x <= mid)
            return max(y(mx[k], x), query(ls[k], l, mid, x));
        else
            return max(y(mx[k], x), query(rs[k], mid + 1, r, x));
    }

public:
    void insert(int _k, LL _b)
    {
        insert(root, -S, S, New(_k, _b));
    }

    LL query(int x)
    {
        return query(root, -S, S, x);
    }

    void init()
    {
        cnt = 0, tot = 1, root = 0, b[0] = -1e18;
        memset(mx, 0, sizeof(mx));
        memset(ls, 0, sizeof(ls));
        memset(rs, 0, sizeof(rs));
    }
} lcst;

template<class T>
void read(T &ret)
{
    ret = 0;
    char ch = getchar(), flag = 0;
    while ((ch < '0' || ch > '9') && ch != '-')
        ch = getchar();
    if (ch == '-')
        ch = getchar(), flag = 1;
    while (ch >= '0' && ch <= '9')
        ret = ret * 10 + ch - '0', ch = getchar();
    if (flag)
        ret = -ret;
}

template<class T, class... Args>
void read(T &ret, Args &... rest)
{
    read(ret);
    read(rest...);
}

int main()
{
    read(n);
    for (int i = 1; i <= n; i++)
    {
        read(a[i]);
        sum[i] = sum[i - 1] + a[i];
    }
    lcst.init();
    for (int i = 1; i <= n; i++)
    {
        ans = max(ans, -1LL * a[i] * i + sum[i - 1] + lcst.query(a[i]));
        lcst.insert(i, -sum[i - 1]);
    }
    lcst.init();
    for (int i = n; i >= 1; i--)
    {
        ans = max(ans, -1LL * a[i] * i + sum[i] + lcst.query(a[i]));
        lcst.insert(i, -sum[i]);
    }
    for (int i = 1; i <= n; i++)
        ans += 1LL * i * a[i];
    cout << ans << endl;
    return 0;
}
标签: none

非特殊说明,本博所有文章均为博主原创。

上一篇 洛谷 P3607 题解
下一篇 HDU 2829 题解

评论啦~