バグから学ぶ計算機科学 Scalaのハッシュテーブルにおいて並列コレクションのためのコード変更が大量の衝突を引き起こした事例

書いた人: ると

書いた日: 2012年1月21日

はじめに

Twitterで「有名なオープンソースソフトで今まであったおもしろいバグを解説した本とかないだろうか」とツイートしたらそれなりに需要があるようでした。そこで先ず隗より始めよという故事にのっとり、死馬の骨としてバグ解説記事を書いてみます。

今回のバグはScala 2.9の標準ライブラリに含まれるmutable.HashSet(ハッシュテーブルを使った重複無しコレクション)のコピーがJavaの標準ライブラリに含まれるHashSetの100倍遅いというバグです。並列コレクションのためにぱっと見問題の無い変更を加えたら思わぬところで影響が出たというものです。

なお、今回はScalaに関するバグですが、Scalaに関する知識は必要ありません。また、タイトルには「並列」と入っていますが、同期などの並列計算におけるめんどうな部分は今回のバグとは関係ありませんので、並列計算に関する知識もほとんど必要ありません。ただしハッシュテーブルの基本的な考えについては既に知っていると仮定します。

ハッシュテーブルへの挿入順で速度が変わる!?

問題の再現はとても簡単で、Scala 2.9の対話的環境で次のようなコードを実行してみれば再現できます。

import scala.collection.mutable
import scala.collection.JavaConverters._

// 空のハッシュセットを作成
val javaSet1 = new java.util.HashSet[Int]
val scalaSet1 = mutable.HashSet.empty[Int]

// 10万個要素を追加。
// 今回は0から99999までの連番の整数だが、問題には関係ない。
javaSet1.addAll((0 until 100000).asJava)
scalaSet1 ++= (0 until 100000)

// もう1組、空のハッシュセットを作成
val javaSet2 = new java.util.HashSet[Int]
val scalaSet2 = mutable.HashSet.empty[Int]

// 要素をコピー
// Javaの場合はすぐ終わる
javaSet2.addAll(javaSet1)

// Scalaの場合は遅い!
scalaSet2 ++= scalaSet1

ハッシュテーブルが極端に遅い場合は、ハッシュテーブルの同じ場所に要素を追加しようとする「衝突」が頻繁に置きていると疑われます。

しかしここで興味深いのは、一旦値をソートしてから挿入すると遅くないという点です。

val scalaSet3 = mutable.HashSet.empty[Int]

// これは遅くない
scalaSet3 ++= scalaSet1.toSeq.sorted

++=メソッドは要素を1つずつ順番に挿入していく実装になっていますので、要素を挿入する順番がこのバグに関係しているということです。

しかし、ハッシュテーブルにおける衝突頻度は普通は挿入順序によらないはずです。実際に挿入が終わった状態のテーブルを調べてみて衝突している要素を調べてみると、ソートした場合もしていない場合も同じ数だけ衝突しています。

ではなぜこんな奇妙な動作をするのでしょうか。

Scalaのmutable.HashSetの実装

原因の解説のために、まずはScalaのmutable.HashSetの実装を簡単に説明します。

mutable.HashSetは要素の重複を含まないコレクションをハッシュテーブルで実装したクラスです。JavaやScalaでは全てのオブジェクトはhashCodeという整数を返すメソッドを実装していて、ハッシュ値にはそれを利用しています。

なお、ハッシュテーブルの実際の実装はFlatHashTableというトレイト(実装を持てるインターフェース)にありますので興味を持った方は読んでみて下さい。

FlatHashTableではハッシュテーブルのサイズは2のべき乗になっていて、要素数がテーブルサイズの45%以上になるとサイズを2倍にしてテーブルを作り直します。要素を置くインデックスとしてはハッシュ値の上位数ビットを利用します。このとき利用するビット数はハッシュテーブルの大きさによって異なります。例えばテーブルのサイズが16で、要素1から要素4のハッシュ値が次の表の2列目のような場合、インデックスは次の表の3列目のようになります(実際にはhashCodeメソッドの返り値そのものを使うのではなく、分散がよくなるような関数を適用してから使いますが、今回のバグには関連しないので省略します)

ハッシュ値(2進) インデックス
要素1 0001 1101 ... 0001 = 1
要素2 0010 0110 ... 0010 = 2
要素3 0011 0100... 0011 = 3
要素4 0011 1111 ... 0011 = 3

インデックスの値が同じになってしまった場合、つまり衝突が起きた場合は後から挿入した要素のインデックスを1つ増やします。例えば要素を1から4まで順に挿入した場合は次のようになります。

ハッシュ値(2進) インデックス
要素1 0001 1101 ... 0001 = 1
要素2 0010 0110 ... 0010 = 2
要素3 0011 0100... 0011 = 3
要素4 0011 1111 ... 0100 = 4

ところでハッシュテーブルの実装を読んだ経験がある方には上位数ビットを使うという実装は変に思えるかもしれません。普通は下位数ビットを使う方がコードが簡単になります。実際Javaでは下位数ビットを使っていますし、Scalaでも2.8までは下位数ビットを使っていました。実はこの変更が今回のバグの肝です。この変更は2.9における並列コレクションの導入に伴うものですが、変更の理由を見る前にバグの詳しい原因を見てみましょう。

中で何が起きたのか

今回のバグには3つの要因があります。今回のバグは次の3つが揃って初めて発生しました。

1つ目は先程説明したインデックスの計算方法です。

2つ目は複数の要素を一度に挿入する際の動作です。通常ハッシュテーブルに大量の要素を一度に挿入する場合は最終的なテーブルのサイズを見積もれるため、一気にテーブルを大きくします。しかしScala 2.9の実装ではこの処理をせずに要素を1つずつ挿入しながら少しずつテーブルを大きくしています。サイズを一気に大きくするsetSizeメソッドは用意されているのですが、中身は何もしないメソッドになっていて、呼ばれていません。

3つ目は要素の列挙順に関するもので、ハッシュテーブル内の要素を1つずる列挙する場合はハッシュテーブルに格納されている順に列挙されます。これは極めて普通の実装で大抵の実装ではこのようにします。

では、この3つが揃うとどうなるでしょうか。先程の要素1から4が入ったテーブルをコピーしてみましょう。各要素のインデックスは1, 2, 3, 4なので、要素1から4まで順に値を挿入していきます。ここでは初期テーブルサイズは8とします

テーブルのサイズは8なので上位3ビットを見て要素1を0番目に入れます。

インデックス 要素 ハッシュ値(参考)
0 要素1 0001 1101 ...
1
2
3
...

次に要素2を挿入します。上位3ビットを見て1番目に入れます。

インデックス 要素 ハッシュ値(参考)
0 要素1 0001 1101 ...
1 要素2 0010 0110 ...
2
3
...

ここまでは問題ありません。次に要素3を挿入します。上位3ビットを見るとこれも001であり衝突していますのでその次の2番目に入れます。

インデックス 要素 ハッシュ値(参考)
0 要素1 0001 1101 ...
1 要素2 0010 0110 ... 衝突!
2 要素3 0011 0100...
3
...

そして要素4も上位3ビットは001なので衝突します。その次のインデックス2でも衝突しますので3番目に入れます。

インデックス 要素 ハッシュ値(参考)
0 要素1 0001 1101 ...
1 要素2 0010 0110 ... 衝突!
2 要素3 0011 0100... 衝突!
3 要素4 0011 1111 ...
...

ここで要素数がテーブルサイズの45%以上になったためテーブルサイズが倍の16になり、ハッシュ値の上位4ビットを元に値を再挿入します。途中経過は省略しますが、最終的には次のようになります。

インデックス 要素 ハッシュ値(参考)
0 要素1 0001 1101 ...
1 要素2 0010 0110 ...
2 要素3 0011 0100... 衝突!
3 要素4 0011 1111 ...
...

これを見ると最終状態では1回しか衝突していないにも関らず、中間状態では2回衝突が起きているのがわかります。

ハッシュテーブル中で連続した要素は似た上位ビットを持つため、少ない上位ビットでインデックスを決めると頻繁に衝突してしまいます。例えば最終テーブルサイズが65536で衝突が無かったとしても、最初の256番目までに格納されている要素の上位8ビットは全て0ですし、最初の512番目までに格納されている要素の上位9ビットも4パターンしかありません。つまり最終的にはハッシュテーブル全体に値が分散していたとしても中間状態では大量の衝突が発生してしまいます。

もしインデックスとして下位数ビットを使っていたとしたらこのようなことにはなりません。ハッシュテーブル中で連続していたとしても下位ビットは異なる値になるためです。そのためJavaではこのような問題は起きません。

さて、現象がわかったところでこのバグが生まれた原因について見てみます。つまりなぜインデックスとして上位数ビットを使うようになったのか見てみましょう。

なぜこのような実装になったのか

このバグの原因となった変更はScala 2.9で導入された並列コレクションのための変更です。

並列コレクションとは、プログラマがほとんど意識せずにマルチコアを活かした処理を書くための仕組みです。従来からScalaにはコレクションの各要素に対して同じ処理をするような機能がありました。簡単な例を示します。

// xsは1, 2, 3を含む集合
val xs = mutable.HashSet(1, 2, 3)

// xsの各要素に1を足した新しい集合を作る
// ysは2, 3, 4を含む集合
val ys = xs.map(x => x + 1)

// xsの中から2で割った余りが1である要素のみを集めた新しい集合を作る
// zsは1, 3含む集合
val zs = xs.filter(x => x % 2 == 1)

// xsの各要素xを、xと-xで置き換えた新しい集合を作る
// wsは1, -1, 2, -2, 3, -3を含む集合
val ws = xs.flatMap(x => mutable.HashSet(x, -x))

従来ではこれらの機能は要素を1つずつ順番に処理していましたので、マルチコアを持つCPUでもマルチコアを活かせていませんでした。そこで導入されたのが並列コレクションで、parというメソッドを挟むだけで各要素を並列に処理するコレクションになります。例えば次のようになります。

// xsは0から9999までを含む列
val xs = (0 until 10000)

// xsの各要素に1を足した新しい列を作る
// ysは1から10000までを含む列
// 計算は並列に実行される
val ys = xs.par.map(x => x + 1)

この例の場合、10000要素を持つ列が数個に分割され、それぞれを別のコアで処理したあと結果を統合します。さて、この並列コレクションはmutable.HashSetでも利用でき、ごく単純化すると次のような動作をします。

  1. ハッシュテーブルを一定サイズのブロックに分割する
  2. 各ブロックを各コアが処理する
  3. 結果はブロックごと・ハッシュ値の上位5ビットごとに用意したリストに貯めておく(この時点では重複を許す)。つまりブロック数×32個のリストができる
  4. 各ブロックの計算結果を連結して、上位5ビットごとのリストを作る。つまり32個のリストができる
  5. 各リストの値からハッシュテーブルを作る

この最後のステップも並列に処理します。つまり各リストを別のコアが担当します。

また、ハッシュテーブルの大きさはリストが完成した時点でリストのサイズから見積もっておき、値を入れる前に十分な大きさを確保しておきます。

ここでインデックスとしてハッシュ値の上位数ビットを使うようにしておけば、同じリスト内の値はハッシュテーブル内でも同じ領域に含まれることになります。そのため各コアが同期を使わずにハッシュテーブルを構築できます(実際には衝突があるとコアの担当範囲から要素があふれてしまう場合がありますが、それは後でまとめて処理します)。また、1つのコアが連続した狭い領域のみにアクセスするのでCPUのキャッシュ効率も良くなります。

なぜ下位5ビットでリストを作ってはいけないかというと、5ビットでリストに分類しておくというのはあくまで仮の分類であって、最終的なハッシュテーブルのサイズによってインデックスに使うビット数は変わります。テーブルサイズが小さければ5ビットよりも少なくなりますし、大きければ5ビットよりも多くなります。それはリストの構築が済んでからしかわからないため、下位5ビットでリストを作ってしまうと各リストの値がハッシュテーブル上で連続した領域に位置しません。例えばハッシュ値の下位8ビットが0000 0000である要素と1110 0000である要素があり、最終的なテーブルサイズが256だったとすると、この2つの要素は同じ下位5ビットを持っているにも関わらずハッシュテーブル中で離れた場所に位置してしまいます。一方で上位5ビットでリストを作っておければ上位何ビットを使ったとしても連続した領域に位置すると保障できます。

このような理由からハッシュテーブルのインデックスとして上位5ビットを使うように変更されました。しかしこれがバグの原因となりました。

解決に向けての取り組み

さて、解決策ですが、このバグは現在Prokopec氏が担当しています。Prokopec氏はScalaの作者であるOdersky教授の研究室で並列処理を研究している院生の方であり、Scalaの並列コレクションを担当している方です。Prokopec氏はいくつか解決案を挙げています。

まず最初は要素を追加する前にハッシュテーブルを大きくしておくsetSizeメソッドを実装して++=メソッドから呼ぶというものです。要素を1つずつ挿入した場合には効きませんが、全体的な効率化が見込めます。

もう1つはハッシュ値を使う前にハッシュテーブルごとに固有の値の分だけビット回転させるという案です。インデックスを計算する前や並列処理の結果を32個のリストに分割する前に回転させます。こうすると各リストの要素がハッシュテーブルの連続した領域に位置するという性質を保ったまま、ハッシュテーブルをコピーする際に上位ビットをばらばらにでき、衝突を避けられます。テーブル固有の値としては乱数を使う方法の他、テーブルのサイズに依存した固定値を使うという案も挙げられています。乱数を使った場合、衝突も減らせて、パフォーマンスにも影響しなかったそうです。追記: @kmizuさんによると、Scala 2.10.M1では改善されているようです。レポジトリの最新のソースを見てみるとテーブルサイズから回転する量を決めているようです。

おわりに

さて、これまで1つのバグを通じてハッシュテーブルにおける衝突や並列コレクションの処理を見てきました。このバグを通じて次のようなことがわかりました。

きっと世の中のBTSにはこういうおもしろいバグが沢山あるはずです。みなさんも今まで見たおもしろいバグを紹介してみませんか。