Wizard Notes

音楽信号解析の技術録、作曲活動に関する雑記

Kerasでカスタムレイヤーを含むモデルをsave/load_modelする

f:id:Kurene:20210110000321p:plain

Kerasでカスタムレイヤーを含むモデルを保存/読み込むヒントとなる情報があったので試してみました。

まず、以下の日本語のドキュメントの方法のコードでは、Custom layerを含むモデルのsave/loadはできませんでした。

オリジナルのKerasレイヤーを作成する - Keras Documentation

このコードに対して、以下の2点を追加したところsave/load_modelを実行できました。

  • get_config()の追加
  • load_model()にキーワード引数custom_objectsを追加

各項目の詳細については、コードや参考Webページをご覧ください。

特にハマったところは、get_config()の書き方でした。

いろいろ検索したところ、以下のWebページに有力そうな情報が書かれていました。

load_model() with custom layers, and custom layers in general · Issue #4871 · keras-team/keras · GitHub

つまり、

カスタムレイヤー+Layerクラスのコンストラクタ(__init__)のキーワード引数名と値を辞書として返せばよい

とのこと。

以下のソースコードでは、親クラスであるLayerクラスのコンストラクタの引数名・値はbase_configで、カスタムレイヤーでのコンストラクタの引数名・値はconfigに入っており、最終的にそれを結合して返しています。

検証用のソースコード

実行結果

model.save(),model=model_load() の前後でmodel.predict(x) を実施し出力が一致するか検証しました。

なお、トレーニングの内容は適当です。

Epoch 1/10
1/1 [==============================] - 0s 156ms/step - loss: 0.4172
Epoch 2/10
1/1 [==============================] - 0s 0s/step - loss: 0.4125
Epoch 3/10
1/1 [==============================] - 0s 0s/step - loss: 0.4078
Epoch 4/10
1/1 [==============================] - 0s 0s/step - loss: 0.4032
Epoch 5/10
1/1 [==============================] - 0s 0s/step - loss: 0.3987
Epoch 6/10
1/1 [==============================] - 0s 0s/step - loss: 0.3942
Epoch 7/10
1/1 [==============================] - 0s 0s/step - loss: 0.3898
Epoch 8/10
1/1 [==============================] - 0s 0s/step - loss: 0.3855
Epoch 9/10
1/1 [==============================] - 0s 0s/step - loss: 0.3812
Epoch 10/10
1/1 [==============================] - 0s 16ms/step - loss: 0.3770
Before save: model.predict(x) =>
[[ 0.00595104  0.04227125  0.07621967 -0.00854966]
 [-0.02209851  0.05556415  0.06378415 -0.00197074]
 [-0.01329668 -0.0070225   0.05794774 -0.03990902]
 [-0.03630912  0.01880281  0.02773816 -0.01863302]
 [ 0.00922568  0.02720691  0.10599497 -0.02486535]]
After load_model: model.predict(x) =>
[[ 0.00595104  0.04227125  0.07621967 -0.00854966]
 [-0.02209851  0.05556415  0.06378415 -0.00197074]
 [-0.01329668 -0.0070225   0.05794774 -0.03990902]
 [-0.03630912  0.01880281  0.02773816 -0.01863302]
 [ 0.00922568  0.02720691  0.10599497 -0.02486535]]

参考文献