応答生成モデルの学習(Text-to-Text Transfer Transformer)

ユーザの発話テキストに対して、応答テキストを生成させるためのモデルを作ります。

T5 (Text-to-Text Transfer Transformer)

チャットボットを作るときに必要となるモデルは、Text to Text (Seq to Seq)モデルとなります。今回はT5(Text-to-Text Transfer Transformer)を選択しました。

T5を含めたTransformer系のモデルの特徴は、「事前学習」によって言語(この場合は日本語)そのものを学習しておき、ファインチューニング(転移学習)によって、個々のタスクに応じた内容を学習させます。

以前は、OpenNMTモデルのように、一から学習させなければなりませんでした。この方法だと、学習するデータを作成するのにも人手が必要でしたし、計算量も大規模なためコンピュータコストも負担になりました。実際、Twitterのデータ90MBを収集して試してみましたが、この程度のデータ量では、日本語として意味をなさない応答が返ってきました。

OpenNMTの出力例

しかし、Transformer系のモデルの場合は、事前学習済みのモデルが公開されており、必要なタスクに応じたデータのみ作成・学習すれば良いので、個人レベルでも試してみることができます。

例えていうと、OpenNMTは、何も知らない赤ちゃんの状態から学習させるのに対して、T5は、日本語が話せる子供に、「挨拶を教える」「特定の分野に対して尋ねられたら答える」というような学習をさせますので、教えることは格段に少なくて済みます。

今回は、@sonoisa(日鉄ソリューションズ株式会社)さんが公開をしてくださっている日本語モデルを、使用させていただきました。関連する記事は、こちら。使用される場合には、モデルについての免責事項等もご確認ください。

ファインチューニング (転移学習)

さて、具体的な転移学習です。まずは、先人のやり方をトレースしてみます。参考にした記事は、こちらの記事です。

この中で、転移学習から推論までの一連の流れをGoogle Colabratoryのノートブックとして公開してくださっていますので、一旦この通りに実行していきます。

出来上がったモデルは、MODEL_DIRに定義したフォルダに格納されます。

先ほどのGoogleColabのノートブックの最後の「任意の文章に対する応答生成」で、作成したモデルによる推論も行っています。ここでの例では、bodyに任意の文章を定義してコードを実行すると、10個の応答候補が返ってくるようになっています。

再学習

上記の例では、Twitterのデータを使って学習しています。多少、乱暴な応答ですが、確かに日本語の応答としては成立しているように見えます。

この後、自分が思うような言葉を返してもらうためには、新たにデータを作って学習させていく必要があります。実際には、400件ほどのデータを手作業で作成し、再学習をさせてみました。

データ量も少ないので、学習時間は10分程度で終了します。

この作業を繰り返していくことによって、自分なりの応答モデルを作ることができそうです。