我是靠谱客的博主 苗条白云,这篇文章主要介绍tensorflow自定义神经网络模型,现在分享给大家,希望可以做个参考。

建立和训练一个简单的模型通常涉及几个步骤:

  • 定义模型。
  • 定义损失函数。
  • 获取训练数据。
  • 运行训练数据并使用“优化器”调整变量以拟合数据。

1.定义模型

复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
class Model(object): def __init__(self): # Initialize variable to (5.0, 0.0) # In practice, these should be initialized to random values. self.W = tf.Variable(5.0) self.b = tf.Variable(0.0) def __call__(self, x): return self.W * x + self.b model = Model() assert model(3.0).numpy() == 15.0

2.定义损失函数

复制代码
1
2
def loss(predicted_y, desired_y): return tf.reduce_mean(tf.square(predicted_y - desired_y))

3.获取训练数据

复制代码
1
2
3
4
5
6
7
TRUE_W = 3.0 TRUE_b = 2.0 NUM_EXAMPLES = 1000 inputs = tf.random_normal(shape=[NUM_EXAMPLES]) noise = tf.random_normal(shape=[NUM_EXAMPLES]) outputs = inputs * TRUE_W + TRUE_b + noise

可视化数据:

复制代码
1
2
3
4
5
6
7
8
import matplotlib.pyplot as plt plt.scatter(inputs, outputs, c='b') plt.scatter(inputs, model(inputs), c='r') plt.show() print('Current loss: '), print(loss(model(inputs), outputs).numpy())

4.模型训练

现在拥有了网络和训练数据。让我们训练它,即使用训练数据来更新模型的变量(W和b),以便使用梯度下降来减少损失。

复制代码
1
2
3
4
5
6
def train(model, inputs, outputs, learning_rate): with tf.GradientTape() as t: current_loss = loss(model(inputs), outputs) dW, db = t.gradient(current_loss, [model.W, model.b]) model.W.assign_sub(learning_rate * dW) model.b.assign_sub(learning_rate * db)
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
model = Model() # Collect the history of W-values and b-values to plot later Ws, bs = [], [] epochs = range(10) for epoch in epochs: Ws.append(model.W.numpy()) bs.append(model.b.numpy()) current_loss = loss(model(inputs), outputs) train(model, inputs, outputs, learning_rate=0.1) print('Epoch %2d: W=%1.2f b=%1.2f, loss=%2.5f' % (epoch, Ws[-1], bs[-1], current_loss)) # Let's plot it all plt.plot(epochs, Ws, 'r', epochs, bs, 'b') plt.plot([TRUE_W] * len(epochs), 'r--', [TRUE_b] * len(epochs), 'b--') plt.legend(['W', 'b', 'true W', 'true_b']) plt.show()

 

最后

以上就是苗条白云最近收集整理的关于tensorflow自定义神经网络模型的全部内容,更多相关tensorflow自定义神经网络模型内容请搜索靠谱客的其他文章。

本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
点赞(101)

评论列表共有 0 条评论

立即
投稿
返回
顶部