1.系统环境
硬件环境(Ascend/GPU/CPU): GPU
MindSpore版本: mindspore=2.0
执行模式(PyNative/ Graph):不限
Python版本:3.7
操作系统平台:Linux
2. 问题描述
迁移网络任务-tacotron2时遇到mindsporeAPI binary_cross_entropy_with_logits描述有问题
API描述:
代码段:
objectness_loss = ops.binary_cross_entropy_with_logits(
logits=objectness[sampled_inds],
label=labels[sampled_inds],
weight=1,
pos_weight=1
)
报错信息:
3. 解决方案
需要传Tensor格式信息,改成
weight = Tensor(np.ones((1,len(objectness[sampled_inds])))[0],dtype=ms.float32)
pos_weight = Tensor(np.ones((1,len(objectness[sampled_inds])))[0],dtype=ms.float32)
objectness_loss = ops.binary_cross_entropy_with_logits(
logits=objectness[sampled_inds],
label=labels[sampled_inds],
weight=weight,
pos_weight=pos_weight
)