MindSpore模型推理时ValueError: shapes (32,10) and (32,10,1,1) not aligned分析和解决

1 系统环境

  • 硬件环境(Ascend/GPU/CPU): Atlas 300i
  • MindSpore版本: mindspore=2.2.1
  • 执行模式(PyNative/ Graph): 不限
  • Python版本: Python=3.9
  • 操作系统平台: Linux

2 报错信息

2.1 问题描述

在模型推理时,预期输出张量形状应为[batch_size, 10](10分类问题),但实际获取到的输出形状为[batch_size, 10, 1, 1],导致后续后处理代码报错。

2.2 脚本信息

# 模型定义关键部分
class CustomNet(nn.Cell):
    def __init__(self):
        super(CustomNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.fc = nn.Dense(64 * 54 * 54, 10)  # 全连接层
        
    def construct(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = self.fc(x)
        return x
# 推理代码
model = load_checkpoint("model.ckpt")
net = CustomNet()
load_param_into_net(net, model)
input_data = Tensor(np.random.rand(32, 3, 224, 224), dtype=mstype.float32)
output = net(input_data)
print("Output shape:", output.shape)  # 期望(32,10),实际得到(32,10,1,1)

2.3 报错信息

ValueError: shapes (32,10) and (32,10,1,1) not aligned

3 根因分析

昇腾芯片与GPU在张量布局和计算规则上存在差异。从以上的情况看,应该是全连接层在昇腾硬件上的特殊处理。当全连接层输入未能完全展平时,昇腾后端可能会保持某些维度为1,而不是自动消除。在Ascend环境下,MindSpore的图编译过程对形状推导更加严格。如果x.view(x.shape[0], -1)这行代码中的展平操作没有完全按设计执行,就会保留额外的维度。

4 解决方案

可以直接修正输出的形状,消除掉多余的维度。

class CustomNet(nn.Cell):
    def __init__(self):
        super(CustomNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
        self.pool = nn.MaxPool2d(kernel_size=2)
        self.fc = nn.Dense(64 * 54 * 54, 10)  # 全连接层
        self.squeeze = ops.Squeeze()  # 添加压缩操作
        
    def construct(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc(x)
        # 确保输出形状正确
        if len(x.shape) > 2:
            x = self.squeeze(x)
        return x

也可以在模型定义中进行设置,在数据进入全连接层前进行完全展平操作。

import mindspore.ops as ops

class CustomNet(nn.Cell):
    def __init__(self):
        super(CustomNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, pad_mode='same')
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc = nn.Dense(64 * 111 * 111, 10)  # 修正:224→112→111
        self.reshape = ops.Reshape()
        self.flatten = nn.Flatten()
        
    def construct(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        # 使用Flatten替代view,更安全
        x = self.flatten(x)
        # 或者使用显式reshape
        # x = self.reshape(x, (x.shape[0], -1))
        x = self.fc(x)
        return x