场景图是一种图结构,节点代表图像中的实体(物体),边表示实体之间的关系。场景图生成(Scene Graph Generation, SGG)是根据一张给定的图像,生成对应的场景图,这要求模型不仅要识别出图中的实体,还要检测实体之间的语义关系。场景图生成超越了普通的目标检测任务,与视觉关系检测紧密相关,并在多模态任务中展现出潜力。
当前主流的场景图生成方法基于两阶段流程。模型首先使用检测器生成实体候选框,再通过神经网络对物体候选框之间的关系进行分类。这种方式能够实现良好的效果,但面临着模型参数量大、计算复杂度高等(O(n^2))问题。
RelTR提出了一种单阶段的端到端场景图生成方法。在架构上,RelTR基于 encoder-decoder架构。编码器负责提取图像的视觉特征表示,解码器基于预训练的查询特征和多种注意力机制直接预测一组固定数量的关系三元组。
01
论文创新点
对比以往方法,RelTR的主要贡献和创新点如下:
1、对比以往的两阶段场景图生成方法,RelTR 能够通过视觉信息直接生成稀疏的场景图,具备更小的参数和计算复杂度。如下图,RelTR 只需要预测固定数量的关系三元组,不需要预测所有实体对之间的关系。
2、论文设计了一种集合预测损失(set prediction loss),通过基于IoU的分配策略,将预测的三元组与真实标注进行匹配,让模型直接预测三元组成为可能。下面的公式是 RelTR 的损失表达式,后文会进行更细致的介绍。

3、RelTR的三元组解码器能够利用实体解码器输出的实体检测结果,进一步提升主客体的定位与分类精度。
02
RelTR
在这一部分,本文将对RelTR的实现进行更加详细的介绍。
下图是 RelTR 的整体架构。RelTR 基于 encoder-decoder 的架构。RelTR 包括三个核心组件:用于提取视觉特征上下文信息的特征编码器、基于DETR框架捕获实体表征的实体解码器以及三元组解码器。其中RelTR的主要创新在于三元组解码器,接下来我们对其三元组解码器进行介绍。
三元组解码器
1、Subject and Object Queries
RelTR 引入了固定数目的主语查询和宾语查询。这些查询是学习后的 Embedding。主语查询和宾语查询之间一一对应。RelTR 通过 <主语-宾语-关系>来表示稀疏的场景图。此外,在 RelTR 学习过程中,也引入了可训练的 E_o 和 E_t。
2、 多种注意力机制
RelTR 设计了 耦合自注意力机制(CSA)、解耦视觉注意力(DVA)和 解耦实体注意力(DEA) 。三种注意力计算方式分别如下:
1)CSA:

2)DVA:

3)DEA:

其中 Q_s 是主语查询特征表示,Q_o是宾语查询特征表示。E_o 和 E_t 是可学习的主语/宾语编码。
3、set prediction loss
RelTR扩展了DETR的损失计算方式,设计了专门用于三元组检测的集合预测损失。该损失函数通过计算三元组预测(包含主/客体置信度、谓词置信度及边界框IoU)的cost矩阵,将真实三元组分配给最匹配的预测。未被匹配的预测会被赋予<背景-无关系-背景>标签。
针对训练中出现的多个查询聚焦同一真实三元组的问题(如预测A和D同时匹配同一目标,但D因稍高cost被误判为背景),RelTR提出了基于IoU的分配策略改进:当预测的主/客体类别正确且边界框IoU超过阈值时,即使未被匈牙利算法匹配,也不计算该主/客体的预测损失。这一策略有效避免了因局部预测误差(如关系分类错误)导致优质检测结果被错误惩罚的问题。
![]()
03
基于MindSpore实现模型推理
基于MindSpore实现RelTR模型(原代码使用了自定义的Transformer和改进的ResNet,全部展示过长,现展示部分代码,基于MindSpore实现的代码已发布到gitee仓库:https://gitee.com/yanrui2025/ms-rel-tr):
class RelTR(nn.Cell):
""" RelTR: Relation Transformer for Scene Graph Generation """
def __init__(self, backbone, transformer, num_classes, num_rel_classes, num_entities, num_triplets, aux_loss=False, matcher=None):
""" Initializes the model.
Parameters:
backbone: torch module of the backbone to be used. See backbone.py
transformer: torch module of the transformer architecture. See transformer.py
num_classes: number of entity classes
num_entities: number of entity queries
num_triplets: number of coupled subject/object queries
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
"""
super().__init__()
self.num_entities = num_entities
self.transformer = transformer
hidden_dim = transformer.d_model
self.hidden_dim = hidden_dim
self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1, has_bias=True)
self.backbone = backbone
self.aux_loss = aux_loss
self.entity_embed = nn.Embedding(num_entities, hidden_dim*2)
self.triplet_embed = nn.Embedding(num_triplets, hidden_dim*3)
self.so_embed = nn.Embedding(2, hidden_dim) # subject and object encoding
# entity prediction
self.entity_class_embed = nn.Dense(hidden_dim, num_classes + 1)
self.entity_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
# mask head
self.so_mask_conv = nn.SequentialCell([nn.Upsample(size=(28, 28)),
nn.Conv2d(2, 64, kernel_size=3, stride=2, pad_mode='pad', padding=3, has_bias=True),
nn.ReLU(),
nn.BatchNorm2d(64),
nn.MaxPool2d(kernel_size=3, pad_mode='pad', stride=2, padding=1),
nn.Conv2d(64, 32, kernel_size=3, stride=1, pad_mode='pad', padding=1, has_bias=True),
nn.ReLU(),
nn.BatchNorm2d(32)])
self.so_mask_fc = nn.SequentialCell([nn.Dense(2048, 512),
nn.ReLU(),
nn.Dense(512, 128)])
# predicate classification
self.rel_class_embed = MLP(hidden_dim*2+128, hidden_dim, num_rel_classes + 1, 2)
# subject/object label classfication and box regression
self.sub_class_embed = nn.Dense(hidden_dim, num_classes + 1)
self.sub_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
self.obj_class_embed = nn.Dense(hidden_dim, num_classes + 1)
self.obj_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
def construct(self, samples):
""" The forward expects a NestedTensor, which consists of:
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
It returns a dict with the following elements:
- "pred_logits": the entity classification logits (including no-object) for all entity queries.
Shape= [batch_size x num_queries x (num_classes + 1)]
- "pred_boxes": the normalized entity boxes coordinates for all entity queries, represented as
(center_x, center_y, height, width). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "sub_logits": the subject classification logits
- "obj_logits": the object classification logits
- "sub_boxes": the normalized subject boxes coordinates
- "obj_boxes": the normalized object boxes coordinates
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
if isinstance(samples, (list, Tensor)):
samples = nested_tensor_from_tensor_list(samples)
features, pos = self.backbone(samples)
src, mask = features[-1].decompose()
assert mask is not None
hs, hs_t, so_masks, _ = self.transformer(self.input_proj(src), mask, self.entity_embed.embedding_table,
self.triplet_embed.embedding_table, pos[-1], self.so_embed.embedding_table) # bug input_proj
so_masks = self.so_mask_conv(so_masks.view(-1, 2, src.shape[-2],src.shape[-1])).view(hs_t.shape[0], hs_t.shape[1], hs_t.shape[2],-1)
so_masks = self.so_mask_fc(so_masks)
split = ops.Split(axis=-1, output_num=2)
hs_sub, hs_obj = split(hs_t)
outputs_class = self.entity_class_embed(hs)
outputs_coord = self.entity_bbox_embed(hs).sigmoid()
outputs_class_sub = self.sub_class_embed(hs_sub)
outputs_coord_sub = self.sub_bbox_embed(hs_sub).sigmoid()
outputs_class_obj = self.obj_class_embed(hs_obj)
outputs_coord_obj = self.obj_bbox_embed(hs_obj).sigmoid()
concat = ops.Concat(axis=-1)
concat_feat = concat((hs_sub, hs_obj, so_masks))
outputs_class_rel = self.rel_class_embed(concat_feat)
out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1],
'sub_logits': outputs_class_sub[-1], 'sub_boxes': outputs_coord_sub[-1],
'obj_logits': outputs_class_obj[-1], 'obj_boxes': outputs_coord_obj[-1],
'rel_logits': outputs_class_rel[-1]}
if self.aux_loss:
out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, outputs_class_sub, outputs_coord_sub,
outputs_class_obj, outputs_coord_obj, outputs_class_rel)
return out
基于MindSpore实现数据加载。这里实现了数据的加载以及预处理。
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Copyright (c) Institute of Information Processing, Leibniz University Hannover.
"""
dataset (COCO-like) which returns image_id for evaluation.
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
"""
from pathlib import Path
import json
#
import numpy as np
import mindspore.numpy as mnp
from pycocotools import mask as coco_mask
import os
import numpy as np
from PIL import Image
from pycocotools.coco import COCO
from mindspore import Tensor
import mindspore as ms
import mindspore.ops as ops
import mindspore.dataset.vision.c_transforms as c_vision
import mindspore.dataset.vision.py_transforms as py_vision
# from mindspore.dataset.transforms.py_transforms import Compose
from mindspore.dataset.transforms import Compose
import mindspore.dataset as ds
import cv2
class CocoDetection:
def __init__(self, img_folder, ann_file, transforms=None, return_masks=False):
self.img_folder = img_folder
self.coco = COCO(ann_file)
self.ids = list(sorted(self.coco.imgs.keys()))
self._transforms = transforms
self.prepare = ConvertCocoPolysToMask(return_masks)
# 加载关系标注
rel_json_path = '/'.join(ann_file.split('/')[:-1]) + '/rel.json'
with open(rel_json_path, 'r') as f:
all_rels = json.load(f)
if 'train' in ann_file:
self.rel_annotations = all_rels['train']
elif 'val' in ann_file:
self.rel_annotations = all_rels['val']
else:
self.rel_annotations = all_rels['test']
self.rel_categories = all_rels['rel_categories']
def __getitem__(self, idx):
"""
返回图像和标注(包括关系标注)
"""
print(idx)
image_id = self.ids[idx]
ann_ids = self.coco.getAnnIds(imgIds=image_id)
target = self.coco.loadAnns(ann_ids) # 加载标注
# 加载图像
img_info = self.coco.loadImgs(image_id)[0]
img_path = os.path.join(self.img_folder, img_info['file_name'])
img = Image.open(img_path).convert('RGB')
# 加载关系标注
rel_target = self.rel_annotations[str(image_id)]
# 构造target字典(和PyTorch版本一致)
target_dict = {
'image_id': image_id,
'annotations': target,
'rel_annotations': rel_target
}
# 应用标注转换(如将多边形转换为mask)
img, target_dict = self.prepare(img, target_dict)
# 应用图像变换(如Resize、ToTensor等)
if self._transforms is not None:
img, target_dict = self._transforms(img, target_dict)
return img, target_dict
def __len__(self):
return len(self.ids)```
为了实现模型的评估,需要将模型输出的原始张量转换为COCO评估格式:
```python
class PostProcess(nn.Cell):
""" This module converts the model's output into the format expected by the coco api"""
def construct(self, outputs, target_sizes):
""" Perform the computation
Parameters:
outputs: raw outputs of the model
target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
For evaluation, this must be the original image size (before any data augmentation)
For visualization, this should be the image size after data augment, but before padding
"""
out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
assert len(out_logits) == len(target_sizes)
assert target_sizes.shape[1] == 2
softmax = ops.Softmax(axis=-1)
prob = softmax(out_logits)
scores, labels = prob[..., :-1].max(-1)
# convert to [x0, y0, x1, y1] format
boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
# and from relative [0, 1] to absolute [0, height] coordinates
img_h, img_w = target_sizes.unbind(1)
stack = ops.Stack(axis=1)
scale_fct = stack([img_w, img_h, img_w, img_h])
boxes = boxes * scale_fct[:, None, :]
results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
return results
进行数据推理及结果的评估:
def evaluate(model, criterion, postprocessors, data_loader, base_ds):
dataset = 'vg'
eval = True
model.set_train(False)
criterion.set_train(False)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
metric_logger.add_meter('sub_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
metric_logger.add_meter('obj_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
metric_logger.add_meter('rel_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Test:'
# initilize evaluator
# TODO merge evaluation programs
if dataset == 'vg':
evaluator = BasicSceneGraphEvaluator.all_modes(multiple_preds=False)
if eval:
evaluator_list = []
for index, name in enumerate(data_loader.dataset.rel_categories):
if index == 0:
continue
evaluator_list.append((index, name, BasicSceneGraphEvaluator.all_modes()))
else:
evaluator_list = None
else:
all_results = []
iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
coco_evaluator = CocoEvaluator(base_ds, iou_types)
for samples, targets in metric_logger.log_every(data_loader, 100, header):
# samples = samples.to(device)
# targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
outputs = model(samples)
loss_dict = criterion(outputs, targets)
weight_dict = criterion.weight_dict
# reduce losses over all GPUs for logging purposes
loss_dict_reduced = loss_dict
loss_dict_reduced_scaled = {k: v * weight_dict[k]
for k, v in loss_dict_reduced.items() if k in weight_dict}
loss_dict_reduced_unscaled = {f'{k}_unscaled': v
for k, v in loss_dict_reduced.items()}
metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
**loss_dict_reduced_scaled,
**loss_dict_reduced_unscaled)
metric_logger.update(class_error=loss_dict_reduced['class_error'])
metric_logger.update(sub_error=loss_dict_reduced['sub_error'])
metric_logger.update(obj_error=loss_dict_reduced['obj_error'])
metric_logger.update(rel_error=loss_dict_reduced['rel_error'])
if dataset == 'vg':
evaluate_rel_batch(outputs, targets, evaluator, evaluator_list)
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors['bbox'](outputs, orig_target_sizes)
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
if coco_evaluator is not None:
coco_evaluator.update(res)
if dataset == 'vg':
evaluator['sgdet'].print_stats()
if eval and dataset == 'vg':
calculate_mR_from_evaluator_list(evaluator_list, 'sgdet')
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
if coco_evaluator is not None:
coco_evaluator.synchronize_between_processes()
# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()
stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if coco_evaluator is not None:
if 'bbox' in postprocessors.keys():
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
return stats, coco_evaluator
04
测试结果及可视化展示
在 Visual Genome 数据集上进行测试。实验结果如下:
下面是由MindSpore模型输出的可视化结果:
05
一些小插曲
1、模型迁移问题:
在一开始测试时,我们选择直接将pytorch预训练好的模型加载到复现的MindSpore模型上。而在将迁移后的MindSpore模型参数保存为ckpt文件后后重新加载时,MindSpore模型的输出是异常的。这可能是ckpt转化的问题,用下面的方式直接根据参数名称映射赋值则没有问题。
2、MindSpore模型计算问题:
在本地CPU上使用模型推理时,对于同一个图片,pytorch模型能够正常推理,而MindSpore模型会输出nan。
pytorch模型的输出:
MindSpore模型的输出:
两者前半部分的输出是相同的,说明模型的结构和参数并没有问题。问题可能出在MindSpore版本和运行的硬件环境上。在启智平台使用MindSpore模型推理则无异常。






