メインコンテンツへスキップ
Keras コールバックを使用すると、実験を管理し、モデル チェックポイントをログし、モデルの予測を可視化できます。Keras コールバックは、Pyhon SDK バージョン 0.13.4 以降で wandb.integration.keras モジュールから利用できます。 W&B Keras インテグレーションでは、次のコールバックを提供しています。
  • WandbMetricsLogger : 実験管理 にはこのコールバックを使用します。トレーニングおよび検証のメトリクスに加えて、システム メトリクスも W&B にログします。
  • WandbModelCheckpoint : モデル チェックポイントを W&B Artifacts にログするには、このコールバックを使用します。
  • WandbEvalCallback: このベース コールバックは、モデルの予測を W&B Tables にログし、インタラクティブに可視化できるようにします.

Keras インテグレーションのインストールとインポート

W&B の最新バージョンをインストールします。
pip install -U wandb
Kerasインテグレーションを使用するには、wandb.integration.keras から必須クラスをインポートします:
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint, WandbEvalCallback
以下のセクションでは、各コールバックについてコード例とともに詳しく説明します。

WandbMetricsLogger で実験を管理する

wandb.integration.keras.WandbMetricsLogger() は、on_epoch_endon_batch_end などのコールバックが引数として受け取る Keras の logs 辞書を自動的にログします。 以下の抜粋例では、Keras の workflow で WandbMetricsLogger() を使用する方法を示します。まず、使用したい optimizer、損失関数、メトリクスを指定してモデルをコンパイルします。次に、wandb.init() を使用して W&B run を初期化します。最後に、WandbMetricsLogger() コールバックを model.fit() に渡します。
import wandb
from wandb.integration.keras import WandbMetricsLogger
import tensorflow as tf

model.compile(
    optimizer = "adam",
    loss = "categorical_crossentropy",
    metrics = ["accuracy", tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top@5_accuracy')]
)

# 新しい W&B Run を初期化する
with wandb.init(config={"batch_size": 64}) as run:

    # WandbMetricsLogger を model.fit に渡す
    model.fit(
        X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbMetricsLogger()]
    )
前の例では、各エポックの終了時に、lossaccuracytop@5_accuracy などのトレーニングおよび検証メトリクスを W&B にログします。さらに、以下の内容もログします。

WandbMetricsLogger リファレンス

パラメーター説明
log_freq(epochbatch、または int): epoch の場合は各エポックの終了時にメトリクスをログします。batch の場合は各バッチの終了時にメトリクスをログします。int の場合は指定したバッチ数ごとにメトリクスをログします。デフォルトは epoch です。
initial_global_step(int): initial_epoch からトレーニングを再開し、学習率スケジューラを使用している場合に、学習率を正しくログするためにこの引数を使用します。これは step_size * initial_step として計算できます。デフォルトは 0 です。

WandbModelCheckpoint を使用してモデルをチェックポイントする

WandbModelCheckpoint コールバックを使用すると、Keras モデル (SavedModel 形式) またはモデルの重みを定期的に保存し、モデルのバージョン管理のために wandb.Artifact として W&B にアップロードできます。 このコールバックは tf.keras.callbacks.ModelCheckpoint() のサブクラスであるため、チェックポイントのロジックは親コールバックによって処理されます。 このコールバックでは、次を保存できます。
  • monitor に基づいて最高のパフォーマンスを達成したモデル
  • パフォーマンスに関係なく、各エポックの終了時のモデル
  • エポックの終了時、または一定数のトレーニングバッチごとのモデル
  • モデルの重みのみ、またはモデル全体
  • SavedModel 形式または .h5 形式のモデル
このコールバックは WandbMetricsLogger() と併用してください。
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint

# 新しい W&B run を初期化する
with wandb.init(config={"bs": 12}) as run:

    # WandbModelCheckpoint を model.fit に渡す
    model.fit(
        X_train,
        y_train,
        validation_data=(X_test, y_test),
        callbacks=[
            WandbMetricsLogger(),
            WandbModelCheckpoint("models"),
        ],
    )

WandbModelCheckpoint リファレンス

パラメーター説明
filepath(str): モデルファイルを保存するパス。
monitor(str): 監視するメトリクスの名前。
verbose(int): 詳細表示モード。0 または 1。モード 0 ではメッセージを表示せず、モード 1 ではコールバックがアクションを実行した際にメッセージを表示します。
save_best_only(Boolean): save_best_only=True の場合、monitor 属性と mode 属性で定義された内容に基づき、最新のモデル、または最良と判断されたモデルのみを保存します。
save_weights_only(Boolean): True の場合、モデルの重みのみを保存します。
mode(auto, min, or max): val_acc の場合は maxval_loss の場合は min に設定します。
save_freq(“epoch” or int): ‘epoch’ を使用すると、コールバックは各エポックの終了後にモデルを保存します。整数を使用すると、その数のバッチの終了時にコールバックがモデルを保存します。val_accval_loss などの検証メトリクスを監視する場合、これらのメトリクスはエポックの終了時にのみ利用できるため、save_freq は “epoch” に設定する必要があります。
options(str): save_weights_only が true の場合は省略可能な tf.train.CheckpointOptions オブジェクト、save_weights_only が false の場合は省略可能な tf.saved_model.SaveOptions オブジェクト。
initial_value_threshold(float): 監視するメトリクスの初期「最良」値となる浮動小数点数。

N エポックごとにチェックポイントをログする

デフォルトでは (save_freq="epoch") 、コールバックは各エポックの後にチェックポイントを作成し、それを artifact としてアップロードします。特定のバッチ数ごとにチェックポイントを作成するには、save_freq を整数に設定します。N エポックごとにチェックポイントを作成するには、train データローダーの要素数を計算して、それを save_freq に渡します。
WandbModelCheckpoint(
    filepath="models/",
    save_freq=int((trainloader.cardinality()*N).numpy())
)

TPU アーキテクチャでチェックポイントを効率的にログする

TPU でチェックポイントを作成する際に、UnimplementedError: File system scheme '[local]' not implemented というエラーメッセージが表示されることがあります。これは、モデルディレクトリ (filepath) にはクラウドストレージのバケットパス (gs://bucket-name/...) を使用する必要があり、さらにそのバケットに TPU サーバーからアクセスできなければならないためです。一方、W&B ではチェックポイント作成にローカルパスを使用し、その後 artifact としてアップロードします。
checkpoint_options = tf.saved_model.SaveOptions(experimental_io_device="/job:localhost")

WandbModelCheckpoint(
    filepath="models/,
    options=checkpoint_options,
)

WandbEvalCallback を使用してモデルの予測を可視化する

WandbEvalCallback() は、主にモデル予測、次いでデータセットの可視化を目的とした Keras コールバックを構築するための抽象基底クラスです。 この抽象コールバックは、データセットやタスクに依存しません。これを使用するには、基底コールバッククラス WandbEvalCallback() を継承し、add_ground_truth メソッドと add_model_prediction メソッドを実装します。 WandbEvalCallback() は、次の機能を提供するユーティリティクラスです。
  • データと予測用の wandb.Table() インスタンスを作成する。
  • データと予測のテーブルを wandb.Artifact() としてログする。
  • on_train_begin でデータテーブルをログする。
  • on_epoch_end で予測テーブルをログする。
次の例では、画像分類タスクで WandbClfEvalCallback を使用します。このコールバックは、検証データ (data_table) を W&B にログし、推論を実行して、各エポックの終了時に予測結果 (pred_table) を W&B にログします。
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbEvalCallback


# モデル予測の可視化コールバックを実装する
class WandbClfEvalCallback(WandbEvalCallback):
    def __init__(
        self, validation_data, data_table_columns, pred_table_columns, num_samples=100
    ):
        super().__init__(data_table_columns, pred_table_columns)

        self.x = validation_data[0]
        self.y = validation_data[1]

    def add_ground_truth(self, logs=None):
        for idx, (image, label) in enumerate(zip(self.x, self.y)):
            self.data_table.add_data(idx, wandb.Image(image), label)

    def add_model_predictions(self, epoch, logs=None):
        preds = self.model.predict(self.x, verbose=0)
        preds = tf.argmax(preds, axis=-1)

        table_idxs = self.data_table_ref.get_index()

        for idx in table_idxs:
            pred = preds[idx]
            self.pred_table.add_data(
                epoch,
                self.data_table_ref.data[idx][0],
                self.data_table_ref.data[idx][1],
                self.data_table_ref.data[idx][2],
                pred,
            )


# ...

# 新しい W&B Run を初期化する
with wandb.init(config={"hyper": "parameter"}) as run:

    # Model.fit にコールバックを追加する
    model.fit(
        X_train,
        y_train,
        validation_data=(X_test, y_test),
        callbacks=[
            WandbMetricsLogger(),
            WandbClfEvalCallback(
                validation_data=(X_test, y_test),
                data_table_columns=["idx", "image", "label"],
                pred_table_columns=["epoch", "idx", "image", "label", "pred"],
            ),
        ],
    )

WandbEvalCallback リファレンス

パラメーター説明
data_table_columns(list) data_table の列名の一覧
pred_table_columns(list) pred_table の列名の一覧

メモリ使用量の詳細

on_train_begin method が呼び出されると、data_table を W&B にログします。これが W&B Artifact としてアップロードされると、この表への参照を取得でき、data_table_ref クラス変数を使ってアクセスできます。data_table_ref は 2 次元リストで、self.data_table_ref[idx][n] のようにインデックス指定できます。ここで、idx は行番号、n は列番号です。以下の例で使い方を見てみましょう。

コールバックをカスタマイズする

より細かく制御するには、on_train_begin または on_epoch_end の method をオーバーライドします。N バッチごとにサンプルをログしたい場合は、on_train_batch_end method を実装できます。
WandbEvalCallback を継承してモデルの予測可視化用コールバックを実装していて、不明な点や修正が必要な点があれば、issue を作成してお知らせください。

WandbCallback [レガシー]

W&B ライブラリの WandbCallback() クラスを使用すると、model.fit() でトラッキングされるすべてのメトリクスと損失値を自動的に保存できます。
import wandb
from wandb.integration.keras import WandbCallback

with wandb.init(config={"hyper": "parameter"}) as run:

    # Kerasでモデルを設定するコード

    # コールバックをmodel.fitに渡す
    model.fit(
        X_train, y_train, validation_data=(X_test, y_test), callbacks=[WandbCallback()]
    )
短い動画 1分以内で始める Keras と W&B を視聴できます。 より詳しい動画は、Keras と W&B を統合するをご覧ください。Colab Jupyter Notebookも確認できます。
スクリプトについては、example repoを参照してください。これには、Fashion MNIST の例と、それによって生成される W&B ダッシュボード が含まれます。
WandbCallback クラスは、さまざまなログ設定オプションをサポートします。たとえば、監視するメトリクスの指定、重みと勾配のトラッキング、training_data と validation_data に対する予測のログなどです。 詳しくは、keras.WandbCallback のリファレンスドキュメントを参照してください。 WandbCallback は次のことを行います
  • Keras が収集したすべてのメトリクスの履歴データを自動的にログします。これには、損失や keras_model.compile() に渡されたすべての項目が含まれます。
  • monitor 属性と mode 属性で定義される「最良」のトレーニング step に関連付けられた run の summary メトリクスを設定します。デフォルトでは、これは val_loss が最小のエポックです。WandbCallback はデフォルトで、最良の epoch に対応するモデルを保存します。
  • 必要に応じて、勾配とパラメーターのヒストグラムをログします。
  • 必要に応じて、wandb が可視化できるようにトレーニングデータと検証データを保存します。

WandbCallback リファレンス

引数
monitor(str) 監視するメトリクスの名前。デフォルトは val_loss です。
mode(str) {auto, min, max} のいずれかです。min - monitor が最小になるときにモデルを保存します max - monitor が最大になるときにモデルを保存します auto - モデルを保存するタイミングを自動的に推定します (デフォルト) 。
save_modelTrue - monitor がそれまでのすべてのエポックを上回った場合にモデルを保存します False - モデルを保存しません
save_graph(boolean) True の場合、モデルのグラフを wandb に保存します (デフォルトは True) 。
save_weights_only(boolean) True の場合は、モデルの重みのみを保存します (model.save_weights(filepath)) 。それ以外の場合は、モデル全体を保存します) 。
log_weights(boolean) True の場合、モデルの各レイヤーの重みのヒストグラムを保存します。
log_gradients(boolean) True の場合、トレーニング中の勾配のヒストグラムをログします
training_data(tuple) model.fit に渡す (X,y) と同じ形式です。勾配の計算に必要で、log_gradientsTrue の場合は必須です。
validation_data(tuple) model.fit に渡した (X,y) と同じ形式。wandb が可視化するためのデータです。このフィールドを設定すると、各エポックで wandb は少数の予測を実行し、後で可視化できるようにその結果を保存します。
generator(generator) wandb が可視化するための検証データを返すジェネレーターです。このジェネレーターは (X,y) のタプルを返す必要があります。wandb で特定のデータ例を可視化するには、validate_data または generator のいずれかを設定する必要があります。
validation_steps(int) validation_data がジェネレーターである場合、検証セット全体に対してジェネレーターを何 step 実行するか。
labels(list) wandb でデータを可視化する場合、このラベルのリストを指定すると、複数クラスの分類器を構築しているときに、数値出力が分かりやすい文字列に変換されます。二値分類器の場合は、2 つのラベル [label for false, label for true] からなるリストを渡せます。validate_datagenerator が両方とも false の場合、これは何もしません。
predictions(int) 各エポックで可視化用に行う予測数です。最大値は 100 です。
input_type(string) 可視化しやすくするためのモデル入力のタイプ。次のいずれかを指定できます: (image, images, segmentation_mask)。
output_type(string) 視覚化に役立つモデル出力のタイプ。次のいずれかを指定できます: (image, images, segmentation_mask)。
log_evaluation(boolean) True の場合、各エポックで、検証データとモデルの予測を含む表を保存します。詳しくは validation_indexesvalidation_row_processoroutput_row_processor を参照してください。
class_colors([float, float, float]) 入力または出力がセグメンテーションマスクの場合、各クラスに対応する RGB タプル (範囲 0~1) を含む配列。
log_batch_frequency(integer) None の場合、callback は各エポックでログします。整数を設定すると、callback は log_batch_frequency バッチごとにトレーニングメトリクスをログします。
log_best_prefix(string) None の場合、追加の summary メトリクスは保存されません。文字列を設定すると、監視対象のメトリクスとエポックの先頭にそのプレフィックスを付け、結果を summary メトリクスとして保存します。
validation_indexes([wandb.data_types._TableLinkMixin]) 各検証例に関連付けるインデックスキーの順序付きリスト。log_evaluation が True で validation_indexes を指定すると、検証データのTableは作成されません。代わりに、各予測が TableLinkMixin で表される行に関連付けられます。行キーのリストを取得するには、Table.get_index() を使用します。
validation_row_processor(Callable) 検証データに適用する関数で、通常はデータの可視化に使用します。この関数は ndx (int) と row (dict) を受け取ります。モデルの入力が 1 つだけの場合、row["input"] にはその行の入力データが含まれます。それ以外の場合は、入力スロットの名が含まれます。fit 関数が単一のターゲットを受け取る場合、row["target"] にはその行のターゲットデータが含まれます。それ以外の場合は、出力スロットの名が含まれます。たとえば、入力データが単一の配列で、データを Image として可視化する場合は、プロセッサとして lambda ndx, row: {"img": wandb.Image(row["input"])} を指定します。log_evaluation が False の場合、または validation_indexes が指定されている場合は無視されます。
output_row_processor(Callable) validation_row_processor と同様ですが、モデルの出力に対して適用されます。row["output"] にはモデルの出力結果が含まれます。
infer_missing_processors(Boolean) validation_row_processoroutput_row_processor が存在しない場合に、それらを推論するかどうかを指定します。デフォルトは True です。labels を指定すると、W&B は必要に応じて分類用のプロセッサを推論します。
log_evaluation_frequency(int) 評価結果をどの頻度でログするかを指定します。デフォルトは 0 で、この場合はトレーニング終了時にのみログします。1 に設定すると毎エポック、2 に設定すると 1 エポックおき、というようにログします。log_evaluation が False の場合は効果はありません。

よくある質問

wandbKeras のマルチプロセシングを使用するにはどうすればよいですか?

use_multiprocessing=True を設定すると、次のエラーが発生することがあります。
Error("You must call wandb.init() before wandb.config.batch_size")
これを回避するには、次のようにします。
  1. Sequence クラスの構築時に、wandb.init(group='...') を追加します。
  2. main では、if __name__ == "__main__": を使用していることを確認し、スクリプトの残りのロジックはその中に記述します。