1 系统环境
硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.4.1
执行模式(PyNative/ Graph):不限
Python版本: Python=3.9
操作系统平台: Ubuntu 22.04
2 报错信息
2.1 问题描述
torch.where()参数condition和mindspore.mint.where()参数condition在需要多个逻辑表达式情况下的写法不一样。
对于需要两个逻辑表达式的情况,如:
(neg >= -delta) & (neg <= delta)
需要使用ops.logical_and接口来实现。
比如参数condition为:
condition = ops.logical_and((neg >= -delta), (neg <= delta))
如果参数condition使用如下代码会报错:
condition = (neg >= -delta) & (neg <= delta)
2.2 脚本信息
def test_where_value_and_grad():
def ternarize_torch(tensor):
delta = get_delta_torch(tensor)
alpha = get_alpha_torch(tensor, delta)
pos = torch.where(tensor > delta, 1, tensor)
neg = torch.where(pos<-delta, -1, pos)
ternary = torch.where((neg >= -delta) & (neg <= delta), 0, neg)
return ternary * alpha
def get_alpha_torch(tensor, delta):
ndim = len(tensor.shape)
view_dims = (-1,) + (ndim - 1)*(1,)
i_delta = (torch.abs(tensor) > delta)
i_delta_count = i_delta.view(i_delta.shape[0], -1).sum(1)
tensor_thresh = torch.where((i_delta), tensor, 0)
alpha = (1 / i_delta_count)*(torch.abs(tensor_thresh.view(tensor.shape[0], -1)).sum(1))
alpha = alpha.view(view_dims)
return alpha
def get_delta_torch(tensor):
ndim = len(tensor.shape)
view_dims = (-1,) + (ndim - 1) * (1,)
n = tensor[0].nelement()
norm = tensor.norm(1, ndim - 1).view(tensor.shape[0], -1)
norm_sum = norm.sum(1)
delta = (0.75 / n) * norm_sum
return delta.view(view_dims)
def ternarize_ms(tensor):
delta = get_delta_ms(tensor)
alpha = get_alpha_ms(tensor, delta)
pos = mint.where(tensor > delta, 1, tensor)
neg = mint.where(pos<-delta, -1, pos)
condition = ops.logical_and((neg >= -delta), (neg <= delta))
#condition = (neg >= -delta) & (neg <= delta)
ternary = mint.where(condition, 0, neg)
return ternary * alpha
def get_alpha_ms(tensor, delta):
ndim = len(tensor.shape)
view_dims = (-1,) + (ndim - 1)*(1,)
i_delta = (ops.abs(tensor) > delta)
i_delta_count = i_delta.view(i_delta.shape[0], -1).sum(1)
tensor_thresh = mint.where((i_delta), tensor, 0)
alpha = (1 / i_delta_count)*(ops.abs(tensor_thresh.view(tensor.shape[0], -1)).sum(1))
alpha = alpha.view(view_dims)
return alpha
def get_delta_ms(tensor):
ndim = len(tensor.shape)
view_dims = (-1,) + (ndim - 1) * (1,)
n = tensor[0].nelement()
norm = tensor.norm(1, ndim - 1).view(tensor.shape[0], -1)
norm_sum = norm.sum(1)
delta = (0.75 / n) * norm_sum
return delta.view(view_dims)
x_np = np.random.randn(3, 3).astype(np.float32)
x_torch = torch.from_numpy(x_np)
x_torch.requires_grad = True
x_ms = Parameter(Tensor.from_numpy(x_np))
y_torch = ternarize_torch(x_torch)
y_ms = ternarize_ms(x_ms)
print(f"<value torch> {y_torch}\n"
f"<value_ms> {y_ms}")
if np.allclose(y_torch.detach().numpy(), y_ms.asnumpy(), atol=1e-3):
print("[value] WITHIN TOLERANCE (from test_stack_value_and_grad): "
"value discrepancy between torch.stack "
"mindspore.mint.stack is less than 1e-3.")
else:
print("[value] BEYOND TOLERANCE (from test_stack_value_and_grad): "
"value discrepancy is beyond tolerance.")
print('='*50)
y_sum = y_torch.sum()
y_sum.backward()
grad_ms = ms.grad(ternarize_ms, grad_position=0)(x_ms)
print(f"<grad torch> {x_torch.grad}\n"
f"<grad> {grad_ms}")
if np.allclose(x_torch.grad.numpy(), grad_ms.asnumpy(), atol=1e-3):
print("[grad] WITHIN TOLERANCE (from test_stack_value_and_grad): "
"grad discrepancy between torch.stack "
"mindspore.mint.stack is less than 1e-3.")
else:
print("[grad] BEYOND TOLERANCE (from test_stack_value_and_grad): "
"grad discrepancy is beyond tolerance.")
print('='*50)
2.3 报错信息
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[4], line 1
----> 1 test_where_value_and_grad()
Cell In[2], line 70, in test_where_value_and_grad()
67 y_torch = ternarize_torch(x_torch)
68 y_ms = ternarize_ms(x_ms)
---> 70 print(f"<value torch> {y_torch}\n"
71 f"<value_ms> {y_ms}")
72 if np.allclose(y_torch.detach().numpy(), y_ms.asnumpy(), atol=1e-3):
73 print("[value] WITHIN TOLERANCE (from test_stack_value_and_grad): "
74 "value discrepancy between torch.stack "
75 "mindspore.mint.stack is less than 1e-3.")
File ~/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/common/_stub_tensor.py:48, in _stub_method.<locals>.fun(*arg, **kwargs)
46 def fun(*arg, **kwargs):
47 stub = arg[0]
---> 48 arg = (stub.stub_sync(),) + arg[1:]
49 return method(*arg, **kwargs)
File ~/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/common/_stub_tensor.py:159, in StubTensor.stub_sync(self)
157 """sync real tensor."""
158 if self.stub:
--> 159 val = self.stub.get_value()
160 self.tensor = Tensor(val, internal=True)
161 if hasattr(self, "member_cache"):
TypeError: The supported input and output data types for the current operator are: node is Default/BitwiseAnd-op1
InputDesc [0] support {int16,int32,int64,int8,uint16,uint32,uint64,uint8,}
InputDesc [1] support {int16,int32,int64,int8,uint16,uint32,uint64,uint8,}
OutputDesc [0] support {int16,int32,int64,int8,uint16,uint32,uint64,uint8,}
But current operator's input and output data types is:
InputDesc [0] is Bool
InputDesc [1] is Bool
OutputDesc [0] is Bool
----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_ascend.cc:586 HandleKernelSelectFailure
[WARNING] DEVICE(3671,ffff657fa1e0,python):2024-12-25-16:48:33.998.529 [mindspore/ccsrc/transform/acl_ir/op_api_convert.h:114] GetOpApiFunc] Dlsym aclSetAclOpExecutorRepeatable failed!
[WARNING] KERNEL(3671,ffff657fa1e0,python):2024-12-25-16:48:33.998.585 [mindspore/ccsrc/transform/acl_ir/op_api_cache.h:54] SetExecutorRepeatable] The aclSetAclOpExecutorRepeatable is unavailable, which results in aclnn cache miss.
[WARNING] DEVICE(3671,ffff62ff91e0,python):2024-12-25-16:48:34.028.472 [mindspore/ccsrc/transform/acl_ir/op_api_convert.h:114] GetOpApiFunc] Dlsym aclDestroyAclOpExecutor failed!
3 根因分析
当condition = (neg >= -delta) & (neg <= delta)
的时候
因为有操作符&
,框架会把这个操作符自动转换成ops.bitwise_and
算子
而该算子当前并不支持bool
类型,因此会报TypeError: The supported input and output data types for the current operator are: node is Default/BitwiseAnd-op1
4 解决方案
用ops.logical_and
算子替换操作符&
,该算子可以支持bool
类型
ops.logical_and((neg >= -delta), (neg <= delta))