概述
问题描述:
【功能模块】
nn.metric
【操作步骤&问题现象】
1、对训练过的模型进行验证时,报错:ValueError: Classification case, dims of y_pred equal dims of y add 1, but got y_pred: 2 dims and y: 2 dims
2、
【截图信息】
def _check_shape(self, y_pred, y):
"""
Checks the shapes of y_pred and y.
Args:
y_pred (Tensor): Predict array.
y (Tensor): Target array.
"""
if self._type == 'classification':
if y_pred.ndim != y.ndim + 1:
raise ValueError('Classification case, dims of y_pred equal dims of y add 1, '
'but got y_pred: {} dims and y: {} dims'.format(y_pred.ndim, y.ndim))
if y.shape != (y_pred.shape[0],) + y_pred.shape[2:]:
raise ValueError('Classification case, y_pred shape and y shape can not match. '
'got y_pred shape is {} and y shape is {}'.format(y_pred.shape, y.shape))
解答:
`y_pred` stands for logits, `y` stands for labels. `y_pred` and `y` must be a `Tensor`, a list or an array.
For the 'classification' evaluation type:
`y_pred` is a list of floating numbers in range :math:`[0, 1]` and the shape is :math:`(N, C)` in most cases (not strictly), where :math:`N` is the number of cases and :math:`C` is the number of categories.
`y` must be in one-hot format that shape is :math:`(N, C)`, or can be transformed to one-hot format that shape is :math:`(N,)`.
For 'multilabel' evaluation type, the value of `y_pred` and `y` can only be 0 or 1, indices with 1 indicate the positive category. The shape of `y_pred` and `y` are both :math:`(N, C)`.
可以参考链接:
https://gitee.com/sun_zhongqian/mindspore/blob/r1.6/mindspore/python/mindspore/nn/metrics/accuracy.py
https://www.mindspore.cn/docs/api/zh-CN/r1.6/api_python/nn/mindspore.nn.Accuracy.html
自定义调试信息 — MindSpore master documentation
可以参考这部分代码去实现model.eval()
from mindspore import Tensor
from mindspore.nn import Accuracy
import numpy as np
x = Tensor(np.array([[0.2, 0.5], [0.3, 0.1], [0.9, 0.6]]))
y = Tensor(np.array([1, 0, 1]))
metric = Accuracy()
metric.clear()
metric.update(x, y)
accuracy = metric.eval()
print('Accuracy is ', accuracy)
最后
以上就是勤恳斑马为你收集整理的【mindspore产品】【model.功能】预测值维度为什么要比实际值维度多一维的全部内容,希望文章能够帮你解决【mindspore产品】【model.功能】预测值维度为什么要比实际值维度多一维所遇到的程序开发问题。
如果觉得靠谱客网站的内容还不错,欢迎将靠谱客网站推荐给程序员好友。
发表评论 取消回复