ABC147 E Balanced Pathでbitset高速化を完全に理解する

本番で通せなかったABC147のE問題『Balanced Path』について、反省していきたいと思います。

問題概要

H*Wのグリッドの各マスに数が2つ書かれている。グリッド上を(0,0)->(H,W)y+1orx+1で移動するとき、経路上の各マスについて、いずれかを選び、「その合計と選ばなかったものの合計の差の最小値」を求めよ。

考察

第一感

  • マス数が小さい(H,W <= 80)
  • 各マスに書かれた数も小さい(0 <= A_ij <= 80, 0 <= B_ij <= 80)
    • したがって各マスごとの差も 0 <= |A_ij - B_ij| <= 80
  • グリッドを右or下で移動するので、左マスと上マスだけに依存するDPになりそう
  • 経路ごとに具体的な集合は持たなくても良い気がする
    • そのマスまで到達した際に、可能な差の絶対値だけを持っておけば良い気がする

取り組み

  1. 上記を実装する
    • TLE
    • 到達できる集合をsetで管理したので、各状態からの遷移先を計算するところでループが重かった
  2. 各経路について差の絶対値が80以下と仮定していい気持ちになる
    • 前半と後半が打ち消し合う場合、前半で打ち消した分と後半で打ち消した分で打ち消しあえるような気がした
      • これは誤り(考察不足)
    • サンプルが通り、反例を見つけられなかったので提出
      • TLEだったものが、WAに変わる
  3. 本番終了後Kiri8128さん提出 #8864256を参考に考察

差が常に1マスの最大値(80)以下で遷移すると仮定できないのはなぜ

経路の各マスの差が15, 35, 45, 80, 15であるような経路を考えると、+++--という途中で差が80を越えるような経路と、-++-+という差が80を越えない経路が取れるため、常に差は80以下になる気がします。 が!これは誤りで前半でお互いに打ち消しあうと、後半で出てくるものが余ってしまうケースがあります。

WAになるケースが見つけられなかったのですが、Twitterで親切な方が教えて下さいました。

累積の差が80までとして、枝刈りする実装だと、例えばこのケースで落ちます。

2 4
0 0 0 0
0 0 0 0
60 60 40 40
60 40 40 40

前半と後半を組み合わせて作る小さなパーツで、更に全体の差を小さくできる場合の考慮ができていなかったですね。

高速化が本筋だった

結局TLE解の高速化が必要になります。今回はbitsetを使った高速化が必要だったようです。

数の集合をbit列で表すことで、集合全体への足し算をbit shiftとして表現することができます。
例: S = set([1, 2, 4]) => '0b1011' = 11。(集合の要素を足し算をしているのではなく、「i番目のビットが1⇒集合Sにiが含まれる」という表現です。)

Pythonだと 1<<2*80*80 といった長いbit列も大きな整数値として扱う事ができます。

正しい実装

MAX_DIFF = 80
h, w = map(int, input().split())

a = [list(map(int, input().split())) for i in range(h)]
b = [list(map(int, input().split())) for i in range(h)]
c = [[abs(a[i][j]-b[i][j]) for j in range(w)] for i in range(h)]

sets = [[0 for j in range(w)] for i in range(h)]
sets[0][0] = (1 << MAX_DIFF+c[0][0]) | (1 << MAX_DIFF-c[0][0])
for i in range(1, h):
    sets[i][0] |= (sets[i-1][0] << MAX_DIFF+c[i][0]
                   ) | (sets[i-1][0] << MAX_DIFF-c[i][0])
for i in range(1, w):
    sets[0][i] |= (sets[0][i-1] << MAX_DIFF+c[0][i]
                   ) | (sets[0][i-1] << MAX_DIFF-c[0][i])
for i in range(1, h):
    for j in range(1, w):
        sets[i][j] |= (sets[i-1][j] << MAX_DIFF+c[i][j]
                       ) | (sets[i-1][j] << MAX_DIFF-c[i][j])
        sets[i][j] |= (sets[i][j-1] << MAX_DIFF+c[i][j]
                       ) | (sets[i][j-1] << MAX_DIFF-c[i][j])

s = bin(sets[h-1][w-1]+(1 << (h+w)*MAX_DIFF))
min_diff = 1 << MAX_DIFF
for i in range(len(s)):
    if s[-1-i] == '1':
        min_diff = min(min_diff, abs(i-(h+w-1)*MAX_DIFF))
print(min_diff)

最大誤差分オフセットすることで、[-80, (h+w)*MAX_DIFF]の範囲をbit列で表現しています。

まとめ

  • グリッドに対して左マス・上マスだけに依存するDPで書けると気付いたところはGood
  • 前半と後半を組みあせて小さなパーツを作るケースを思いつけなかったのはBad
  • bitset: 数の集合をある数に対応させる(各要素を各bitに対応させる)ことができる
    • bitが表現する要素が等間隔に並んでいる場合、bit shiftで集合全体への足し算・引き算が表現できる

この記事はCompetitive Programming (2) Advent Calendar 2019の12/8分の記事です。