概述
波士顿房价预测
# !/usr/bin/python
# -*- coding: UTF-8 -*-
import tensorflow as tf
import numpy as np
from tensorflow import keras
import os
class Regressor(keras.layers.Layer):
"""
定义线性回归的类
"""
def __init__(self):
"""
初始化
"""
super(Regressor,self).__init__()
# 定义两个参数w和b
self.w = tf.Variable(tf.random.uniform([13,1]),name = "w")
self.b = tf.Variable(tf.random.uniform([1]),name = "b")
print(self.w.shape, self.b.shape)
print(type(self.w),tf.is_tensor(self.w),self.w.name)
print(type(self.b),tf.is_tensor(self.b),self.b.name)
def __call__(self, x):
"""
把类变成一个可调用对象
:param x:
:return:
"""
# [batch_size,13] [13,1] = [batch_size,1]
x = tf.matmul(x,self.w) + self.b
return
x
def main():
tf.random.set_seed(22)
np.random.seed(22)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 获取数据
(x_train, y_train), (x_val, y_val) = keras.datasets.boston_housing.load_data()
print(x_train[0])
print(y_train[0])
x_train,x_val = x_train.astype(np.float32),x_val.astype(np.float32)
# (400,13) (404,1)
# (102,13) (102,1)
print(x_train.shape,y_train.shape)
print(x_val.shape,y_val.shape)
# 每一个batch_size有64条数据
db_train = tf.data.Dataset.from_tensor_slices((x_train,y_train)).batch(64)
db_val = tf.data.Dataset.from_tensor_slices((x_val,y_val)).batch(102)
# 构建模型
model = Regressor()
# 损失 均方误差函数
criteon = keras.losses.MeanSquaredError()
# 优化器
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
for epoch in range(200):
# 每一次都是一个batch_size
for step,(x,y) in enumerate(db_train):
with tf.GradientTape() as tape:
# 训练[64*13]*[13*1] = [64*1]
logits = model(x)
print(logits)
# 去掉第1维的括号,二维变一维,拉平。
logits = tf.squeeze(logits,axis = 1)
print(logits)
print(y)
# 计算损失
loss = criteon(y,logits)
# 梯度更新
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
print(epoch,'loss:',loss.numpy())
# 没10个epoch测试一次
if epoch % 10 == 0:
for x,y in db_val:
logits = model(x)
logits = tf.squeeze(logits,axis = 1)
loss = criteon(y,logits)
print(epoch,'val loss :',loss.numpy())
if __name__ == '__main__':
main()
最后
以上就是飘逸芹菜为你收集整理的tf2学习 线性回归的全部内容,希望文章能够帮你解决tf2学习 线性回归所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复