概述
import tensorflow as tf
import tensorflow as keras
from tensorflow.keras import layers,optimizers,losses,Sequential
class MyDense(layers.Layer):
def __init__(self,inp_dim,outp_dim):
super(MyDense,self).__init__()
#创建权值张量并添加到管理列表中,设置为需要优化
self.kernel=self.add_variable('w',[inp_dim,outp_dim],trainable=True)
#tf.Variable()类型的类成员变量也会自动纳入张量管理中
#elf.kernel=tf.Variable(tf.random.nromal([inp_dim,outp_dim]),trainable=True)
def call(self,inputs,training=True):
out=inputs@self.kernel
out=tf.nn.relu(out)
return out
net=MyDense(4,3)#创建输入为4,输出为3节点的自定义层
net.variables,net.trainable_variables#查看自定义层的参数列表
#使用自定义网络层
network=Sequential([MyDense(784,256),
MyDense(256,128),
MyDense(128,64),
MyDense(64,32),
MyDense(32,10)])
network.build(input_shape=(None,28*28))#自动创建所有层的内部张量
network.summary()#打印出网络结构和参数量
Model: "sequential_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
my_dense_13 (MyDense) multiple 200704
_________________________________________________________________
my_dense_14 (MyDense) multiple 32768
_________________________________________________________________
my_dense_15 (MyDense) multiple 8192
_________________________________________________________________
my_dense_16 (MyDense) multiple 2048
_________________________________________________________________
my_dense_17 (MyDense) multiple 320
=================================================================
Total params: 244,032
Trainable params: 244,032
Non-trainable params: 0
最后
以上就是明亮煎蛋为你收集整理的自定义网络层的全部内容,希望文章能够帮你解决自定义网络层所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复