1 系统环境
硬件环境(Ascend/GPU/CPU): CPU
MindSpore版本: 2.2.10
执行模式(PyNative/ Graph): 不限
2 报错信息
2.1 问题描述
将一个原先使用torch网络的模型迁移到mindspore环境下训练,出现了内存过度增长的问题。迁移到mindspore下训练时,内存占用会以每秒约20M的速度迅速上升,直到主机内存溢出。
2.2 脚本信息
MindSpore脚本
for _ in range(self.K_epochs):
for index in CustomSampler(self.batch_size, self.mini_batch_size):
grad_fn1 = mindspore.value_and_grad(forward_fn1, None, self.optimizer_actor.parameters)
actor_loss, grads1 = grad_fn1(s[index], a[index], a_logprob[index], adv[index])
if self.use_grad_clip: # Trick 7: Gradient clip
grads1 = ops.clip_by_norm(grads1, 0.5)
self.optimizer_actor(grads1)
# Update critic
grad_fn2 = mindspore.value_and_grad(forward_fn2, None, self.optimizer_critic.parameters)
critic_loss, grads2 = grad_fn2(s[index], a[index], v_target[index])
self.critic_loss.append(critic_loss.item())
if self.use_grad_clip: # Trick 7: Gradient clip
grads2 = ops.clip_by_norm(grads2, 0.5)
self.optimizer_critic(grads2)
Torch脚本:
for _ in range(self.K_epochs):
for index in BatchSampler(SubsetRandomSampler(range(self.batch_size)), self.mini_batch_size, False):
dist_now = self.actor.get_dist(s[index])
dist_entropy = dist_now.entropy().sum(1, keepdim=True)
a_logprob_now = dist_now.log_prob(a[index])
# a/b=exp(log(a)-log(b)) In multi-dimensional continuous action space,we need to sum up the log_prob
ratios = torch.exp(a_logprob_now.sum(1, keepdim=True) - a_logprob[index].sum(1, keepdim=True)) # shape(mini_batch_size X 1)
surr1 = ratios * adv[index] # Only calculate the gradient of 'a_logprob_now' in ratios
surr2 = torch.clamp(ratios, 1 - self.epsilon, 1 + self.epsilon) * adv[index]
actor_loss = -torch.min(surr1, surr2) - self.entropy_coef * dist_entropy # Trick 5: policy entropy
# Update actor
self.optimizer_actor.zero_grad()
actor_loss.mean().backward()
if self.use_grad_clip: # Trick 7: Gradient clip
torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 0.5)
self.optimizer_actor.step()
v_s = self.critic(s[index], a[index])
critic_loss = F.mse_loss(v_target[index], v_s)
self.critic_loss.append(critic_loss.item())
# Update critic
self.optimizer_critic.zero_grad()
critic_loss.backward()
if self.use_grad_clip: # Trick 7: Gradient clip
torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 0.5)
self.optimizer_critic.step()
3 根因分析
上述的mindspore代码中有两个grad_fn计算,应该会占用较多计算资源,可以将两个grad_fn合并计算。使用的mindspore.value_and_grad
4 解决方案
上述的mindspore代码中有两个grad_fn使用的mindspore.value_and_grad,并且传入到参数相似,可进行合并。
首先需要先将前向计算的两个结果进行合并,然后再传入同一个grad_fn函数。
for _ in range(self.K_epochs):
for index in CustomSampler(self.batch_size, self.mini_batch_size):
grad_fn = mindspore.value_and_grad(forward_fn1+forward_fn2, None, self.optimizer_actor.parameters)
actor_loss, grads = grad_fn(s[index], a[index], a_logprob[index], adv[index])
if self.use_grad_clip: # Trick 7: Gradient clip
grads = ops.clip_by_norm(grads, 0.5)
self.optimizer_actor(grads)