2010年9月6日月曜日

logsumexpとスケーリング法

少し前にtwitter上でCRFSuiteはスケーリング法を使っているから速い,的なことを書いたのでその解説です.

linear-chain CRFのパラメタ推定に必要なのは対数尤度関数の微分です.これの計算に必要なのが,前向き・後ろ向きのスコアαとβです.時刻t(系列上での位置)とラベルiに対する前向きスコアαは,以下の式で計算されます.fは特徴ベクトル,wは重みベクトルです.



ところがこのままだと問題が起こります.αの値はexp個の足し算で構成されるため,最終的にかなり大きくて,簡単に倍精度の限界を超えてしまうのです.困った.そこで,logの世界に落とします.αの代わりにlog(α)を計算します.すると,expの世界の掛け算はlogの世界の足し算になります.問題は,足し算です.expの世界の足し算を,logの世界で行う2項関数がlogsumexpです.



で定義されます.expをかけてるから,結局オーバーフローしてしまうように見えますが,以下のように式変形することで,見事にオーバーフローしなくなりました.ただし,y < x とします.



これを使うとlogαの計算は以下のようになります.



ここまでがlogsumexpの話.


logsumexpを使い出すと,今までただの足し算でよかったのにlogやらexpやら,いかにも重そうな関数を使わなくてはならなくて激しく遅くなります.できる事ならexpの世界のまま計算したい.そこで使われるのがスケーリング法です.ここで大事な点を思い出します.最後に計算したかったのは周辺確率です.つまり,0から1の間に収まるような値が出力です.最後にZで割れば,彼らは器用にoverflowもunderflowもおこしません.ということは,少しずつ割っていって,最終的にZで割った値が出るように調整してあげれば,オーバーフローせずに計算が終わるはずです.

時刻 t ごとに和が 1 になるように正規化したα'を以下で定義します.



これを,普通のα同様にα'の再帰計算で計算できればlogの世界で計算する必要がなくなるという寸法です.結局α'は,毎時刻ごとに1になるように正規化しているだけなので,正規化するときに割った値を覚えておけばよいのでした.ちなみにこのテクは,PRMLのHMM(だったかな?)の章に書いてあります.って教わりました.

ところで,これで本当に速くなるんでしょうか.改めてαの式とlogαの式を見返して見ます.確かにαにはlogsumexpがありません.しかし,logαでは実は重みベクトルと特徴ベクトルの内積に対するexpが消えました.結局,logsumexpの代わりにexpになっただけ,このままではexpの実行回数で両者に差がありません.ここにもう一つトリックが存在します.一般にlinear-chain CRFではラベル遷移の特徴(遷移素性)と各時刻ごとの特徴(状態素性)は別々にあらわされます.つまり,

という2種類の特徴量に分けられています.ラベル数L,文長Tとします.これらのexpはあらかじめ計算しておくと,必要なexpの計算量はO(TL + L2),一方でlogαではO(TL2).結果的にαを直接計算した方が,つまりスケーリング法の方が速くなるのでした.


ところで,そんなわけで私も以前スケーリング法を実装したのですが,なんだか思わぬタイミングでoverflowやunderflowを起こしてしまいます.なので,私も完璧に理解したわけではありません.まだ実装上のテクニックがあるのかもしれません.logsumexpもスケーリング法も,たまたま何かで見たり,たまたま教えてもらったり,罠がいっぱい.

0 件のコメント:

コメントを投稿