1 系统环境
- 硬件环境(Ascend/GPU/CPU): Atlas 300i
- MindSpore版本: mindspore=2.2.1
- 执行模式(PyNative/ Graph): 不限
- Python版本: Python=3.9
- 操作系统平台: Linux
2 报错信息
2.1 问题描述
在模型推理时,预期输出张量形状应为[batch_size, 10](10分类问题),但实际获取到的输出形状为[batch_size, 10, 1, 1],导致后续后处理代码报错。
2.2 脚本信息
# 模型定义关键部分
class CustomNet(nn.Cell):
def __init__(self):
super(CustomNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
self.pool = nn.MaxPool2d(kernel_size=2)
self.fc = nn.Dense(64 * 54 * 54, 10) # 全连接层
def construct(self, x):
x = self.conv1(x)
x = self.pool(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
# 推理代码
model = load_checkpoint("model.ckpt")
net = CustomNet()
load_param_into_net(net, model)
input_data = Tensor(np.random.rand(32, 3, 224, 224), dtype=mstype.float32)
output = net(input_data)
print("Output shape:", output.shape) # 期望(32,10),实际得到(32,10,1,1)
2.3 报错信息
ValueError: shapes (32,10) and (32,10,1,1) not aligned
3 根因分析
昇腾芯片与GPU在张量布局和计算规则上存在差异。从以上的情况看,应该是全连接层在昇腾硬件上的特殊处理。当全连接层输入未能完全展平时,昇腾后端可能会保持某些维度为1,而不是自动消除。在Ascend环境下,MindSpore的图编译过程对形状推导更加严格。如果x.view(x.shape[0], -1)这行代码中的展平操作没有完全按设计执行,就会保留额外的维度。
4 解决方案
可以直接修正输出的形状,消除掉多余的维度。
class CustomNet(nn.Cell):
def __init__(self):
super(CustomNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
self.pool = nn.MaxPool2d(kernel_size=2)
self.fc = nn.Dense(64 * 54 * 54, 10) # 全连接层
self.squeeze = ops.Squeeze() # 添加压缩操作
def construct(self, x):
x = self.conv1(x)
x = self.pool(x)
x = self.flatten(x)
x = self.fc(x)
# 确保输出形状正确
if len(x.shape) > 2:
x = self.squeeze(x)
return x
也可以在模型定义中进行设置,在数据进入全连接层前进行完全展平操作。
import mindspore.ops as ops
class CustomNet(nn.Cell):
def __init__(self):
super(CustomNet, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, pad_mode='same')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc = nn.Dense(64 * 111 * 111, 10) # 修正:224→112→111
self.reshape = ops.Reshape()
self.flatten = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.pool(x)
# 使用Flatten替代view,更安全
x = self.flatten(x)
# 或者使用显式reshape
# x = self.reshape(x, (x.shape[0], -1))
x = self.fc(x)
return x