image generation from scene graphs 论文+code复现总结

image generation from scene graphs 论文+code复现总结

abstraction

传统方法在一些限制好的领域像鸟或花,这些方法都还不错,但是在如实地分解复杂的段落为多个对象和关系上都很失败。

他们提出了一个方法,从场景图生成图像,明确地推理对象和他们的关系

用图卷积网络处理输入图像,计算一个场景布局通过预测边界框和对象分割遮罩,通过级联优化网络(一对鉴别器)将布局转化为网络

introduction

传统方法 RNN+GAN

句子是线性结构

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OtPcLABp-1617713336517)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20201230100541242.png)]

一个复杂句子传递的信息可以被一个包含对象和他们关系的场景图更明确的表示

场景图是一种可以用来表示文本或者图像结构的表述

可以看到,场景图将场景表示为有向图,其中节点(红色)是对象,边(蓝色)给出对象之间的关系

  1. 图卷积网络沿着图像边缘处理信息
  2. 缝合间隙,图结构的输入和两个维度的图像输出
  3. 通过预测一个边缘盒和语义遮罩建立了一个场景布局
  4. 级联优化网络生成图片
  5. 训练一对鉴别器网络来生成

数据集

  1. Visual Genome-提供人工标注的场景图
  2. COCO-正确标注图像

比较实验结果

与 StackGAN 比较在Amazon Mechanical Turk(众包平台)上

related work

生成模型

GAN

VAE—通过变分推理,共同学习在图像和潜码之间的分布一个编码器和解码器

自回归模型—通过之前的所有像素限制每个像素

条件图像合成、

GANs可以以类别标签为条件,向生成器和鉴别器提供标签作为额外的输入,或者强迫鉴别器预测标签

本文采用了后者的方法

几个方法:

Reed:GAN和多尺度回归模型,

Chen:街景生成—CRN(本文用了这个模型从场景布局生成图片) 是场景布局预测,研究了从文本到3D场景生成的方法

场景图

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TnyJK8Mc-1617713336519)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20201230152328273.png)]
将场景表示为有向图,节点是对象和赋予了两个对象之间关系的边

图像上的深度学习

现有的方法:word2vec

嵌入,给一个文档语料库,一个很大的图

本文的方法:图网络

在任意图上的递归网络

method

困难:

  1. 处理图结构的输入
  2. 生成的图片明确代表图中目标和关系
  3. 保证合成的图像是真实的

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kW52ceMc-1617713336520)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20201230153132826.png)]

我们从图卷积网络中的目标嵌入向量表示场景图中的目标和关系来预测边围盒和目标的语义遮罩,这些组合起来组成了语义布局,作为图和图像领域中的过渡

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-gNu1TBYs-1617713336521)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20201230153528034.png)]

场景图的部分是由图卷积神经网络,给嵌入向量赋予每个目标,每个图卷积层混合了图的每条边的信息

scene graphs

O-object;R-relationship;

a scene graph is a tuple(O,E) ;
E ∈ O × R × O E∈O×R×O EO×R×O
第一部处理:我们用学习好的嵌入层将图中的每个结点和边从分类好的标签转化为一个向量

graph convolution network

传统的2维的卷积层

空间网格的特征向量作为输入,输出一个新的空间网格的特征向量

通过共享权重,每个输出向量包含了其对应的输入相邻输入的信息

本文中的图卷积层

在每个节点和边上,Din维度的向量输入,会有Dout维度的输出

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5FyopVvC-1617713336522)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210101184217741.png)]

​ 三个函数g 将一条边的元组(vi,vr,vj)作为输入,分别输出对于subject 主体 oi,预测的关系r,object 客体 oj新向量

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SN5lCDOZ-1617713336522)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210101192808271.png)]

一个客体向量可能在多个关系里

因此,客体oi的输出向量vi’由所有图的边线连接oi的向量vi,对于那些边的向量vr

最后,计算主体(每条边开始在oi)候选向量集合,和客体(每条边结束在oi)候选向量集合

oi的输出向量vi会由h这个函数计算,池化这个向量集合到一个输出向量

在我们的执行过程中,对于gs,gp和go单独的网络连接三个输入向量

将它们提供给一个多层感知器

使用完全连接的输出头计算三个输出向量

池函数h取其输入向量的平均值,并将结果提供给MLP

scene layout

用一系列的图卷积层处理输入场景图,给出每个对象的嵌入向量,该向量聚合了图中所有对象和关系的信息

场景布局给出了图像粗糙的二维结构

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-SozF7J9l-1617713336523)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210101201353429.png)]

我们通过使用对象布局网络预测分割掩模和每个对象的包围框来计算场景布局,如图4所示

目标oi,大小为D的嵌入向量vi通过遮罩回归网络预测一个大小为M×M二值遮罩m。一个盒回归网络预测一个包围盒

掩模回归网络由几个以sigmoid非线性函数结束的转置卷积组成,使掩模的元素位于范围(0,1);盒回归网络是一个MLP

我们将嵌入向量vi与掩模mi相乘,得到形状为D×M×M的掩模嵌入,然后它将弯曲到包围盒的位置使用双线性插值来给出一个对象布局

在训练过程中,我们使用正确标注的包围盒bi来计算场景布局;在测试时,我们用预测包围盒bi

cascaded refinement network

CRN由一系列卷积精化模块组成,模块间空间分辨率加倍;这允许生成以从粗到细的方式进行

每个模块接收场景布局(下采样到模块的输入分辨率)和前一个模块的输出作为输入。这些输入以信道方式连接并传递到一对3×3的卷积层,输出在被传递到下一个模块之前,使用最近邻插值法向上采样。

discriminators

基于补丁的图像鉴别器保证生成的图像的整体外观是真实的、

对象鉴别器确保图像中的每个对象看起来都是真实的

除了对每个对象进行真假分类外,Dobj 还确保每个对象都可以使用辅助分类器来预测对象的类别

该分类器预测对象的类别

training

训练生成网络使6个损失的加权和最小

  1. box loss 惩罚正确标注和预测包围盒之间L1的差值
  2. mask loss 惩罚正确标注和预测遮罩之间L1的差值,交叉熵。忽略了在VG上训练模型的掩码预测损失
  3. pixel loss 正确标注和生成图像之间的差值
  4. image adversarial loss 图像生成器和图像鉴别器之间的损失(图像补丁看起来真实,realistic)
  5. object adversarial loss 对象生成器和对象鉴别器之间的损失(每个生成的对象看起来真实)
  6. auxiliarly classifier loss

implement details

在所有场景图中都增加了一个特殊的图像对象,并在每个真对象与图像对象之间增加了特殊的图像关系;这确保了所有场景图都是连接的

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-LOQou86w-1617713336523)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210102130754187.png)]

experiments

dataset

COCO

该数据集为40K的训练图像和5K的 validation 图像标注了边界框和分割掩模,用于80个thing类别(人、车等)和91个stuff类别(天空、草地等)。

我们使用这些标注来构建基于对象二维图像坐标的合成场景图,使用6种互斥的几何关系:左、右、上、下、内和周围

我们会忽略覆盖不到图像2%的物体,使用包含3到8个对象的图像

visual genome

该基因组包含108077张带有场景图注释的图像。我们将数据分为80%的训练,10%的val集和10%的测试集;使用对象和关系类别在训练集中分别出现至少2000次和500次,留下178个对象和45个关系类型

我们忽略小的物体,使用有3到30个物体和至少一个关系的图像;这样我们就有62,565张训练图像、5,506张val和5,088张测试图像,每张图像平均有10个对象和5个关系

视觉基因组不提供分割遮罩,因此忽略在VG上训练模型的遮罩预测损失

qualitative results

图五展示了我们的方法可以生成具有多个对象的场景,甚至是具有相同对象类型的多个实例:

图六,事物位置的变化表示关系受到了遵循

ablation study

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ujP2tlbu-1617713336524)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210102143732830.png)]

我们使用Inception score来测量图像质量,它使用ImageNet分类模型来鼓励图像中的可识别对象和图像的多样性

  1. no gconv,它不能共同推理不同物体的存在,并且只能预测每个类别的一个盒子和遮罩
  2. no relationship,图卷积允许这个模型联合地描述对象。说明场景图关系的实用性的性能较差
  3. no discriminators,生成过度平滑的图像
  4. no Dobj and Dimg(omit one of them),
  5. GT layout,提供了一个性能上限

object localization

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-m228J5DA-1617713336524)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210102145933808.png)]

R@t is object recall with an IoU threshold of t, and measures agreement with ground-truth boxes.

交并比(Intersection over Union)

σx 和 σarea通过计算每个对象类别中盒x-位置和区域的标准偏差,然后在不同类别之间求平均值来测量盒的多样性

衡量标准之一:预测盒和正确标注盒的高度一致

另一种度量方法是多样性:对象的预测盒应该随着图中其他对象和关系的变化而变化

1.no gonv,模型只能学会预测每个对象类别的一个边界框

2.no relationship,如果没有关系,这个模型的预测盒与真实标注盒位置的一致性较差。

user studies

caption matching

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XiAGgJLG-1617713336524)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210102155304303.png)]

object recall

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XszA663A-1617713336525)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210102155731475.png)]

这个实验测量了每种方法图像中可识别对象的数量

conclusion

本文提出了一种从场景图中生成图像的端到端方法。与主要的从文本描述生成图像的方法相比,从结构化的场景图而不是非结构化的文本生成图像允许我们的方法明确地推理对象和关系,并生成具有许多可识别对象的复杂图像。

supplementary material

附录

valvalidation的简称。

training datasetvalidation dataset都是在训练的时候起作用。
而因为validation的数据集和training没有交集,所以这部分数据对最终训练出的模型没有贡献。
validation的主要作用是来验证是否过拟合、以及用来调节训练参数等。

比如训练0-10000次迭代过程中,trainvalidationloss都是不断降低,
但是从10000-20000过程中train loss不断降低,validationloss不降反升。
那么就证明继续训练下去,模型只是对training dataset这部分拟合的特别好,但是泛化能力很差。
所以与其选取20000次的结果,不如选择10000次的结果。
这个过程的名字叫做Early Stopvalidation数据在此过程中必不可少。

如果跑caffe自带的训练demo,你会用到train_val.prototxt,这里面的val其实就是validation
而网络输入的TEST层,其实就是validation,而不是test。你可以通过观察validationlosstrainloss定下你需要的模型。

但是为什么现在很多人都不用validation了呢?
我的理解是现在模型中防止过拟合的机制已经比较完善了,Dropout\BN等做的很好了。
而且很多时候大家都用原来的模型进行fine tune,也比从头开始更难过拟合。
所以大家一般都定一个训练迭代次数,直接取最后的模型来测试。

召回率和交并比IOU

召回率

就是被正确识别出来的正样本个数与测试集中所有正样本的个数的比值

注: Precision和Recall之间往往是一种博弈关系,好的模型让Recall值增长的同时保持Precision的值也在很高的水平,而差的模型性可能会损失很多Precision值才能换来Recall值的提高。通常情况下,都会使用Precision-recall曲线,来显示分类模型在Precision与Recall之间的权衡。

交并比(Intersection-over-Union,IoU)

目标检测中使用的一个概念,是产生的候选框(candidate bound)与原标记框(ground truth bound)的交叠率,即它们的交集与并集的比值。最理想情况是完全重叠,即比值为1。

引用

代码复现

第一次复现 总结

  1. pycharm interpreter setting 里面新建一个虚拟conda环境 然后配包

    这样的话 比较好管理 一个项目一个环境

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NPbPH13u-1617713336525)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210406204051543.png)]

  1. 不用命令行编译,直接Run model 路径比较不容易出错

  2. 在这个model里,我路径出错,解决方式:

    把路径补全

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kr2it1gR-1617713336525)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210406204412513.png)]

    还有扩展路径

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ROszAj0C-1617713336526)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210406204436611.png)]

  3. pytorch 论文用的老版本,应该找兼容再高一点的版本,要不然显卡启动太慢,模型加载的也慢

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XdTos2d4-1617713336526)(C:\Users\Administrator\AppData\Roaming\Typora\typora-user-images\image-20210406204559294.png)]

[外链图片转存中…(img-NPbPH13u-1617713336525)]

  1. 不用命令行编译,直接Run model 路径比较不容易出错

  2. 在这个model里,我路径出错,解决方式:

    把路径补全

    [外链图片转存中…(img-kr2it1gR-1617713336525)]

    还有扩展路径

    [外链图片转存中…(img-ROszAj0C-1617713336526)]

  3. pytorch 论文用的老版本,应该找兼容再高一点的版本,要不然显卡启动太慢,模型加载的也慢

[外链图片转存中…(img-XdTos2d4-1617713336526)]

  1. ubuntu 虚拟机不能利用显卡,window 用wget运行脚本有一大堆bug

你可能感兴趣的