Friday, 20 September 2019

TensorFlow RNN: Text Classification

TensorFlow has RNN features which are similar to Keras. RNN can be used to do classification or generation. RNN classifies best with sequential data with similarity in time. The following example shows how to classify text with single LSTM layer or multiple LSTM layers. LSTM layer is better than basic RNN layer as basic RNN layer faces vanishing gradients problem as backpropagation in big time dimension is just like in too deep network.

Source code:
import tensorflow as tf;
tf.reset_default_graph();

#data
'''
t0      t1      t2
british gray    is => cat (y=0)
0       1       2
white   samoyed is => dog (y=1)
3       4       2 
'''
Bsize = 2;
Times = 3;
Max_X = 4;
Max_Y = 1;

X = [[[0],[1],[2]], [[3],[4],[2]]];
Y = [[0],           [1]          ];

#normalise
for I in range(len(X)):
  for J in range(len(X[I])):
    X[I][J][0] /= Max_X;

for I in range(len(Y)):
  Y[I][0] /= Max_Y;

#model
Inputs   = tf.placeholder(tf.float32, [Bsize,Times,1]);
Expected = tf.placeholder(tf.float32, [Bsize,      1]);

#single LSTM layer
#'''
Layer1   = tf.keras.layers.LSTM(20);
Hidden1  = Layer1(Inputs);
#'''

#multi LSTM layers
'''
Layers = tf.keras.layers.RNN([
  tf.keras.layers.LSTMCell(30), #hidden 1
  tf.keras.layers.LSTMCell(20)  #hidden 2
]);
Hidden2 = Layers(Inputs);
'''

Weight3  = tf.Variable(tf.random_uniform([20,1], -1,1));
Bias3    = tf.Variable(tf.random_uniform([   1], -1,1));
Output   = tf.sigmoid(tf.matmul(Hidden1,Weight3) + Bias3);

Loss     = tf.reduce_sum(tf.square(Expected-Output));
Optim    = tf.train.GradientDescentOptimizer(1e-1);
Training = Optim.minimize(Loss);

#train
Sess = tf.Session();
Init = tf.global_variables_initializer();
Sess.run(Init);

Feed = {Inputs:X, Expected:Y};
for I in range(1000): #number of feeds, 1 feed = 1 batch
  if I%100==0: 
    Lossvalue = Sess.run(Loss,Feed);
    print("Loss:",Lossvalue);
  #end if

  Sess.run(Training,Feed);
#end for

Lastloss = Sess.run(Loss,Feed);
print("Loss:",Lastloss,"(Last)");

#eval
Results = Sess.run(Output,Feed);
print("\nEval:");
print(Results);

print("\nDone.");
#eof

Colab link:
https://colab.research.google.com/drive/1_TRH5ZshDApJC6JRHdRJAiBWLw3f4wRF

No comments:

Post a Comment