tensorflow servingの話 その1

まずはじめに

はじめまして、yasu_umiです。弊社では、深層学習を用いたSaaSを提供しており、Deep Learningのモデルを安定的にAPI提供する必要があります。

Deep Learningを利用したサービスの乗ったサーバーは、運用がとても大変です。実サービスを運用するにあたっては学習によってモデルが更新されるということが問題となります。

  • モデルファイルを各サーバーに配布する必要がある
  • モデルが更新されるたびにモデルを読み込みなおす必要がある
  • 安全に運用するためのモデルのバージョニングの実装が大変

また、Deep Learningのモデルはファイルサイズが大きいです。一般にpythonのサーバーはgunicornなどで管理し定期的に再起動する必要があるため、アプリケーションとモデルが同じサーバーに乗っていると、サーバーの再起動に時間がかかり、メモリの負荷も大きくなってしまいます。

そこで今回は、この一番大変な部分をアプリケーションコードで頑張らずに、サーバーを停止することなくロールバックも可能な状態でモデルの更新ができる tensorflow serving (以下serving)について、実際の運用までを紹介していきます。

今回の記事中のコードは github に公開しているので合わせてご覧ください。 動作環境はpyton3.xで、今回必要なライブラリはtensorflow1.xとgrpc1.xです。

servingとは

一言で言うと、tensorflowのservableなオブジェクト読み込んで、gprcでリクエストを受け付けて、実行してくれるとても速いC++で書かれたサーバーです。 servableとは、例えばtensorflowにおけるgraphのような、何か計算をするオブジェクトです。読み込みはデフォルトではローカルのファイルパスが指定できる仕組みになっており、ユーザーは指定したディレクトリ以下にモデルファイルを配置するだけでservingに新しいモデルを読み込ませることができます。 詳しくは公式ドキュメントのarchitecture overview に紹介されています。

ここで大事なのは、servingの実行環境には学習済みモデル以外のコードが一切必要ないということです。詳しく解説すると、serving向けにエクスポートしたモデルにはGraphDef (ネットワークなどの定義)とSignatureDef(引数・その型)が書かれており、servingはモデルファイル内のそれらの値を使い、モデルをロード時にコンパイルするので、pythonで書いたgraphのネットワーク定義が不要になるのです。これにより、全く異なるアーキテクチャのネットワークを同じサーバーで同時に提供できます。

servingはgrpcしか受け付けないので、webアプリケーションなどから利用する場合手前に何かサーバーを立てる必要がありますが、servingがモデルのロードや実行を行ってくれるため、このサーバーは入出力の形式さえ知っていれば中のモデルがどういう形になっているか知らなくて済みます。

まずservingを動かす

前置きはこのくらいにして早速servingを起動しましょう。公式にはdockerhubにコンテナがないので手元でDockerfileを書いていきます。 tensorflow_model_serverはapt-getで入れることができるので、今回はそれを使います。

FROM ubuntu:16.04

RUN apt-get update && \
    apt-get upgrade -qy && \
    apt-get install -y --no-install-recommends \
      ca-certificates curl && \
    echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | tee /etc/apt/sources.list.d/tensorflow-serving.list && \
    curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | apt-key add - && \
    apt-get update && \
    apt-get install -y --no-install-recommends \
      tensorflow-model-server && \
    apt-get clean && \
    rm -rf /var/cache/apt/archives/* /var/lib/apt/lists/*

WORKDIR /root/serving-example
EXPOSE 8500

イメージをビルドし

docker build -t serving-example .

コンテナを起動します

docker run --name serving-example -v `pwd`:/root/serving-example -p 8500:8500 -it serving-example /bin/bash

そしてコンテナ内でservingを起動します

tensorflow_model_server --model_name='default' --model_base_path=/root/serving-example/tmp

/root/serving-exampleにはdocker run実行時のカレントディレクトリをmountしたので、ここにtmpディレクトリを作り、その中にモデルファイルを置くことでservingにモデルをロードすることができます。今はまだ何も置いていないのでNo versions of servable default found under base pathと怒られているはずです。

graphを作る

次に、graphを作りましょう。モデルはなんでもいいのですが、ここでは結果が簡単に検証できるものにします。引数xとyをとり、和を返すモデルをserveします。 コンテナ内のtensorflow_model_serverは起動したまま、ローカルにスクリプトを書いていきます。 完成品はgithubexample1.pyにあります。

まず必要な定数を置き

import tensorflow as tf

tf.app.flags.DEFINE_integer('version', 0, 'version')
tf.app.flags.DEFINE_integer('x', 0, 'x')
tf.app.flags.DEFINE_integer('y', 0, 'y')

MODEL_NAME = 'default'
VERSION = tf.app.flags.FLAGS.version
SERVING_HOST = 'localhost'
SERVING_PORT = 9000
X = tf.app.flags.FLAGS.x
Y = tf.app.flags.FLAGS.y
EXPORT_DIR = os.path.join(os.path.dirname(__file__), 'tmp', str(VERSION))

グラフを定義します

# define graph
graph = tf.Graph()
with graph.as_default():
    x = tf.placeholder(dtype=tf.int64, shape=(), name='x')
    y = tf.placeholder(dtype=tf.int64, shape=(), name='y')
    x_add_y = tf.add(x=x, y=y)

    # test run graph
    with tf.Session() as sess:
        print('local x_add_y run result: {}'.format(sess.run(x_add_y, feed_dict={x: X, y: Y})))

x + yです。XとYに適当な値を入れて呼ぶと実行されます。

次にこれをserving用にexportします。

# save current graph for serving
builder = tf.saved_model.builder.SavedModelBuilder(EXPORT_DIR)
with tf.Session(graph=graph) as sess:
    builder.add_meta_graph_and_variables(
        sess, [tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            'x_add_y': tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'x': tf.saved_model.utils.build_tensor_info(x), 'y': tf.saved_model.utils.build_tensor_info(y)},
                outputs={'x_add_y':  tf.saved_model.utils.build_tensor_info(x_add_y)},
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME,
            ),
        },
    )
    builder.save()

SavedModelBuilderは変数とgraphをSavedModelのprotocol-buffer形式で保存してくれます。 このbuilderに対してメタ情報として保存したい変数とgraphのあるsessionやsignature_def_mapを渡し、saveするという流れになります。今回はただの足し算で、変数の値は必要なくgraphのみで良いので、sess.runしたのとは別のsessionで保存しています。

signature_def_mapは1つのgraphに対して、実行したい計算が通常複数あるので、それに対応するため名前とsignature_defのdictになっています。

signature_defはというと、実際に計算を行ない結果を取り出すために必要な情報が格納されています。inputsoutputsはそれぞれ、計算の入力と出力に必要な変数名とそれを入れるべきtensor_infoのdictです。ここでの名前はgraph内での名前(placeholderのname引数)と一致している必要はありません。graphの方で名前を変えてしまっても、ここを統一しておけば同じinterfaceで呼び出せるというわけです。

この辺はtf.saved_model.utils以下に便利な各種builderが用意されているのでこれらを使うのが良いでしょう。

EXPORT_DIRはコンテナをマウントしたディレクトリからtmp/<version:int>となるようなパスを指定してください。

さて、python example1.py --x=12 --y=34 --version=0 を実行してからmodel_serverの出力を見ると

Loading SavedModel from: /root/serving-example/tmp/0
Restoring SavedModel bundle.
The specified SavedModel has no variables; no checkpoints were restored.
Running LegacyInitOp on SavedModel bundle.
Loading SavedModel: success. Took 3604 microseconds.
Successfully loaded servable version {name: default version: 0}

というログが流れているのではないでしょうか。 ここから、versionを1にして再度実行すると、version 1のモデルが読み込まれ

Quiescing servable version {name: default version: 0}
Done quiescing servable version {name: default version: 0}
Unloading servable version {name: default version: 0}
Calling MallocExtension_ReleaseToSystem() after servable unload with 534
Done unloading servable version {name: default version: 0}

という風にバージョン1のロードに成功した後にバージョン0をアンロードしてくれます。

さらに、servingはディレクトリの削除を検知して、ディレクトリの中のモデルが現在メモリに載っている場合自動で古いモデルを読み込みます。 更新もロールバックもファイルの移動だけで済んでしまうわけです。

servingにリクエストを投げてみる

gprcはpip install grpcioで入るのですが、servingに投げるリクエストを作るためのtensorflow-serving-apiパッケージはpython3向けのものがありません。が、python2向けのものがそのままptyhon3でも動くので、 pip からzipを落としてきてローカルに置きましょう。(issueはあるのですが、ここはあまり対応する気がなさそうなので気長に待ちましょう)

まずクライアントのstubを作成します

from grpc.beta import implementations
from tensorflow_serving.apis import prediction_service_pb2

# create grpc stub
channel = implementations.insecure_channel(SERVING_HOST, SERVING_PORT)
stub = prediction_service_pb2.beta_create_PredictionService_stub(channel)

次にリクエストを作成します

from tensorflow.core.framework import types_pb2
from google.protobuf import wrappers_pb2
from tensorflow_serving.apis import predict_pb2

# create predict request
request = predict_pb2.PredictRequest()
request.model_spec.name = MODEL_NAME
# version = wrappers_pb2.Int64Value()
# version.value = VERSION
# request.model_spec.version.CopyFrom(version)
request.model_spec.signature_name = 'x_add_y'

request.inputs['x'].dtype = types_pb2.DT_INT64
request.inputs['x'].int64_val.append(X)
request.inputs['y'].dtype = types_pb2.DT_INT64
request.inputs['y'].int64_val.append(Y)

リクエストにはいくつか種類がありますが、今回はpredictを使います。servingに投げるリクエストは、model_spec下にあるリクエストを投げる対象のモデルとinputs下にある入力内容の2つの要素からなります。

model_specで指定できる内容はモデルのname、そしてversionsignature_nameです。バージョンの指定は任意であり、指定しない場合は自動で最新のバージョンが呼ばれます。呼び出し側は今どのバージョンがロードされているか意識する必要がないわけです。

inputssignature_def_mapで指定した変数名とその中身というdictになります。placeholderで指定しておいたdtypeと同じ型を明示し、それに対応したプロパティに値をappendすることで入力値をセットできます。

これで準備はOKです。リクエストを投げましょう。

result_future = stub.Predict.future(request, 1)
result = result_future.result()
print('serving x_add_y run result: {}'.format(result.outputs['x_add_y'].int64_val[0]))

上手く値が出力されたでしょうか。local x_add_y run resultserving x_add_y run resultが同じ値になっていれば成功です。 outputsはinputsと同じく変数名をkeyとするdictになっており、graph内でのdtypeと同じプロパティにアクセスすることで中身が取得できます。こちらもinputsと同じく配列になっていますが、今回は出力は1つだけです。

まとめ

生tensorflowでローカルにserving用のモデルを出力し、実際にリクエストを投げるところまでできました。

これで自分で定義したグラフをservingでバージョニングすることができるようになりました!サーバーを再起動しなくてもモデルの更新ができるようになった、のですが、これだと各servingサーバーのローカルにモデルファイルを配布する必要があり、pythonで全てを書いていた時と比較して楽になっていません。さらにtensorflowでplaceholderなどを直接定義していることもコードの可読性を下げています。

次回以降はローカルのファイルシステムではなくS3でモデルファイルを管理する方法や、生のtensorflowで書くことをやめてkerasやestimatorを使う方法、実際のAWS上でのインフラ構成等を紹介する予定です。