しがない元高専生の競プロ日記

AtCoderとかいろいろ解いてく

ABC149 - E Handshake 解説

[追記]スマホで見るとMathJaxが死んでる可能性が高いです。PCで見ることを推奨します。

解説でもヒントが少なく、人によってソースコードが大きく違ったりと苦戦したので残しておこうと思いました。

今回は具体的な解法の説明を重視しています(思考過程等の解説は無し、実装方針の説明に寄ってます)。

コンテスト自体はunratedなので途中で諦めた人も多い気がします。 atcoder.jp

問題の意訳

配列 \(A\) をタテ、ヨコにした100マス計算みたいな感じの表を作った時、できた数の上位 \(M\) 個の総和を求める。

簡単にするとこんな感じ(入出力例1の例)

\begin{array}{|c|c|c|c|c|c|} \hline + & 10 & 14 & 19 & 33 & 34 \\ \hline 10 & 20 & 24 & 29 & 43 & 44 \\ \hline 14 & 24 & 28 & 33 & 47 & 48 \\ \hline 19 & 29 & 33 & 38 & 52 & 53 \\ \hline 33 & 43 & 47 & 52 & 66 & 67 \\ \hline 34 & 44 & 48 & 53 & 67 & 68 \\ \hline \end{array}

\(68 + 67 + 67 = 202\) で \(202\) が答えとなる。

簡単な解法の説明

今回の問題は2つのパートに分かれていて、

  1. 増える幸福度が上位M個に入る最小の和を二分探索する。
  2. 累積和を用いて総和を求める。

という2パートに分かれています。

ここで大きな罠があって、実は1,2番で具体的に書かれていない二分探索の判定パートと総和の求め方が分かりにく過ぎる(というか解説PDFで具体的な解説無し)ので大きく苦戦します(しました)。

という訳で具体的な実装までできるだけ丁寧に説明しようと思います。分かりにくかったらごめんなさい。

(heno239さんのソースコードがクソ分かりやすくて助かりました。ありがとうございました)

にぶたんパート

前提として、二分探索を行うには単調性があることが条件です。

今回は、\(X\)が上位\(M\)個未満に入れるかという判定を行います。上位\(M\)個未満というのは図のように単調性がある(ある位置で〇になったらその先も全部〇になる)といえます。

\begin{array}{|c|ccc|cc|} sum & 0 & \ldots & X-1 & X & \ldots & max(A)+max(A) \\ \hline / & × & × & × & 〇 & 〇 & 〇 \\ \end{array}


それでは遂に二分探索最大の山場、探索範囲の縮小のための判定です。

判定を行うためにはlower_boundを使う必要があります。

そのため事前に配列\(A\)をソートしておきましょう。

判定方法は以下のアルゴリズムで実装できます。

  1. まずa[i]を左手に持っているとする。
  2. a[i]を持っているときに和がX未満となる右手に持っている数字はX - a[i]未満の数字である。
  3. 2を lower_bound を使って判定、X - a[i]未満の個数を求める。
  4. 3で求めた数をNから引くことでX - a[i]以上の個数を判定できる。
  5. 1~4を繰り返し、すべてのX - a[i]以上の個数の和がM未満かを判定する。

これをC++で記述するとこうなります。

//判定
bool chk(long long x) {
    long long cnt = 0;
    for(int i=0; i<n; i++) {
        long long pos = lower_bound(a.begin(), a.end(), x - a[i]) - a.begin();
        cnt += (n - pos);
    }
    return cnt < m;//X以上の和がM個未満かどうか
}

int main() {
    cin >> n >> m;
    a.resize(n);
    for(int i=0; i<n; i++)cin >> a[i];
    sort(a.begin(), a.end());
    //二分探索
    long long ng = 0, ok = M;
    while (abs(ok - ng) > 1) {
        long long mid = (ok + ng) / 2;
        if (chk(mid)) ok = mid;
        else ng = mid;
    }
    //終了時にngにX-1,okにXが入っている。
    ...

和を求めるパート

続いて、答えとなる総和を求めるパートに入ります。
ここで、ある区間の総和を高速に求める必要があるので累積和を求めておきます。

累積和が分からない人はコチラ

ここからは以下のようなアルゴリズムで総和を求めていきます。
1. upper_boundを使い、左手にa[i]があると仮定した場合のX-a[i]以下の数の個数を判定
2.1で求めた数をNから引いてa[i]を左手に持った時のX以上の数を計算。
3.ans2で求めた数*a[i]+右手に持つ幸福度の総和を足す。
4.M2で求めた数を引く。
5.1~4繰り返し

ここで、`M`を見てみると、数が余っています。何故でしょう?
何故なら、今回の`X`とは、上位`M`個に確実に入る数となっているからです。
具体例としては、`{35,35,35,35,38,39}`、`M = 3`とします。
こうすると、`X`には`35`は入らず、`36`が答えとなります。
何故なら、`35`は複数存在するため、確実にすべての`35`は上位`M`個に入らなくなり、上位`M`個に確実に入ることができる数は`36`になります。
よって、最後に余った`M`にはすべて入るか不確定だった \(X-1\) 、つまり `ng` を余りの `M` 個足し合わせることで総和を求めることができます。

具体的なソースコードはコチラになります。

 ...
    vector<long long> wa(n + 1); //Aの累積和
    for(int i=0; i<n; i++) wa[i + 1] = wa[i] + a[i];
    for(int i=0; i<n; i++) {
        long long pos = upper_bound(a.begin(), a.end() , ng - a[i]) - a.begin();
        long long cnt = n - pos;
        ans += cnt * a[i] + (wa[n] - wa[pos]);
        m -= cnt;
    }
    ans += m * ng;
    cout << ans << endl;
}

総括

二分探索は結構得意な部類だと思っていたのですがコレとその前のAGCで完全に自信を無くしました。

判定パートが499点分あると死にます。助けてください...

ついでに同じような類題を1つ紹介しておきます(割と天才向けです)。

atcoder.jp

最後に最終的なソースコードと提出結果を貼って終わろうと思います。

ここまで見ていただきありがとうございました。

Submission #9301825 - AtCoder Beginner Contest 149

#include <bits/stdc++.h>
using namespace std;


long long n, m, ans;
vector<long long> a;

//判定
bool chk(long long x) {
    long long cnt = 0;
    for(int i=0; i<n; i++) {
        long long pos = lower_bound(a.begin(), a.end(), x - a[i]) - a.begin();
        cnt += (n - pos);
    }
    return cnt < m;//X以上の和がM個未満かどうか
}

int main() {
    cin >> n >> m;
    a.resize(n);
    for(int i=0; i<n; i++)cin >> a[i];
    sort(a.begin(), a.end());
    //二分探索
    long long ng = 0, ok = LLONG_MAX;
    while (abs(ok - ng) > 1) {
        long long mid = (ok + ng) / 2;
        if (chk(mid)) ok = mid;
        else ng = mid;
    }
    //終了時にngにX-1,okにXが入っている。


    vector<long long> wa(n + 1); //Aの累積和
    for(int i=0; i<n; i++) wa[i + 1] = wa[i] + a[i];
    for(int i=0; i<n; i++) {
        long long pos = upper_bound(a.begin(), a.end() , ng - a[i]) - a.begin();
        long long cnt = n - pos;
        ans += cnt * a[i] + (wa[n] - wa[pos]);
        m -= cnt;
    }
    ans += m * ng;
    cout << ans << endl;
}