使用mindspore.mint.where()报错The supported input and output data types for the current operator are: node is Default/Bitwis

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.4.1
执行模式(PyNative/ Graph):不限
Python版本: Python=3.9
操作系统平台: Ubuntu 22.04

2 报错信息

2.1 问题描述

torch.where()参数condition和mindspore.mint.where()参数condition在需要多个逻辑表达式情况下的写法不一样。
对于需要两个逻辑表达式的情况,如:

(neg >= -delta) & (neg <= delta)

需要使用ops.logical_and接口来实现。
比如参数condition为:

condition = ops.logical_and((neg >= -delta), (neg <= delta))

如果参数condition使用如下代码会报错:

condition = (neg >= -delta) & (neg <= delta)

2.2 脚本信息

def test_where_value_and_grad():
    def ternarize_torch(tensor):
        delta = get_delta_torch(tensor)
        alpha = get_alpha_torch(tensor, delta)
        pos = torch.where(tensor > delta, 1, tensor)
        neg = torch.where(pos<-delta, -1, pos)
        ternary = torch.where((neg >= -delta) & (neg <= delta), 0, neg)
        return ternary * alpha

    def get_alpha_torch(tensor, delta):
        ndim = len(tensor.shape)
        view_dims = (-1,) + (ndim - 1)*(1,)
        i_delta = (torch.abs(tensor) > delta)
        i_delta_count = i_delta.view(i_delta.shape[0], -1).sum(1)
        tensor_thresh = torch.where((i_delta), tensor, 0)
        alpha = (1 / i_delta_count)*(torch.abs(tensor_thresh.view(tensor.shape[0], -1)).sum(1))
        alpha = alpha.view(view_dims)
        return alpha

    def get_delta_torch(tensor):
        ndim = len(tensor.shape)
        view_dims = (-1,) + (ndim - 1) * (1,)
        n = tensor[0].nelement()
        norm = tensor.norm(1, ndim - 1).view(tensor.shape[0], -1)
        norm_sum = norm.sum(1)
        delta = (0.75 / n) * norm_sum
        return delta.view(view_dims)

    def ternarize_ms(tensor):
        delta = get_delta_ms(tensor)
        alpha = get_alpha_ms(tensor, delta)
        pos = mint.where(tensor > delta, 1, tensor)
        neg = mint.where(pos<-delta, -1, pos)
        condition = ops.logical_and((neg >= -delta), (neg <= delta))
        #condition = (neg >= -delta) & (neg <= delta)
        ternary = mint.where(condition, 0, neg)
        return ternary * alpha

    def get_alpha_ms(tensor, delta):
        ndim = len(tensor.shape)
        view_dims = (-1,) + (ndim - 1)*(1,)
        i_delta = (ops.abs(tensor) > delta)
        i_delta_count = i_delta.view(i_delta.shape[0], -1).sum(1)
        tensor_thresh = mint.where((i_delta), tensor, 0)
        alpha = (1 / i_delta_count)*(ops.abs(tensor_thresh.view(tensor.shape[0], -1)).sum(1))
        alpha = alpha.view(view_dims)
        return alpha

    def get_delta_ms(tensor):
        ndim = len(tensor.shape)
        view_dims = (-1,) + (ndim - 1) * (1,)
        n = tensor[0].nelement()
        norm = tensor.norm(1, ndim - 1).view(tensor.shape[0], -1)
        norm_sum = norm.sum(1)
        delta = (0.75 / n) * norm_sum
        return delta.view(view_dims)

    x_np = np.random.randn(3, 3).astype(np.float32)
    x_torch = torch.from_numpy(x_np)
    x_torch.requires_grad = True
    x_ms = Parameter(Tensor.from_numpy(x_np))
    y_torch = ternarize_torch(x_torch)
    y_ms = ternarize_ms(x_ms)

    print(f"<value torch> {y_torch}\n"
            f"<value_ms> {y_ms}")
    if np.allclose(y_torch.detach().numpy(), y_ms.asnumpy(), atol=1e-3):
        print("[value] WITHIN TOLERANCE (from test_stack_value_and_grad): "
                "value discrepancy between torch.stack "
                "mindspore.mint.stack is less than 1e-3.")
    else:
        print("[value] BEYOND TOLERANCE (from test_stack_value_and_grad): "
                "value discrepancy is beyond tolerance.")

    print('='*50)
    y_sum = y_torch.sum()
    y_sum.backward()
    grad_ms = ms.grad(ternarize_ms, grad_position=0)(x_ms)
    print(f"<grad torch> {x_torch.grad}\n"
            f"<grad> {grad_ms}")
    if np.allclose(x_torch.grad.numpy(), grad_ms.asnumpy(), atol=1e-3):
        print("[grad] WITHIN TOLERANCE (from test_stack_value_and_grad): "
                "grad discrepancy between torch.stack "
                "mindspore.mint.stack is less than 1e-3.")
    else:
        print("[grad] BEYOND TOLERANCE (from test_stack_value_and_grad): "
                "grad discrepancy is beyond tolerance.")
    print('='*50)

2.3 报错信息

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[4], line 1
----> 1 test_where_value_and_grad()
Cell In[2], line 70, in test_where_value_and_grad()
     67 y_torch = ternarize_torch(x_torch)
     68 y_ms = ternarize_ms(x_ms)
---> 70 print(f"<value torch> {y_torch}\n"
     71       f"<value_ms> {y_ms}")
     72 if np.allclose(y_torch.detach().numpy(), y_ms.asnumpy(), atol=1e-3):
     73     print("[value] WITHIN TOLERANCE (from test_stack_value_and_grad): "
     74           "value discrepancy between torch.stack "
     75           "mindspore.mint.stack is less than 1e-3.")
File ~/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/common/_stub_tensor.py:48, in _stub_method.<locals>.fun(*arg, **kwargs)
     46 def fun(*arg, **kwargs):
     47     stub = arg[0]
---> 48     arg = (stub.stub_sync(),) + arg[1:]
     49     return method(*arg, **kwargs)
File ~/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/common/_stub_tensor.py:159, in StubTensor.stub_sync(self)
    157 """sync real tensor."""
    158 if self.stub:
--> 159     val = self.stub.get_value()
    160     self.tensor = Tensor(val, internal=True)
    161     if hasattr(self, "member_cache"):
TypeError: The supported input and output data types for the current operator are: node is Default/BitwiseAnd-op1
InputDesc [0] support {int16,int32,int64,int8,uint16,uint32,uint64,uint8,}
InputDesc [1] support {int16,int32,int64,int8,uint16,uint32,uint64,uint8,}
OutputDesc [0] support {int16,int32,int64,int8,uint16,uint32,uint64,uint8,}
But current operator's input and output data types is:
InputDesc [0] is Bool
InputDesc [1] is Bool
OutputDesc [0] is Bool
----------------------------------------------------
- C++ Call Stack: (For framework developers)
----------------------------------------------------
mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_ascend.cc:586 HandleKernelSelectFailure
[WARNING] DEVICE(3671,ffff657fa1e0,python):2024-12-25-16:48:33.998.529 [mindspore/ccsrc/transform/acl_ir/op_api_convert.h:114] GetOpApiFunc] Dlsym aclSetAclOpExecutorRepeatable failed!
[WARNING] KERNEL(3671,ffff657fa1e0,python):2024-12-25-16:48:33.998.585 [mindspore/ccsrc/transform/acl_ir/op_api_cache.h:54] SetExecutorRepeatable] The aclSetAclOpExecutorRepeatable is unavailable, which results in aclnn cache miss.
[WARNING] DEVICE(3671,ffff62ff91e0,python):2024-12-25-16:48:34.028.472 [mindspore/ccsrc/transform/acl_ir/op_api_convert.h:114] GetOpApiFunc] Dlsym aclDestroyAclOpExecutor failed!

3 根因分析

condition = (neg >= -delta) & (neg <= delta)的时候
因为有操作符&,框架会把这个操作符自动转换成ops.bitwise_and算子
而该算子当前并不支持bool类型,因此会报TypeError: The supported input and output data types for the current operator are: node is Default/BitwiseAnd-op1

4 解决方案

ops.logical_and算子替换操作符&,该算子可以支持bool类型
ops.logical_and((neg >= -delta), (neg <= delta))