使用mindspore.numpy.broadcast_to 算子报错及解决

1 系统环境

硬件环境(Ascend/GPU/CPU): 不限
MindSpore版本: 2.0.0-alpha
执行模式(PyNative/ Graph): 不限
Python版本: 3.7.5
操作系统平台: 不限

2 报错信息

2.1 问题描述

官方demo 能跑通

import mindspore.numpy as np

x = np.array([1, 2, 3])
output = np.broadcast_to(x, (3, 3))
print(output)

shape改成(3,4)后报错。


import mindspore.numpy as np

x = np.array([1, 2, 3])
output = np.broadcast_to(x, (3, 4))
print(output)

2.2 报错信息

>>> ms.numpy.broadcast_to(A,(3,4))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/ma-user/miniconda3/envs/py38ms1.9/lib/python3.8/site-packages/mindspore/numpy/array_ops.py", line 1290, in broadcast_to
    return _raise_value_error('cannot broadcast with ', shape)
  File "/home/ma-user/miniconda3/envs/py38ms1.9/lib/python3.8/site-packages/mindspore/ops/primitive.py", line 716, in __call__
    return fn(*args)
  File "/home/ma-user/miniconda3/envs/py38ms1.9/lib/python3.8/site-packages/mindspore/numpy/utils_const.py", line 227, in _raise_value_error
    raise ValueError(info + f"{param}")
ValueError: cannot broadcast with (3, 4)

2.3 脚本代码(代码格式,可上传附件)

import mindspore.numpy as np

x = np.array([1, 2, 3])
output = np.broadcast_to(x, (3, 4))
print(output)

3 根因分析

用户代码执行报错的具体原因如下:


代码的报错根因分析是_check_can_broadcast_to函数报错,接下来看下_check_can_broadcast_to具体实现

针对用户使用的样例,我们的实际调用就是_check_can_broadcast_to((3,), (3,4))),
首先分析if 分支代码块,我们的用户是满足要求的
其实我们再看for代码块,reversed是反转的作用,zip是压缩的作用。
最后我们的i 是3,j=4.代码的1258行是判定i 是否等于1或者等于j,显然3即不等于4也不等于1.代码就会返回False。

4 解决方案

针对_check_can_broadcast_to()代码的分析,我们可以的了解到输入np.array的shape(3,),所以我们目标shape(N,3),N可以是任意维度任意数值

import mindspore.numpy as np

>>> output = np.broadcast_to(x, (4,3 ))  
>>> print(output)  
[[1 2 3]  
[1 2 3]  
[1 2 3]  
[1 2 3]]  

>>> output = np.broadcast_to(x, (4,2,3))  
>>> print(output)  
[[[1 2 3]  
[1 2 3]]  
[[1 2 3]  
[1 2 3]]  
[[1 2 3]  
[1 2 3]]  
[[1 2 3]  
[1 2 3]]]