使用mindspore.ops.MaxPool3D算子设置为ceil_mode=True时,在MindSpore1.8.1和1.9.0版本中计算结果不一致

1 系统环境

硬件环境(Ascend/GPU/CPU): Ascend/GPU/CPU
MindSpore版本: 1.8.1、1.9.0
执行模式(PyNative/ Graph): 不限模式
Python版本: 3.9.13
操作系统平台: win10

2 报错信息

2.1 问题描述

mindspore 1.9.0 mindspore.ops.MaxPool3Dceil_mode=True时计算结果出现问题

2.2 报错信息

1.8.1
mindspore version: 1.8.1
input shape: (4, 3, 6, 6, 5)
res shape: (4, 3, 3, 3, 2)
no ceil res shape: (4, 3, 2, 2, 2)

1.9.0
mindspore version: 1.9.0
input shape: (4, 3, 6, 6, 5)
res shape: (4, 3, 2, 2, 2)
no ceil res shape: (4, 3, 2, 2, 2)

2.3 脚本代码

import mindspore as ms
import numpy as np

print(f"mindspore version: {ms.__version__}")
a = ms.Tensor(np.random.randn(4, 3, 6, 6, 5), ms.float32)
print(f"input shape: {a.shape}")
maxpool = ms.ops.MaxPool3D(kernel_size=3, strides=2, pad_mode="pad", pad_list=0, ceil_mode=True)
res = maxpool(a)
print(f"res shape: {res.shape}")
maxpool = ms.ops.MaxPool3D(kernel_size=3, strides=2, pad_mode="pad", pad_list=0, ceil_mode=False)
res = maxpool(a)
print(f"no ceil res shape: {res.shape}")

3 根因分析

1.8版本不存在mindspore/core/ops/max_pool3d.cc


1.9版本开始移植python侧infer到cpp侧,python 中整数相除会保存浮点数结果,cpp 中忽略了此处细节,使用int型保存结果,造成整除,ceil失去意义。
代码如下

    if (pad_mode == PadMode::VALID) {  
    out_d = in_d == -1 ? -1 : MaxPool3DCeilDiv((in_d - (kernel_d - 1)), stride_d);  
    out_h = in_h == -1 ? -1 : MaxPool3DCeilDiv((in_h - (kernel_h - 1)), stride_h);  
    out_w = in_w == -1 ? -1 : MaxPool3DCeilDiv((in_w - (kernel_w - 1)), stride_w);  
    } else if (pad_mode == PadMode::SAME) {  
    out_d = in_d == -1 ? -1 : MaxPool3DCeilDiv(in_d, stride_d);  
    out_h = in_h == -1 ? -1 : MaxPool3DCeilDiv(in_h, stride_h);  
    out_w = in_w == -1 ? -1 : MaxPool3DCeilDiv(in_w, stride_w);  
    } else {  
    double out_d_tmp =  
        in_d == -1  
        ? -1  
        : static_cast<double>(in_d + pad_list[kInputIndex0] + pad_list[kInputIndex1] - kernel_d) / stride_d + 1;  
    double out_h_tmp =  
        in_h == -1  
        ? -1  
        : static_cast<double>(in_h + pad_list[kInputIndex2] + pad_list[kInputIndex3] - kernel_h) / stride_h + 1;  
    double out_w_tmp =  
        in_w == -1  
        ? -1  
        : static_cast<double>(in_w + pad_list[kInputIndex4] + pad_list[kInputIndex5] - kernel_w) / stride_w + 1;  
    
    if (ceil_mode) {  
        out_d = DoubleToLong(std::ceil(out_d_tmp));  
        out_h = DoubleToLong(std::ceil(out_h_tmp));  
        out_w = DoubleToLong(std::ceil(out_w_tmp));  
    } else {  
        out_d = DoubleToLong(std::floor(out_d_tmp));  
        out_h = DoubleToLong(std::floor(out_h_tmp));  
        out_w = DoubleToLong(std::floor(out_w_tmp));  
    }

int64_t out_d = 0;  
int64_t out_h = 0;  
int64_t out_w = 0;

out_d out_h out_w都是定义的int类型,后面的(in_d + pad_list[kInputIndex0] + pad_list[kInputIndex1] - kernel_d) / stride_d + 1返回的直接就是去掉小数的整数,因此后面再做ceil和floor操作已经没有意义了

4 解决方案

可以在Cpp infer中使用浮点数保存中间计算结果,确保ceil生效,和python侧逻辑保持一致。

if (pad_mode == PadMode::VALID) {  
out_d = in_d == -1 ? -1 : MaxPool3DCeilDiv((in_d - (kernel_d - 1)), stride_d);  
out_h = in_h == -1 ? -1 : MaxPool3DCeilDiv((in_h - (kernel_h - 1)), stride_h);  
out_w = in_w == -1 ? -1 : MaxPool3DCeilDiv((in_w - (kernel_w - 1)), stride_w);  
} else if (pad_mode == PadMode::SAME) {  
out_d = in_d == -1 ? -1 : MaxPool3DCeilDiv(in_d, stride_d);  
out_h = in_h == -1 ? -1 : MaxPool3DCeilDiv(in_h, stride_h);  
out_w = in_w == -1 ? -1 : MaxPool3DCeilDiv(in_w, stride_w);  
} else {  
double out_d_tmp =  
    in_d == -1  
    ? -1  
    : static_cast<double>(in_d + pad_list[kInputIndex0] + pad_list[kInputIndex1] - kernel_d) / stride_d + 1;  
double out_h_tmp =  
    in_h == -1  
    ? -1  
    : static_cast<double>(in_h + pad_list[kInputIndex2] + pad_list[kInputIndex3] - kernel_h) / stride_h + 1;  
double out_w_tmp =  
    in_w == -1  
    ? -1  
    : static_cast<double>(in_w + pad_list[kInputIndex4] + pad_list[kInputIndex5] - kernel_w) / stride_w + 1;  

if (ceil_mode) {  
    out_d = DoubleToLong(std::ceil(out_d_tmp));  
    out_h = DoubleToLong(std::ceil(out_h_tmp));  
    out_w = DoubleToLong(std::ceil(out_w_tmp));  
} else {  
    out_d = DoubleToLong(std::floor(out_d_tmp));  
    out_h = DoubleToLong(std::floor(out_h_tmp));  
    out_w = DoubleToLong(std::floor(out_w_tmp));  
}