我是靠谱客的博主 苗条白云,最近开发中收集的这篇文章主要介绍tensorflow自定义神经网络模型,觉得挺不错的,现在分享给大家,希望可以做个参考。

概述

建立和训练一个简单的模型通常涉及几个步骤:

  • 定义模型。
  • 定义损失函数。
  • 获取训练数据。
  • 运行训练数据并使用“优化器”调整变量以拟合数据。

1.定义模型

class Model(object):
  def __init__(self):
    # Initialize variable to (5.0, 0.0)
    # In practice, these should be initialized to random values.
    self.W = tf.Variable(5.0)
    self.b = tf.Variable(0.0)
    
  def __call__(self, x):
    return self.W * x + self.b
  
model = Model()

assert model(3.0).numpy() == 15.0

2.定义损失函数

def loss(predicted_y, desired_y):
  return tf.reduce_mean(tf.square(predicted_y - desired_y))

3.获取训练数据

TRUE_W = 3.0
TRUE_b = 2.0
NUM_EXAMPLES = 1000

inputs  = tf.random_normal(shape=[NUM_EXAMPLES])
noise   = tf.random_normal(shape=[NUM_EXAMPLES])
outputs = inputs * TRUE_W + TRUE_b + noise

可视化数据:

import matplotlib.pyplot as plt

plt.scatter(inputs, outputs, c='b')
plt.scatter(inputs, model(inputs), c='r')
plt.show()

print('Current loss: '),
print(loss(model(inputs), outputs).numpy())

4.模型训练

现在拥有了网络和训练数据。让我们训练它,即使用训练数据来更新模型的变量(W和b),以便使用梯度下降来减少损失。

def train(model, inputs, outputs, learning_rate):
  with tf.GradientTape() as t:
    current_loss = loss(model(inputs), outputs)
  dW, db = t.gradient(current_loss, [model.W, model.b])
  model.W.assign_sub(learning_rate * dW)
  model.b.assign_sub(learning_rate * db)
model = Model()

# Collect the history of W-values and b-values to plot later
Ws, bs = [], []
epochs = range(10)
for epoch in epochs:
  Ws.append(model.W.numpy())
  bs.append(model.b.numpy())
  current_loss = loss(model(inputs), outputs)

  train(model, inputs, outputs, learning_rate=0.1)
  print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' %
        (epoch, Ws[-1], bs[-1], current_loss))

# Let's plot it all
plt.plot(epochs, Ws, 'r',
         epochs, bs, 'b')
plt.plot([TRUE_W] * len(epochs), 'r--',
         [TRUE_b] * len(epochs), 'b--')
plt.legend(['W', 'b', 'true W', 'true_b'])
plt.show()

 

最后

以上就是苗条白云为你收集整理的tensorflow自定义神经网络模型的全部内容,希望文章能够帮你解决tensorflow自定义神经网络模型所遇到的程序开发问题。

如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(106)

评论列表共有 0 条评论

立即
投稿
返回
顶部