2014年5月29日木曜日

[C++] next_combination() 書いてみた

C++ 標準ライブラリの <algorithm> には next_permutation() という関数があります。辞書順で順列を列挙してくれる大変便利な関数で競プロ関連でお世話になったことのある人も結構いるのではないかと思います。順列があるなら組み合わせはないの?ということで next_combination() で検索をかけると nyama++ さんの 全ての組み合わせを作る【C++】 - Programming Magic という記事が引っかかります(他の検索結果からも結構参照されているようです)。この記事のリンクで指し先が消失しているものがありますが、内容自体は ISO C++ 標準化委員会(JTC1/SC22/WG21) に提案されているペーパー n2639 Algorithms for permutations and combinations, with and without repetitions です(実装は Reference implementation のところ)。next_pertial_permutation() (next_permutation() が n! の列挙なら、これは nPr を列挙する) の実装が

template <class BidirectionalIterator>
bool next_partial_permutation(BidirectionalIterator first, BidirectionalIterator middle, BidirectionalIterator last)
{
  reverse(middle, last);
  return next_permutation(first , last);
}

というのも私的に感動ポイントですが、next_combination() は確かに訳が分かりません。ってことで一度自分で書いてみたところ最終的にほぼ同様のアルゴリズムとなり理解が進みましたので解説記事っぽいものを書いてみようと思います。なお図としては全ての要素が1ずつ異なるようなイメージで書いていますが、値が跳んでいてもいいですし重複があっても問題ないはずです。

まず最初に意味不明と書かれていた prev_combination() について。処理上何かと便利なので [first, middle), [middle, last) がソートされている、という前提を置いておきます。このとき、[first, middle) に対して next_combination() した場合の例は以下のようになります。

見ての通り、[middle, last) に対しては prev_combination() になっています(両方の領域がソート済み、かつ、辞書式順序で列挙なのでこうなる)。ということで先の前提の下で next_combination() が正しく実装できれば他方の領域に対して next_combination() をかけてやれば元々の領域に対する prev_combination() になります。

さて、では本論のコードについて。さすがに真っ白の状態だと書けないので n2639 で参考文献として挙げられている Knuth 先生の The Art of Computer Programming. Volume 4, Fascicle 3 GeneratingAll Combinations and Partitions を見ると(※立ち読み。分冊が一冊になったら買います、多分)確かに列挙アルゴリズムが記述されてはいますが添字の列挙としてのアルゴリズムであって実際に要素の並び替えを行う next_permutation() 風の next_combination() の場合だとそのままでは使えなさそうです。そうすると指針としては、汎用的な方法として挙げられている「最も右端の増加可能な要素を見つけて増加、以降の要素は設定可能な最小値を設定していく」になります。これを実装してみましょう。

まずは「最も右端の増加可能な要素」を見つける必要があります。右端(middle - 1)が最大値でない場合はそこが増加可能です。では右端ではない部分が「最も右端の増加可能な要素」となる場合を考えてみましょう。この時その要素(*targetとしておきます)の右側には、右端(*(middle-1))が最大値となる状態で連続して昇順に要素が並んでいることになります。さもなくば、より右側の要素が増加可能になるからです。[middle, last) もソート済であることと合わせると、*target < *(last-1) となる最も右端の target を見つければいいことになります。これが存在しない場合には列挙終了です。

次に、*target をどの値に変更すれば良いかを考えましょう。target の右側が最大値から連続で詰まっているので、[target, middle) の間の値は考慮する必要がありません(target から middle まで埋めるには要素が足りない)。ということで、[middle, last) のうちで *target より大きくなる最小値に設定すれば良いことになります(next とします)。*target < *(last-1) なので [middle, last) で必ず見つかります。

では、[target+1, middle) (と [next, last) をどのように埋めれば良いでしょうか? これまでの条件から [target, middle) [next, last) 近傍の順序は下図上段のようになっています。[target, middle) はできるだけ最小になるように設定すること、[middle, last) もソートされている必要があることを考えると最終結果は下図下段のようになっている必要があります。これは target と next を iter_swap した後、[target+1, middle), [next+1, last) が一つの領域だとみなして、next+1 が先頭となるように rotate() することに他なりません。

というのをコードにまとめると以下のようになります。C++er な人には言うまでもないとは思いますが、わざわざ !(a < b) みたいな書き方になってるのは operator< のみ使用にしたかったからです。なお、rotate() は en.cppreference.com の参考実装(Forward Iterator 向けですが)に分割2領域用にちょっとだけ修正したものです。

// possible implementation introduced at http://en.cppreference.com/w/cpp/algorithm/rotate with slight modification to handle parted ranges
template<typename FI>
void parted_rotate(FI first1, FI last1, FI first2, FI last2)
{
 if(first1 == last1 || first2 == last2) return;
 FI next = first2;
 while (first1 != next) {
  std::iter_swap(first1++, next++);
  if(first1 == last1) first1 = first2;
  if (next == last2) {
   next = first2;
  } else if (first1 == first2) {
   first2 = next;
  }
 }
}

template<typename BI>
bool next_combination_imp(BI first1, BI last1, BI first2, BI last2)
{
 if(first1 == last1 || first2 == last2) return false;
 auto target = last1; --target;
 auto last_elem = last2; --last_elem;
 // find right-most incrementable element: target
 while(target != first1 && !(*target < *last_elem)) --target;
 if(target == first1 && !(*target < *last_elem)) {
  parted_rotate(first1, last1, first2, last2);
  return false;
 }
 // find the next value to be incremented: *next
 auto next = first2;
 while(!(*target < *next)) ++next;
 std::iter_swap(target++, next++);
 parted_rotate(target, last1, next, last2);
 return true;
}

// INVARIANT: is_sorted(first, mid) && is_sorted(mid, last)
template<typename BI>
inline bool next_combination(BI first, BI mid, BI last)
{
 return next_combination_imp(first, mid, mid, last);
}

// INVARIANT: is_sorted(first, mid) && is_sorted(mid, last)
template<typename BI>
inline bool prev_combination(BI first, BI mid, BI last)
{
 return next_combination_imp(mid, last, first, mid);
}

基本的に n2639 とほぼ同じことをするコードです。nyama++ さんの記事で②③となっているところが実は合わせて rotate() だったわけです。

reverse(first, mid);
reverse(mid, last);
reverse(first, last);

という reverse() 3連打の rotate() 実装を 2 領域対応にした形が②③になります。後の違いは、false を返す時に先に return する(rotate() が 2 ヶ所になる)か、1ヶ所で rotate() するかくらいですね。

ついでなので、計算量(iter_swap の呼び出し回数)について実際に計数させてみたところ、2nCn 列挙時での 1 next_combination() 辺りの平均 iter_swap() 回数は 1.7 くらいに収束しそうな感じです。 もちろん組み合わせの数自体が指数関数的に増えるので全体的にはすぐに苦しくなるのですが、かなり悪くない数字に感じます。

#n,r,the number of iter_swap,the number of combinations,average iter_swap per one next_combination() call
2,1,2,2,1
4,2,8,6,1.33333
6,3,30,20,1.5
8,4,112,70,1.6
10,5,414,252,1.64286
12,6,1540,924,1.66667
14,7,5754,3432,1.67657
16,8,21656,12870,1.68267
18,9,81994,48620,1.68643
20,10,312068,184756,1.68908
22,11,1192954,705432,1.6911
24,12,4577356,2704156,1.69271
26,13,17619034,10400600,1.69404
28,14,68003992,40116600,1.69516
30,15,263097002,155117520,1.69611
32,16,1019997844,601080390,1.69694
34,17,3961678122,2333606220,1.69766
36,18,15412309480,9075135300,1.6983
38,19,60046904394,35345263800,1.69887
40,20,234252753696,137846528820,1.69937