Thursday, 24 October 2019

TensorFlow: Load Model to Continue Training

TensorFlow low-level model based on tf.Module is easy to save just as Keras model but low-level model is hard to continue training after loading back as custom functions must be created to assign weight values. The following is an example to save/load Keras model to continue training.

Source code:
%tensorflow_version 2.x
%reset -f

#libs
import tensorflow as tf;
from tensorflow.keras.layers import *;

#constants
BSIZE = 4;

#model
class model(tf.keras.Model):
  def __init__(this):
    super().__init__();
    this.W1 = tf.Variable(tf.random.uniform([2,20], -1,1));
    this.B1 = tf.Variable(tf.random.uniform([  20], -1,1));

    this.W2 = tf.Variable(tf.random.uniform([20,1], -1,1));
    this.B2 = tf.Variable(tf.random.uniform([   1], -1,1));

  #@tf.function(input_signature=[tf.TensorSpec([BSIZE,2])])
  def call(this,Inp):
    H1  = tf.nn.leaky_relu(tf.matmul(Inp,this.W1) + this.B1);
    Out = tf.sigmoid(tf.matmul(H1,this.W2) + this.B2);
    return Out;

#data
X = tf.convert_to_tensor([[0,0],[0,1],[1,0],[1,1]], tf.float32);
Y = tf.convert_to_tensor([[0],  [1],  [1],  [0]  ], tf.float32);

#train
Model = model();

#hard to resume training with this low-level training procedure:
'''
Loss  = tf.losses.MeanSquaredError();
Optim = tf.optimizers.SGD(1e-1);
Steps = 100;

for I in range(Steps):
  if I%(Steps/10)==0:
    Out = Model(X);
    Lv  = Loss(Y,Out);
    print("Loss:",Lv.numpy());

  with tf.GradientTape() as T:
    Out = Model(X);
    Lv  = Loss(Y,Out);

  Grads = T.gradient(Lv, Model.trainable_variables);
  Optim.apply_gradients(zip(Grads, Model.trainable_variables));

Out = Model(X);
Lv  = Loss(Y,Out);
print("Loss:",Lv.numpy(),"(Last)");
'''

#easier to resume training with keras
Model.compile(loss=tf.losses.MeanSquaredError(), optimizer=tf.optimizers.SGD(1e-1));
Model.fit(X,Y, batch_size=4, epochs=10, verbose=0);
print("Test:");
print(Model.predict(X, batch_size=4, verbose=0));

#save
print("\nSaving model...");
tf.keras.models.save_model(Model,"/tmp/models/test");

#load
print("\nLoading model to train more...");
M = tf.keras.models.load_model("/tmp/models/test");
print(M.predict(X, batch_size=4, verbose=0));

#continue training
M.fit(X,Y, batch_size=4, epochs=5000, verbose=0);
print("\nTest:");
print(M.predict(X, batch_size=4, verbose=0));
#eof

No comments:

Post a Comment