しがない元高専生の競プロ日記

AtCoderとかいろいろ解いてく

第6回 ドワンゴからの挑戦状 予選 B - Fusing Slimes 解説

問題

URLはこちら↓ atcoder.jp

スライムが移動した距離の期待値に(N-1)!をかける問題。

問題を解いても実装面で罠があったりと大変だったので記事にしました。

解説

まず、問題の言い替えが発生します。

サンプルのケースはわかりにくいので、N=4,A={1, 2, 3, 4}のケースを考えます。

この時、スライムの移動する順番は(N-1)!通りあって、以下のようになります。

スライムを左から1,2,3,4とすると、
123
132
213
231
312
321
の6通り

ここから、上記の各移動順序での移動距離の和を計算してみると、以下のようになります。

123 -> 移動距離 3
132 -> 移動距離 4
213 -> 移動距離 4
231 -> 移動距離 5
312 -> 移動距離 4
321 -> 移動距離 6
  • 試しに2個目の移動方法をシミュレーションすると、1->2で距離1、3->4で距離2、2->4で距離4となります。

ここから各移動距離の期待値を求めます。各移動パターンの期待値の求め方は、

移動距離/(N-1)!で求まります。

よって全てのパターンの移動距離の期待値はΣ(各移動距離 / (N-1)!)になります。


しかし、ここであることに気付きます。

Σ(各移動距離 / (N-1)!) = 移動距離の総和 / (N-1)!と言い替えることができます。

で、この問題は最後に(N-1)!をかけたものを答えとして出力する問題です。

つまり、この問題は各移動距離の総和を求める問題に変換できたことが分かります。


それでは、移動距離の総和を求めてみましょう。

まず、i番目のスライムがj番目のスライムに合体する数を数えてみましょう。

上記の例の1番目のスライムを例に考えてみると、

123 -> 1番目のスライムは2と合体
132 -> 1番目のスライムは2と合体
213 -> 1番目のスライムは3と合体
231 -> 1番目のスライムは4と合体
312 -> 1番目のスライムは2と合体
321 -> 1番目のスライムは4と合体

{i,j} = {1,2}の時3個、
{i,j} = {1,3}の時1個、
{i,j} = {1,4}の時2個となる。

では次に2番目のスライムを例にすると、

123 -> 2番目のスライムは3と合体
132 -> 2番目のスライムは4と合体
213 -> 2番目のスライムは3と合体
231 -> 2番目のスライムは3と合体
312 -> 2番目のスライムは4と合体
321 -> 2番目のスライムは4と合体

{i,j} = {2,3}の時3個、
{i,j} = {2,4}の時3個となる。

3番目のスライムはどこにいても4番目に移動するので6個になります。

これを図に表すとこうなります。

f:id:honehaniwa:20200113191930p:plain

この図を見てみると、1->31->4の時、1->2を通っていることに気づきませんか?

つまり、1->3は、1->2 + 2->3 と変形することができます。

これを図に表すと、以下のようになります。

f:id:honehaniwa:20200113194530p:plain [追記]数値が間違ってたので修正しました

なんと、各iからjへの移動距離は(N-1)! / (これまで通った区間の数)になっていることが分かります。

よってこれらの数×各区間の距離%modを求めれば解けそうです!

実装パート

さて、ようやく実装パートに入りました。

実装においても大きな落とし穴が存在しますので、気を引き締めていきましょう。

※今回はPythonで実装を行います。一応私のメイン言語はC++なので分かりやすく書いてるつもりなのですが最後にC++でのACソースコードを貼りましたのでそちらを参考にしていただけると幸いです。

では実装していきましょう。

n=int(input())
A=list(map(int,input().split()))
mod = 10**9 + 7
# (N-1)!を求める
fac=1
for i in range(1,n):
  fac = fac*i%mod
# 個数テーブルを準備する
cnt=[0]*n
for i in range(1,n):
  cnt[i] = fac // i % mod
  cnt[i] = (cnt[i]+cnt[i-1]) % mod
# 計算
ans=0
for i in range(n-1):
  ans += (A[i+1]-A[i]) * cnt[i+1]
  ans %= mod

print(ans)

出来ました!提出します!

Submission #9496276 - Dwango Programming Contest 6th

f:id:honehaniwa:20200113195244p:plain
...

f:id:honehaniwa:20200113195435p:plain

こうなりました。

方針が間違っているのでしょうか?

実は違います。正体はmodにおける割り算の方法に問題があるからです!

qiita.com

詳しい説明はこちらの記事にて紹介されています。(けんちょんさんいつもありがとうございます)

ここでは簡単な説明のみとさせていただきますが、

  1. modの割り算ではどんな数字も割り切れる。
    -> 9/4 (mod 13) = 12(なぜなら4×12=48 ≡ 9(mod13))
  2. ではmod付きの正しい割り算の結果を知りたい。
  3. 求めるには逆元なるものが必要で、割りたい数×(割る数の逆元)で計算できるらしい!
  4. 逆元はフェルマーの小定理で計算できるらしい!
  5. デキタヤッター!

フェルマーの小定理の説明も詳しくはこっちで解説されてます。

qiita.com

では、逆元の計算を実装して完成です!

mod=10**9+7
# mod inverse(逆元)
# 今回はfermatの小定理で求める
def mod_inv(x):
  return pow(x,mod-2,mod)

n=int(input())
A=list(map(int,input().split()))
# (N-1)!を求める
fac=1
for i in range(1,n):
  fac = fac*i%mod
# 各区間の個数テーブルを先に求める
cnt=[0]*n
for i in range(1,n):
  cnt[i] = fac * mod_inv(i) % mod
  cnt[i] = (cnt[i]+cnt[i-1]) % mod
# 計算
ans=0
for i in range(n-1):
  ans += (A[i+1]-A[i]) * cnt[i+1]
  ans %= mod

print(ans)

AC結果はこちらです。

atcoder.jp

感想

考察は450くらいでした(水になれて冷えた人がコンテストほぼギリギリまで考えて考察できるレベル)が、実装で罠が張って合ってやっぱり600だなぁと感じました。

今回でmodの割り算ができるようになったので今後のコンテストで差をつけていきたいです!