迁移网络任务-tacotron2时遇到mindsporeAPI binary_cross_entropy_with_logits描述有问题

1.系统环境

硬件环境(Ascend/GPU/CPU): GPU

MindSpore版本: mindspore=2.0

执行模式(PyNative/ Graph):不限

Python版本:3.7

操作系统平台:Linux

2. 问题描述

迁移网络任务-tacotron2时遇到mindsporeAPI binary_cross_entropy_with_logits描述有问题

API描述:

image.png

代码段:

objectness_loss = ops.binary_cross_entropy_with_logits(  
        logits=objectness[sampled_inds],   
        label=labels[sampled_inds],  
        weight=1,  
        pos_weight=1  
)

报错信息:

3. 解决方案

需要传Tensor格式信息,改成

weight = Tensor(np.ones((1,len(objectness[sampled_inds])))[0],dtype=ms.float32)  
pos_weight = Tensor(np.ones((1,len(objectness[sampled_inds])))[0],dtype=ms.float32)  
objectness_loss = ops.binary_cross_entropy_with_logits(  
        logits=objectness[sampled_inds],   
        label=labels[sampled_inds],  
        weight=weight,  
        pos_weight=pos_weight  
)