概述
第一篇:将 PyTorch 版的 BERT 模型转换成 Tensorflow 版的 BERT 模型(1)
上一篇,我们分析了 convert_pytorch_checkpoint_to_tf.py
文件中 main()
的参数解析,本篇,我们从模型加载入手。
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir)
BertModel
调用的 from_pretrained()
方法是其父类 PreTrainedModel
中的方法。该方法的作用是:从预训练模型的配置文件实例化 PyTorch 版的预训练模型。
【说明】:以下仅给出和 PyTorch->tf 相关的函数。
PreTrainedModel
类有 4 个全局变量,一个初始化函数、一个 from_pretrained()
函数
class PreTrainedModel(nn.Module):
# 4个全局变量
config_class = None
pretrained_model_archive_map = {}
load_tf_weights = lambda model, config, path: None
base_model_prefix = ""
# 初始化函数
def __init__(self, config, *inputs, **kwargs):
super(PreTrainedModel, self).__init__()
# 判断 config 是否是继承自 PreTrainedConfig 类
if not isinstance(config, PretrainedConfig):
raise ValueError(
"Parameter config in `{}(config)` should be an instance of class `PretrainedConfig`. "
"To create a model from a pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(self.__class__.__name__, self.__class__.__name__))
# Save config in model
self.config = config
【小贴士】
python的 lambda 表达式(匿名函数):https://www.runoob.com/python3/python3-function.html
lambda和map() 函数:https://mp.weixin.qq.com/s/GDC3GeTPXspInK_1DPyuVA
isinstance() 函数:https://www.runoob.com/python/python-func-isinstance.html
raise 语句抛出一个特定异常:https://www.runoob.com/python3/python3-errors-execptions.html
python 常用的内建属性:https://blog.csdn.net/qq_26442553/article/details/82464682
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop('config', None)
state_dict = kwargs.pop('state_dict', None)
cache_dir = kwargs.pop('cache_dir', None)
from_tf = kwargs.pop('from_tf', False)
force_download = kwargs.pop('force_download', False)
proxies = kwargs.pop('proxies', None)
output_loading_info = kwargs.pop('output_loading_info', False)
三个重要参数解析:pretrained_model_name_or_path
、config
、state_dict
pretrained_model_name_or_path 参数是必要的,该参数可以是以下几种形式:
- 一个用来从缓存中加载或下载预训练模型的简称,例如:bert-base-uncased。
- 指向一个目录的路径,该目录包含了由
~pytorch_transformers.PreTrainedModel.save_pretrained 函数保存的模型权重。
- 一个指向 tensorflow index checkpoint 文件的路径或URL,例如 ./tf_model/model.ckpt.index ,在这一情况下,参数 from_tf 应设置成 True,同时提供 configuration 对象实例作为参数 config 的值。使用这种方式加载,速度会较慢。
config 参数是可选的,该参数是某个类(该类继承自PretrainedConfig)的实例。模型需要调用配置文件,而不会自动加载,当出现以下几种情况时,配置文件会被自动加载:
- 某个库提供了模型(使用预训练模型的简写名称来加载)
- 使用 ~pytorch_transformers.PreTrainedModel.save_pretrained 函数保存的模型,以及从保存目录中重新加载的模型。
- 提供一个本地目录作为 pretrained_model_name_or_path 的属性值,并且在该目录下有个名为 config.json 的JSON文件,以这样的方式加载模型。
state_dict 参数是可选的,一个供模型使用的可选状态字典,如果你想从预训练配置文件中创建模型并且加载你自己的权重,可以使用该参数。
【小贴士】
@classmethod
解析:https://www.runoob.com/python/python-func-classmethod.html、https://www.runoob.com/note/33690
不定长参数:https://www.runoob.com/python3/python3-function.html、https://blog.csdn.net/u010376788/article/details/49933511、https://blog.csdn.net/m0_38024592/article/details/82801674
pop() 函数:https://www.runoob.com/python/python-att-dictionary-pop.html
- 加载配置文件
# Load config
# 如果调用from_pretrained()方法时,没有config参数,那么从对应模型的 XxxConfig 中获取配置文件。
if config is None:
# 调用的是 modeling_utils.py 文件中的 PretrainedConfig 类的 from_pretrained() 方法
config, model_kwargs = cls.config_class.from_pretrained(
pretrained_model_name_or_path,
*model_args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
**kwargs)
else:
# 若有 config 参数,那么说明 对应模型的 XxxConfig 已经调用其父类 PretrainedConfig 类中的 from_json_file() 方法来获取 配置文件
model_kwargs = kwargs
- 加载模型
# Load model
# 1)如果给出的是模型的名字,那么判断是否在 pretrained_model_archive_map 中
# 若存在,根据 pretrained_model_name 获取到 模型的下载链接
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
# archive_file 模型的下载链接(在线下载)
archive_file = cls.pretrained_model_archive_map[pretrained_model_name_or_path]
# 2)如果给出的是一个指向目录的路径,首先判断路径给出的是否是一个目录
elif os.path.isdir(pretrained_model_name_or_path):
# 如果给出的是目录,那么判断是否是直接从 tf 中加载
# 如果 from_tf 为 true ,那么表示的是直接加载 tf 版的 checkpoint,同时将目录和文件名拼接起来
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = os.path.join(pretrained_model_name_or_path, TF_WEIGHTS_NAME + ".index")
# TF_WEIGHTS_NAME = 'model.ckpt'
else:
# 如果 from_tf 为 false,那么代表的是加载 pytorch_model.bin,同时将目录和文件名拼接起来
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
# WEIGHTS_NAME = "pytorch_model.bin"
else:
# 3)给出的是模型文件所在的路径(不是目录)
# 同样,要判断 from_tf,再决定加载模型
if from_tf:
# Directly load from a TensorFlow checkpoint
archive_file = pretrained_model_name_or_path + ".index"
else:
archive_file = pretrained_model_name_or_path
try:
# redirect to the cache, if necessary
# 返回的可能是 URL(从该URL获取缓存) 或者是 文件名(文件必须存在)
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir, force_download=force_download, proxies=proxies)
except EnvironmentError:
# 如果报错,那么首先判断 pretrained_model_name 是否存在于 map 中,若存在,logger 打印错误
if pretrained_model_name_or_path in cls.pretrained_model_archive_map:
logger.error("Couldn't reach server at '{}' to download pretrained weights.".format(archive_file))
else:
# 除去第一种情况,logger 打印剩下的几种情况出现的错误
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name_or_path,
', '.join(cls.pretrained_model_archive_map.keys()),
archive_file))
return None
# 若出现错误,此时,不返回任何 model
- 打印日志信息
# 如果从 cached_path() 返回的是 文件名,那么判断是否和 archive_file 一样
if resolved_archive_file == archive_file:
logger.info("loading weights file {}".format(archive_file))
else:
# 从 cache 中加载权重文件
logger.info("loading weights file {} from cache at {}".format(
archive_file, resolved_archive_file))
- 实例化模型
# cls() 表示实例化 PreTrainedModel 这个类,
# 由于 PreTrainedModel 中的初始化函数 __init(self, config, *inputs, **kwargs)__包含三个参数,所以这里提供三个参数。
model = cls(config, *model_args, **model_kwargs)
- 加载状态字典
if state_dict is None and not from_tf:
state_dict = torch.load(resolved_archive_file, map_location='cpu')
- 加载 tf 权重
if from_tf:
# Directly load from a TensorFlow checkpoint
return cls.load_tf_weights(model, config, resolved_archive_file[:-6])
# Remove the '.index'
- 转换 PyTorch 版的
state_dict
格式
old_keys = []
new_keys = []
for key in state_dict.keys():
new_key = None
if 'gamma' in key:
new_key = key.replace('gamma', 'weight')
if 'beta' in key:
new_key = key.replace('beta', 'bias')
if new_key:
old_keys.append(key)
new_keys.append(new_key)
for old_key, new_key in zip(old_keys, new_keys):
state_dict[new_key] = state_dict.pop(old_key)
【小贴士】
zip() 函数:https://www.runoob.com/python/python-func-zip.html
- 从 PyTorch 版的
state_dict
加载
# Load from a PyTorch state_dict
missing_keys = []
unexpected_keys = []
error_msgs = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, '_metadata', None)
state_dict = state_dict.copy()
if metadata is not None:
state_dict._metadata = metadata
def load(module, prefix=''):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
module._load_from_state_dict(
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + '.')
# Make sure we are able to load base models as well as derived(衍生的) models (with heads)
start_prefix = ''
model_to_load = model
if not hasattr(model, cls.base_model_prefix) and any(
s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
start_prefix = cls.base_model_prefix + '.'
if hasattr(model, cls.base_model_prefix) and not any(
s.startswith(cls.base_model_prefix) for s in state_dict.keys()):
model_to_load = getattr(model, cls.base_model_prefix)
load(model_to_load, prefix=start_prefix)
【小贴士】
getattr() 函数:https://www.runoob.com/python/python-func-getattr.html
copy() 函数:https://www.runoob.com/python/att-dictionary-copy.html、https://www.runoob.com/w3cnote/python-understanding-dict-copy-shallow-or-deep.html
hasattr() 函数:https://www.runoob.com/python/python-func-hasattr.html
any() 函数:https://www.runoob.com/python/python-func-any.html
- 打印日志信息
if len(missing_keys) > 0:
logger.info("Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys))
if len(unexpected_keys) > 0:
logger.info("Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys))
if len(error_msgs) > 0:
raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format(
model.__class__.__name__, "nt".join(error_msgs)))
- 绑定词嵌入权重
if hasattr(model, 'tie_weights'):
model.tie_weights()
# make sure word embedding weights are still tied
- 设置模型的默认模式
# Set model in evaluation mode to desactivate DropOut modules by default
model.eval()
- 返回模型和加载信息
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys, "error_msgs": error_msgs}
return model, loading_info
- 返回最终的模型
# 上述的 实例化模型 代码
# model = cls(config, *model_args, **model_kwargs) 是个入口(Input)
return model
# 这里返回是出口(Output),二者之间的所有过程是 处理(Process)
【总结】:模型加载的全过程就是上述内容,在【加载配置文件】处理中,即该代码:cls.config_class.from_pretrained()
,该操作涉及到 PretrainedConfig
类中的 from_pretrained()
函数,同时,该函数还关联 file_utils.py
类中一些文件操作的工具类,这个放在下一篇解析。
我们完成了 PreTrainedModel
类中有关于 PyTorch->tf 模型加载的解析。
最后
以上就是鲜艳自行车为你收集整理的pytorch checkpoint_将 PyTorch 版的 BERT 模型转换成 Tensorflow 版的 BERT 模型(2)的全部内容,希望文章能够帮你解决pytorch checkpoint_将 PyTorch 版的 BERT 模型转换成 Tensorflow 版的 BERT 模型(2)所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复