クロージャを渡して順列生成、をまねしてみる

ある集合 S の部分集合 Q の中で、要素数が一定以下のもの (|Q| < m < |S|) を生成して、それぞれの Q についてごにょごにょしたい。自分で全部考えるのもまあいい練習問題で、効率を度外視した解は簡単に書けると思うが、世の中にはきっと Python らしい書き方があるに違いない。まずは順列生成ジェネレータとかあるだろうな、と検索したら、次の記事を見つけた。

いやー。すごい勉強になった。こういう風に関数を引数に渡す方法があるんだなあ。部分集合を生成してごにょごにょする場合も、この「ごにょごにょ」を渡しちゃうって発想が使えるわけだ。

で、本質的でない部分でもうちょっと高速化できそうだったのでやってみた。

  • シーケンスの長さは最初に計算しておく
  • i 番目の要素を抜いたシーケンスをスライスの結合で得る

下のプログラムのうち、私が手を加えたのは kperm2, perm2 だけ。

# permutation.py                                                                
def default_term(x):
    return [x]
def default_concat(x, y):
    return [x] + y

def kperm(seq, k,
          terminal_procedure=default_term, concat_procedure=default_concat):
    if k==1:
        for item in seq:
            yield terminal_procedure(item)
    else:
        for i in xrange(len(seq)):
            for p in kperm([seq[idx] for idx in xrange(len(seq)) if i!=idx],
                           k-1,
                           terminal_procedure,
                           concat_procedure):
                yield concat_procedure(seq[i], p)

def perm(seq, tp=default_term, cp=default_concat):
    return kperm(seq, len(seq), tp, cp)


def kperm2(seq, length, k,
          terminal_procedure=default_term, concat_procedure=default_concat):
    if k == 1:
        for item in seq:
            yield terminal_procedure(item)
    else:
        for i in xrange(length):
            for p in kperm2(seq[:i] + seq[i+1:],
                            length - 1,
                            k-1,
                            terminal_procedure,
                            concat_procedure):
                yield concat_procedure(seq[i], p)


def perm2(seq, tp=default_term, cp=default_concat):
    return kperm2(seq, len(seq), len(seq), tp, cp)

次の補助用の関数も用意して計測してみる。

# temp.py                                                                       
import time
def measure(proc):
    s = time.time()
    proc()
    e = time.time()
    return e-s

def honer(p):
    return reduce(lambda x,y: x*10+y, p)
>>> import permutation
>>> import temp
>>> temp.measure(lambda: [temp.honer(p) for p in permutation.perm(range(1,10))])
4.7288901805877686
>>> temp.measure(lambda: [n for n in permutation.perm(range(1,10), lambda x:x, lambda x,y:y*10+x)])
2.5051350593566895
>>> temp.measure(lambda: [n for n in permutation.perm2(range(1,10), lambda x:x, lambda x,y:y*10+x)])
2.2291669845581055

お、速度改善された。

追記

2.6 だと itertools に combinations とか permutations とかあるぞ、と。