The magical things in machine learning with TensorFlow are the optimisers. They change the values of variables in the network to optimise it, ie. miminise the final loss.
Hereby model is trained (fit weights) to do XOR regression:
#core import sys; #libs import tensorflow as tf; import numpy as np; import matplotlib.pyplot as pyplot; ''' \brief Log ''' def log(*Args): Str = ""; for I in range(len(Args)): Str += str(Args[I]); if (I<len(Args)-1): Str += "\x20"; tf.logging.info(Str); #end def ''' \brief Log a separator ''' def log_sep(): tf.logging.info("\n"+"="*80); #end def #PROGRAMME ENTRY POINT========================================================== tf.logging.set_verbosity(tf.logging.INFO); log("TensorFlow version:",tf.__version__); #data X = [[0,0],[0,1],[1,0],[1,1]]; Y = [[0], [1], [1], [0] ]; X = np.array(X, dtype=np.float32); Y = np.array(Y, dtype=np.float32); #dnn Batch_Size = 4; Input = tf.placeholder(tf.float32, shape=[Batch_Size,2]); Expected = tf.placeholder(tf.float32, shape=[Batch_Size,1]); Weight1 = tf.Variable(tf.random_uniform([2,20], -1,1), name="weight1"); Bias1 = tf.Variable(tf.random_uniform([ 20], -1,1), name="bias1"); Hidden1 = tf.nn.relu(tf.matmul(Input,Weight1) + Bias1); Weight2 = tf.Variable(tf.random_uniform([20,1], -1,1), name="weight2"); Bias2 = tf.Variable(tf.random_uniform([ 1], -1,1), name="bias2"); Output = tf.sigmoid(tf.matmul(Hidden1,Weight2) + Bias2); Loss = tf.reduce_sum(tf.square(Expected-Output)); Optimiser = tf.train.GradientDescentOptimizer(1e-1); Training = Optimiser.minimize(Loss); #training Sess = tf.Session(); Init = tf.global_variables_initializer(); Sess.run(Init); Losses = []; for I in range(10000): if (I%100==0): Lossvalue = Sess.run(Loss, feed_dict={Input:X, Expected:Y}); print("Loss:",Lossvalue); Losses += [Lossvalue]; #end if Sess.run(Training, feed_dict={Input:X, Expected:Y}); #end for print("Eval:",Sess.run(Output, feed_dict={Input:X, Expected:Y})); print("Loss curve:"); pyplot.plot(Losses); #eof
Colab link:
https://colab.research.google.com/drive/1trM0syMQzNTEMXPztDIYjKfU9cFdW1lt
No comments:
Post a Comment