Keras/Tensorflow subclassing API 实现的model,如何plot_model
class my_model(Model): def __init__(self, dim): super(my_model, self).__init__() self.Base = VGG16(input_shape=(dim), include_top = False, weights = 'imagenet') self.GAP = L.GlobalAveragePooling2D() self.BAT = L.Ba