1 系统环境
硬件环境(Ascend/GPU/CPU): Ascend
MindSpore版本: mindspore=2.4.0
执行模式(PyNative/ Graph): Graph
Python版本: Python=3.9.10
操作系统平台: linux
2 报错信息
2.1 问题描述
使用mint.masked_select在图模式下报错Parse Lambda Function Fail. Node type must be Lambda, but got Call.
2.2 报错信息
================================================================================== FAILURES ==================================================================================
_____________________________________________________________________ test_masked_select_forward_back[0] _____________________________________________________________________
mode = 0
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_masked_select_forward_back(mode):
"""测试前向和反向传播,对比梯度"""
ms.set_context(mode=mode)
def forward_ms(x, mask):
return mint.masked_select(x, mask).sum()
def forward_torch(x, mask):
return torch.masked_select(x, mask).sum()
try:
input_data = [[1.0, 6.0, 2.0, 4.0], [7.0, 3.0, 8.0, 2.0], [2.0, 9.0, 11.0, 5.0]]
mask = [
[True, False, True, False],
[False, True, False, True],
[True, False, True, False]
]
ms_tensor, torch_tensor, ms_mask, torch_mask = create_tensors(input_data, ms.float32, torch.float32, mask=mask, requires_grad=True)
grad_fn_ms = value_and_grad(lambda x: forward_ms(x, ms_mask))
output_ms, gradient_ms = grad_fn_ms(ms_tensor)
output_torch = forward_torch(torch_tensor, torch_mask)
output_torch.backward()
compare_results(output_ms, output_torch.detach())
compare_results(gradient_ms, torch_tensor.grad)
except Exception as e:
> raise e
test_masked_select.py:228:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test_masked_select.py:220: in test_masked_select_forward_back
output_ms, gradient_ms = grad_fn_ms(ms_tensor)
../anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/common/api.py:960: in staging_specialize
out = _MindsporeFunctionExecutor(func, hash_obj, dyn_args, process_obj, jit_config)(*args, **kwargs)
../anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/common/api.py:188: in wrapper
results = fn(*arg, **kwargs)
../anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/common/api.py:582: in __call__
raise err
../anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/common/api.py:579: in __call__
phase = self.compile(self.fn.__name__, *args_list, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <mindspore.common.api._MindsporeFunctionExecutor object at 0xffff74011f10>, method_name = 'after_grad'
args = (Tensor(shape=[3, 4], dtype=Float32, value=
[[ 1.00000000e+00, 6.00000000e+00, 2.00000000e+00, 4.00000000e+00],
[ ...00000e+00, 8.00000000e+00, 2.00000000e+00],
[ 2.00000000e+00, 9.00000000e+00, 1.10000000e+01, 5.00000000e+00]]),)
kwargs = {}
compile_args = (Tensor(shape=[3, 4], dtype=Float32, value=
[[ 1.00000000e+00, 6.00000000e+00, 2.00000000e+00, 4.00000000e+00],
[ ...00000e+00, 8.00000000e+00, 2.00000000e+00],
[ 2.00000000e+00, 9.00000000e+00, 1.10000000e+01, 5.00000000e+00]]),)
key_id = '1876509713526961735371593348192768'
generate_name = 'mindspore.ops.composite.base.after_grad./home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/ops/composite/base.py.599.1735371593348192768'
echo_function_name = 'function "after_grad" at the file "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/ops/composite/base.py", line 599'
full_function_name = 'mindspore.ops.composite.base.after_grad./home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/ops/composite/base.py.599'
create_time = '1735371593348192768', key = 0, parameter_ids = ''
phase = 'mindspore.ops.composite.base.after_grad./home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/ops/composite/base.py.599.1735371593348192768.0'
jit_config_dict = {'debug_level': 'RELEASE', 'exc_mode': 'auto', 'infer_boost': 'off', 'jit_level': '', ...}
def compile(self, method_name, *args, **kwargs):
"""Returns pipeline for the given args."""
# Check whether hook function registered on Cell object.
if self.obj and hasattr(self.obj, "_hook_fn_registered"):
if self.obj._hook_fn_registered():
logger.warning(f"For 'Cell', it's not support hook function when using 'jit' decorator. "
f"If you want to use hook function, please use context.set_context to set "
f"pynative mode and remove 'jit' decorator.")
# Chose dynamic shape tensors or actual input tensors as compile args.
compile_args = self._generate_compile_args(args)
key_id = self._get_key_id()
compile_args = get_auto_dynamic_shape_args_with_check_input_signature(compile_args, key_id,
self.input_signature)
# Restore the mutable attr for every arg.
compile_args = _restore_mutable_attr(args, compile_args)
self._compile_args = compile_args
generate_name, echo_function_name = self._get_generate_name()
# The full Function name
full_function_name = generate_name
create_time = ''
# Add key with obj
if self.obj is not None:
if self.obj.__module__ != self.fn.__module__:
logger.info(
f'The module of `self.obj`: `{self.obj.__module__}` is not same with the module of `self.fn`: '
f'`{self.fn.__module__}`')
self.obj.__parse_method__ = method_name
if isinstance(self.obj, ms.nn.Cell):
generate_name = generate_name + '.' + str(self.obj.create_time)
create_time = str(self.obj.create_time)
else:
generate_name = generate_name + '.' + str(self._create_time)
create_time = str(self._create_time)
generate_name = generate_name + '.' + str(id(self.obj))
full_function_name = generate_name
else:
# Different instance of same class may use same memory(means same obj_id) at diff times.
# To avoid unexpected phase matched, add create_time to generate_name.
generate_name = generate_name + '.' + str(self._create_time)
create_time = str(self._create_time)
self.enable_tuple_broaden = False
if hasattr(self.obj, "enable_tuple_broaden"):
self.enable_tuple_broaden = self.obj.enable_tuple_broaden
self._graph_executor.set_enable_tuple_broaden(self.enable_tuple_broaden)
key = self._graph_executor.generate_arguments_key(self.fn, compile_args, kwargs, self.enable_tuple_broaden)
parameter_ids = _get_parameter_ids(args, kwargs)
if parameter_ids != "":
key = str(key) + '.' + parameter_ids
phase = generate_name + '.' + str(key)
update_auto_dynamic_shape_phase_with_check_input_signature(compile_args, key_id, phase, self.input_signature)
if phase in ms_compile_cache:
# Release resource should be released when CompileInner won't be executed, such as cur_convert_input_
# generated in generate_arguments_key.
self._graph_executor.clear_compile_arguments_resource()
return phase
_check_recompile(self.obj, compile_args, kwargs, full_function_name, create_time, echo_function_name)
# If enable compile cache, get the dependency files list and set to graph executor.
self._set_compile_cache_dep_files()
if self.jit_config_dict:
self._graph_executor.set_jit_config(self.jit_config_dict)
else:
jit_config_dict = JitConfig().jit_config_dict
self._graph_executor.set_jit_config(jit_config_dict)
if self.obj is None:
# Set an attribute to fn as an identifier.
if isinstance(self.fn, types.MethodType):
setattr(self.fn.__func__, "__jit_function__", True)
else:
setattr(self.fn, "__jit_function__", True)
> is_compile = self._graph_executor.compile(self.fn, compile_args, kwargs, phase, True)
E TypeError: Parse Lambda Function Fail. Node type must be Lambda, but got Call. Please check lambda expression to make sure it is defined on a separate line.
E For example, the code 'func = nn.ReLU() if y < 1 else lambda x: x + 1' rewritten as
E 'if y < 1:
E func = nn.ReLU()
E else:
E func = lambda x: x + 1
E 'will solve the problem.
E
E ----------------------------------------------------
E - Framework Unexpected Exception Raised:
E ----------------------------------------------------
E This exception is caused by framework's unexpected error. Please create an issue at https://gitee.com/mindspore/mindspore/issues to get help.
E
E ----------------------------------------------------
E - C++ Call Stack: (For framework developers)
E ----------------------------------------------------
E mindspore/ccsrc/pipeline/jit/ps/parse/parse.cc:570 ParseFuncGraph
E
E ----------------------------------------------------
E - The Traceback of Net Construct Code:
E ----------------------------------------------------
E # 0 In file /home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/ops/composite/base.py:601, 33~35
E return grad_(fn, weights, grad_position)(*args)
E ^~
E (See file '/home/ma-user/work/rank_0/om/analyze_fail.ir' for more details. Get instructions about `analyze_fail.ir` at https://www.mindspore.cn/search?inputValue=analyze_fail.ir)
../anaconda3/envs/MindSpore/lib/python3.9/site-packages/mindspore/common/api.py:674: TypeError复制
2.3 脚本信息
import pytest
import numpy as np
import mindspore as ms
from mindspore import Tensor, value_and_grad, mint
import torch
def create_tensors(input_data, ms_dtype, torch_dtype, mask=None, requires_grad=False):
ms_tensor = Tensor(input_data, ms_dtype)
torch_tensor = torch.tensor(input_data, dtype=torch_dtype, requires_grad=requires_grad)
if mask is not None:
mask = np.array(mask, dtype=bool)
ms_mask = Tensor(mask, ms.bool_)
torch_mask = torch.tensor(mask, dtype=torch.bool)
else:
ms_mask = None
torch_mask = None
return ms_tensor, torch_tensor, ms_mask, torch_mask
def perform_masked_select(ms_tensor, torch_tensor, ms_mask, torch_mask):
ms_result = mint.masked_select(ms_tensor, ms_mask)
torch_result = torch.masked_select(torch_tensor, torch_mask)
return ms_result, torch_result
def compare_results(ms_result, torch_result, atol=1e-3):
assert np.allclose(ms_result.asnumpy(), torch_result.numpy(), atol=atol)
@pytest.mark.parametrize('mode', [ms.GRAPH_MODE, ms.PYNATIVE_MODE])
def test_masked_select_forward_back(mode):
"""测试前向和反向传播,对比梯度"""
ms.set_context(mode=mode)
def forward_ms(x, mask):
return mint.masked_select(x, mask).sum()
def forward_torch(x, mask):
return torch.masked_select(x, mask).sum()
try:
input_data = [[1.0, 6.0, 2.0, 4.0], [7.0, 3.0, 8.0, 2.0], [2.0, 9.0, 11.0, 5.0]]
mask = [
[True, False, True, False],
[False, True, False, True],
[True, False, True, False]
]
ms_tensor, torch_tensor, ms_mask, torch_mask = create_tensors(input_data, ms.float32, torch.float32, mask=mask, requires_grad=True)
grad_fn_ms = value_and_grad(lambda x: forward_ms(x, ms_mask))
output_ms, gradient_ms = grad_fn_ms(ms_tensor)
output_torch = forward_torch(torch_tensor, torch_mask)
output_torch.backward()
compare_results(output_ms, output_torch.detach())
compare_results(gradient_ms, torch_tensor.grad)
except Exception as e:
raise e
3 根因分析
mindspore.value_and_grad(_fn_ , _grad_position=0_ , _weights=None_ , _has_aux=False_ , _return_ids=False_)
- fn (Union[Cell, Function]) - 待求导的函数或网络。
grad_fn_ms = value_and_grad(lambda x: forward_ms(x, ms_mask))匿名函数一般用于临时函数,用在这里并不合适。
lambda x: forward_ms(x, ms_mask)
相当于定义了一个新函数,参数x 。可以直接把forward_ms 作为fn函数,只是输入要带上ms_mask
4 解决方案
grad_fn_ms = value_and_grad(lambda x: forward_ms(x, ms_mask))
output_ms, gradient_ms = grad_fn_ms(ms_tensor)
修改为
grad_fn_ms = value_and_grad(forward_ms(x, ms_mask))
output_ms, gradient_ms = grad_fn_ms(ms_tensor, ms_mask)
因为给的代码还有其他依赖。
用如下代码可以直接验证。
import mindspore as ms
import numpy as np
from mindspore import ops, mint, value_and_grad, Tensor
import torch
ms.set_context(mode=ms.GRAPH_MODE)
def forward_ms(x, mask):
return mint.masked_select(x, mask).sum()
def forward_torch(x, mask):
return torch.masked_select(x, mask).sum()
try:
input_data = [[1.0, 6.0, 2.0, 4.0], [7.0, 3.0, 8.0, 2.0], [2.0, 9.0, 11.0, 5.0]]
mask = [
[True, False, True, False],
[False, True, False, True],
[True, False, True, False]
]
ms_tensor = Tensor(input_data,dtype=ms.float32)
ms_mask= Tensor(mask,dtype=ms.bool_)
grad_fn_ms = value_and_grad(forward_ms)
output_ms, gradient_ms = grad_fn_ms(ms_tensor,ms_mask)
print(output_ms)
print(gradient_ms)
torch_tensor=torch.from_numpy(np.array(input_data,dtype=np.float32))
torch_tensor.requires_grad=True
torch_mask=torch.from_numpy(np.array(mask,dtype=np.bool_))
output_torch = forward_torch(torch_tensor, torch_mask)
output_torch.backward()
print(output_torch.detach())
print(torch_tensor.grad)
except Exception as e:
raise e
结果输出
