Spatial Transformer Networks(STN)-论文笔记

Spatial Transformer Networks(STN)-论文笔记

  • 论文: Spatial Transformer Networks
  • 作者:Max Jaderberg Karen Simonyan Andrew Zisserman Koray Kavukcuoglu
  • code1:https://github.com/oarriaga/STN.keras
  • code2:https://github.com/kevinzakka/spatial-transformer-network

1. 问题提出

  1. CNN在图像分类中取得了显著的成效,主要是得益于 CNN 的深层结构具有 : 平 移 不 变 性 、 缩 小 不 变 性 \color{red}平移不变性、缩小不变性 ;还对缺失的 空 间 不 变 性 ( s p a t i a l l y    i n v a r i a n c e ) \color{red}空间不变性(spatially\;invariance) spatiallyinvariance做了相应的实验。

    • 平 移 不 变 性 平移不变性 主要是由于 Pooling 层 和 步长不为1的卷积层 的存在带来的。实际上主要是池化层的作用:
      • 层越多,越深,池化核或卷积核越大,空间不变性也越强;
      • 但是随之而来的问题是局部信息丢失,所以这些层越多准确率肯定是下降的,所以主流的CNN分类网络一般都很深,但是池化核都比较小,比如2×2。
    • 缩 小 不 变 性 缩小不变性 主要是通过降采样来实现的。降采样比例要根据数据集调整,找到合适的降采样比例,才能保证准确率的情况下,有较强的空间不变性。
      • 比如ResNet,GoogLeNet,VGG,FCN,这些网络的总降采样比例一般是 16或32,基本没有见过 64倍,128倍或者更高倍数的降采样(会损失局部信息降低准确率),也很少见到 2倍或者4倍的降采样比例(空间不变性太弱,泛化能力不好)。
  2. 空 间 不 变 性 ( s p a t i a l l y    i n v a r i a n c e ) \color{red}空间不变性(spatially\;invariance) spatiallyinvariance这些不变性的本质就是图像处理的经典手段:空间变换,又服从于同一方法:坐标矩阵的仿射变换。因此DeepMind设计了 S p a t i a l    T r a n s f o r m e r    N e t w o r k s \color{red}Spatial\;Transformer\;Networks SpatialTransformerNetworks(简称STN),目的就是显式地赋予网络对于以上各项变换(transformation)的不变性(invariance) .

2. 图像处理技巧

2.1 仿射变化

主要是要处理 ( 2 × 3 ) (2\times 3) (2×3)的变换矩阵:
T θ = [ θ 11 θ 12 θ 13 θ 21 θ 22 θ 23 ] (2.1) \mathcal{T}_{\theta} = \begin{bmatrix} \theta _{11} & \theta _{12} & \theta _{13} \\ \theta _{21} & \theta _{22} & \theta _{23} \end{bmatrix}\tag{2.1} Tθ=[θ11θ21θ12θ22θ13θ23](2.1)

  • 平移:
    [ 1 0 θ 13 0 1 θ 23 ] [ x y 1 ] = [ x + θ 13 y + θ 23 ] (2.2) \left[\begin{array}{ccc} 1 & 0 & \theta_{13} \\0 & 1 & \theta_{23} \end{array}\right]\left[\begin{array}{l}x \\y \\1 \end{array}\right]=\left[\begin{array}{l} x+\theta_{13} \\y+\theta_{23} \end{array}\right]\tag{2.2} [1001θ13θ23]xy1=[x+θ13y+θ23](2.2)

  • 缩放:
    [ θ 11 0 0 0 θ 22 0 ] [ x y 1 ] = [ θ 11 x θ 22 y ] (2.3) \left[\begin{array}{ccc} \theta_{11} & 0 & 0 \\0 & \theta_{22} & 0 \end{array}\right]\left[\begin{array}{l}x \\y \\1 \end{array}\right]=\left[\begin{array}{l} \theta_{11} x \\\theta_{22} y\end{array}\right]\tag{2.3} [θ1100θ2200]xy1=[θ11xθ22y](2.3)

  • 旋转:
    对于旋转操作,设绕原点顺时针旋转 α \alpha α 度,坐标仿射矩阵为:
    [ cos ⁡ ( α ) sin ⁡ ( α ) 0 − sin ⁡ ( α ) cos ⁡ ( α ) 0 ] [ x y 1 ] = [ cos ⁡ ( α ) x + sin ⁡ ( α ) y − sin ⁡ ( α ) x + cos ⁡ ( α ) y ] (2.4) \left[\begin{array}{ccc} \cos (\alpha) & \sin (\alpha) & 0 \\ -\sin (\alpha) & \cos (\alpha) & 0 \end{array}\right]\left[\begin{array}{l}x \\y \\1 \end{array}\right]=\left[\begin{array}{c}\cos (\alpha) x+\sin (\alpha) y \\-\sin (\alpha) x+\cos (\alpha) y \end{array}\right]\tag{2.4} [cos(α)sin(α)sin(α)cos(α)00]xy1=[cos(α)x+sin(α)ysin(α)x+cos(α)y](2.4)

    由于图像的坐标不是中心坐标系,通常需要做Normalization,把坐标调整到 [ − 1 , 1 ] [-1,1] [1,1]。这样,就绕图像中心旋转了。

2.2 逆向坐标映射

假设fixed image 的坐标点是 [ x t a r , y t a r ] [x^{tar}, y^{tar}] [xtar,ytar],source iamge 的坐标点是 [ x s o u r , y s o u r ] [x^{sour}, y^{sour}] [xsour,ysour],则一般的坐标映射可以表示为:
[ θ 11 θ 12 θ 13 θ 21 θ 22 θ 23 ] [ x s o u r y s o u r 1 ] = [ x t a r y t a r 1 ] (2.5) \begin{bmatrix} \theta_{11} & \theta_{12} & \theta_{13} \\ \theta _{21} & \theta _{22} & \theta _{23} \end{bmatrix}\begin{bmatrix} x^{sour} \\ y^{sour} \\ 1 \end{bmatrix}=\begin{bmatrix} x^{tar} \\ y^{tar} \\1 \end{bmatrix}\tag{2.5} [θ11θ21θ12θ22θ13θ23]xsourysour1=xtarytar1(2.5)

逆向坐标映射表示为( θ ′ \theta' θ and θ \theta θ are different):
[ θ 11 ′ θ 12 ′ θ 13 ′ θ 21 ′ θ 22 ′ θ 23 ′ ] [ x t a r y t a r 1 ] = [ x s o u r y s o u r 1 ] (2.6) \begin{bmatrix} \theta'_{11} & \theta'_{12} & \theta'_{13} \\ \theta' _{21} & \theta' _{22} & \theta' _{23} \end{bmatrix}\begin{bmatrix} x^{tar} \\ y^{tar} \\ 1 \end{bmatrix}=\begin{bmatrix} x^{sour} \\ y^{sour} \\1 \end{bmatrix}\tag{2.6} [θ11θ21θ12θ22θ13θ23]xtarytar1=xsourysour1(2.6)
STN采用逆向映射,因为:target image 是固定的,正向的插值过程,都是引用像素坐标是浮点数,相对来说很难插值;对应逆向变换,得到的Source坐标是浮点数,用Source像素插值更加便捷

2.3 双线性插值

  • 一个[1,10]图像放大10倍问题,我们需要将10个像素,扩展到为100的数轴上,整个图像应该有100个像素。
    但其中90个对应Source图的坐标是非整数的,是不存在的,如果我们用黑色(RGB(0,0,0))填充,此时图像是惨不忍睹的。所以需要对缺漏的像素进行插值,利用图像数据的局部性近似原理,取邻近像素做平均生成。

  • 双线性插值是一个兼有质量与速度的方法:
    Spatial Transformer Networks(STN)-论文笔记_第1张图片

  • 插值一般表达式:
    V i c = ∑ n H ∑ m W U n m c k ( x i s − m ; Φ x ) k ( y i s − n ; Φ y ) ∀ i ∈ [ 1 … H ′ W ′ ] ∀ c ∈ [ 1 … C ] (2.7) V_{i}^{c}=\sum_{n}^{H} \sum_{m}^{W} U_{n m}^{c} k\left(x_{i}^{s}-m ; \Phi_{x}\right) k\left(y_{i}^{s}-n ; \Phi_{y}\right) \forall i \in\left[1 \ldots H^{\prime} W^{\prime}\right] \forall c \in[1 \ldots C]\tag{2.7} Vic=nHmWUnmck(xism;Φx)k(yisn;Φy)i[1HW]c[1C](2.7)

    • U n m c U_{n m}^{c} Unmc 是输入feature map上第 c c c 个通道上坐标为 ( n , m ) (n, m) (n,m) 的像素值;
    • V i c V_{i}^{c} Vic 是输出 feature map上第 c c c 个通道上坐标为 ( x i t , y i t ) \left(x_{i}^{t}, y_{i}^{t}\right) (xit,yit) 的像素值;
    • k ( ) k() k() 表示插值核函数;
    • Φ x , Φ y \Phi x, \Phi y Φx,Φy 代表 x \mathrm{x} x y \mathrm{y} y 方向的揷值核函数的参数;
    • H , W H, W H,W 输入 U U U的尺寸;
    • H ′ , W ′ H^{\prime}, W^{\prime} H,W 输出 V V V的尺寸;
  • 双线性插值的公式:
    V i c = ∑ n H ∑ m W U n m c max ⁡ ( 0 , 1 − ∣ x i s − m ∣ ) max ⁡ ( 0 , 1 − ∣ y i s − n ∣ ) (2.8) V_{i}^{c}=\sum_{n}^{H} \sum_{m}^{W} U_{n m}^{c} \max \left(0,1-\left|x_{i}^{s}-m\right|\right) \max \left(0,1-\left|y_{i}^{s}-n\right|\right)\tag{2.8} Vic=nHmWUnmcmax(0,1xism)max(0,1yisn)(2.8)
    这个插值核函数做的是利用 U U U中离 当前源坐标 ( x i s , y i s ) \left(x_{i}^{s}, y_{i}^{s}\right) (xis,yis) (小数坐标) 最近的 4个整数坐标 ( n , m ) (n, m) (n,m) 处的像素值做双线性插值然后拷贝到 V V V中的 ( x i t , y i t ) \left(x_{i}^{t}, y_{i}^{t}\right) (xit,yit) 坐 标处。
    Spatial Transformer Networks(STN)-论文笔记_第2张图片


3. 整体框架

3.1 整体描述

Spatial Transformer Networks的结构,主要的部分—共有三个,它们的功能和名称如下:

  • L o c a l i s a t i o n    n e t \color{blue}Localisation\;net Localisationnet(参数预测):
    是自己定义的网络,它输入 U U U,输出变化参数 θ \theta θ,这个参数用来映射 U U U V V V的坐标关系(公式(2.1))。
  • G r i d    g e n e r a t o r \color{green}Grid\;generator Gridgenerator(坐标映射):
    根据 V V V中的坐标点和变化参数 θ \theta θ,计算出 U U U中的坐标点(公式(2.6))。
    • 这里是因为 V V V的大小是先定义好的,当然可以得到 V V V的所有坐标点,而填充 V V V中每个坐标点的像素值的时候,要从 U U U中去取,所以根据 V V V中每个坐标点和变化参数 θ \theta θ进行运算,得到一个坐标。
    • 在sampler中就是根据这个坐标去 U U U中找到像素值,这样子来填充 V V V
  • S a m p l e r \color{gray}Sampler Sampler(像素的采集):

    根据Grid generator得到的一系列坐标和原图 U U U(因为像素值要从 U U U中取)来填充,因为计算出来的坐标可能为小数,要用另外的方法来填充,比如双线性插值。

Spatial Transformer Networks(STN)-论文笔记_第3张图片

3.2 基本结构与前向传播

Spatial Transformer Networks(STN)-论文笔记_第4张图片

  • DeepMind为了描述这个空间变换层,首先添加了坐标网格计算的概念,即:
    • 对应输入源特征图像素的坐标网格——Sampling Grid,保存着 ( x S o u r c e , y S o u r c e ) (x^{Source},y^{Source}) (xSource,ySource)
    • 对应输出源特征图像素的坐标网格——Regluar Grid ,保存着 ( x T a r g e t , y T a r g e t ) (x^{Target},y^{Target}) (xTarget,yTarget)
  1. L o c a l i s a t i o n    n e t \color{blue}Localisation\;net Localisationnet(参数预测):对应着初始化的6个参数。
  2. G r i d    g e n e r a t o r \color{green}Grid\;generator Gridgenerator(坐标映射):对应着图中的①②。
    T θ ( G i ) [ x t a r y t a r 1 ] = [ θ 11 ′ θ 12 ′ θ 13 ′ θ 21 ′ θ 22 ′ θ 23 ′ ] [ x t a r y t a r 1 ] = [ x s o u r y s o u r 1 ] , w h e r e    i = 1 , 2 , 3 , 4.. , H ∗ W (3.1) \mathcal{T}_{\theta}(G_i)\begin{bmatrix} x^{tar} \\ y^{tar} \\ 1 \end{bmatrix} = \begin{bmatrix} \theta'_{11} & \theta'_{12} & \theta'_{13} \\ \theta' _{21} & \theta' _{22} & \theta' _{23} \end{bmatrix}\begin{bmatrix} x^{tar} \\ y^{tar} \\ 1 \end{bmatrix}=\begin{bmatrix} x^{sour} \\ y^{sour} \\1 \end{bmatrix}, where\;i=1,2,3,4..,H∗W\tag{3.1} Tθ(Gi)xtarytar1=[θ11θ21θ12θ22θ13θ23]xtarytar1=xsourysour1,wherei=1,2,3,4..,HW(3.1)
  3. S a m p l e r \color{gray}Sampler Sampler(像素的采集):对应着图中的③④。

3.3 梯度流动与反向传播

添加空间变换层之后,梯度流动变得有趣,如图:
Spatial Transformer Networks(STN)-论文笔记_第5张图片

  1. 后流(①):
    E r r o r    G r a d i e n t Error\;Gradient ErrorGradient → … … → ∂ N e x t ∂ V i c \rightarrow \ldots \ldots \rightarrow \frac{\partial N e x t}{\partial V_{i}^{c}} VicNext
    这是Back Propagation从后层继承的动力源泉,没有它,你就不可能完成Back Propagation。
  2. 里流(②):
    { ∂ V i c ∂ x i S → ∂ x i S ∂ θ ∂ V i c ∂ y i S → ∂ y i S ∂ θ (3.3) \left\{\begin{aligned} \frac{\partial V_{i}^{c}}{\partial x_{i}^{S}} \rightarrow \frac{\partial x_{i}^{S}}{\partial \theta} \\ \frac{\partial V_{i}^{c}}{\partial y_{i}^{S}} \rightarrow \frac{\partial y_{i}^{S}}{\partial \theta} \end{aligned}\right.\tag{3.3} xiSVicθxiSyiSVicθyiS(3.3)
  • 个人对这股流的最好描述就是: 一江春水流进了小黑屋。
  • 是的,你没有看错,这股流根本就没有流到网络开头,而是在定位网络处就断流了。 由此来看,定位网络就好像是在主网络旁侧偷建的小黑屋,是一个违章湕筑。
  • 所以也无怪乎作者说,定位网络直接変成了一个回归模型,因为更新完参数,流就断了,独立于主网络。
  1. 前流(③):
    ∂ V i c ∂ U n m i → ∂ U n m i ∂  Previous  (3.4) \frac{\partial V_{i}^{c}}{\partial U_{n m}^{i}} \rightarrow \frac{\partial U_{n m}^{i}}{\partial \text { Previous }}\tag{3.4} UnmiVic Previous Unmi(3.4)
    这是Back Propagation传宗接代的根本保障,没有它,Back Propagation就断子绝孙了。

3.4 局部梯度

论文中多次出现[局部梯度] (Sub-Gradient) 的概念。采样核函数,是不连续的,不能如下直接求导:
g = ∂ V i c ∂ θ (3.5) g=\frac{\partial V_{i}^{c}}{\partial \theta}\tag{3.5} g=θVic(3.5)
而应该是分两步,先对 x i S 、 x i S x_{i}^{S} 、 x_{i}^{S} xiSxiS 求局部梯度: ∂ V i c ∂ x i c 、 ∂ V i c ∂ y i c \frac{\partial V_{i}^{c}}{\partial x_{i}^{c}} 、 \frac{\partial V_{i}^{c}}{\partial y_{i}^{c}} xicVicyicVic ,后有:
{ g = ∂ V i c ∂ x i S ⋅ ∂ x i S ∂ θ g = ∂ V i c ∂ y i S ⋅ ∂ y i S ∂ θ (3.6) \left\{\begin{aligned} g=\frac{\partial V_{i}^{c}}{\partial x_{i}^{S}} \cdot \frac{\partial x_{i}^{S}}{\partial \theta} \\ g=\frac{\partial V_{i}^{c}}{\partial y_{i}^{S}} \cdot \frac{\partial y_{i}^{S}}{\partial \theta} \end{aligned}\right.\tag{3.6} g=xiSVicθxiSg=yiSVicθyiS(3.6)
有趣的是,对于Theano这种目动求导的 Tools,局部梯度可以直接被忽视。
因为Theano的Tensor机制,会聪明地讨论并且解离非连续函数,追踪每一个可导子式,即便你用了作者们的优雅的采样函数, Tensor.grad函数也能精确只对许出的4个点求导,所以在Theano里讨论非连续函数和局部梯度,是会贻笑大方的。


4. 实验

4.1 Distorted MNIST

这个试验的数据集 是 MNIST,不过与原版的MNIST 不同,这个数据集对图片上的数字做了各种形变操作,比如平移,扭曲,放缩,旋转等。

  • 不同形变操作的简写表示:
    • 旋转:rotation ( R),
    • 旋转+缩放+平移:rotation, scale and translation (RTS),
    • 投影变换:projective transformation ( P),
    • 弹性变形:elastic warping (E) – note that elastic warping is destructive and can not be inverted in some cases.
  • 文章将 Spatial Transformer 模块嵌入到 两种主流的分类网络,FCN和CNN中(ST-FCN 和 ST-CNN )。Spatial Transformer 模块嵌入位置在图片输入层与后续分类层之间。
  • 试验也测试了不同的变换函数对结果的影响:
    • 仿射变换:affine transformation (Aff),
    • 投影变换:projective transformation (Proj),
    • 薄板样条变换:16-point thin plate spline transformation (TPS)

其中CNN的模型与 LeNet是一样的,包含两个池化层。为了公平,所有的网络变种都只包含 3 个可学习参数的层,总体网络参数基本一致,训练策略也相同。
Spatial Transformer Networks(STN)-论文笔记_第6张图片

  • 左侧:不同的形变策略以及不同的 Spatial Transformer网络变种与 baseline的对比;
  • 右侧:一些CNN分错,但是ST-CNN分对的样本
    - (a ):输入
    - (b ):Spatial Transformer层 的 源坐标(Tθ(G) )可视化结果
    - (c ):Spatial Transformer层输出
  • 很明显:ST-CNN优于CNN, ST-FCN优于FCN,说明Spatial Transformer确实增加了 空间不变性
  • FCN中由于没有 池化层,所以FCN的空间不变性不如CNN,所以FCN效果不如CNN
  • ST-FCN效果可以达到CNN程度,说明Spatial Transformer确实增加了 空间不变性
  • ST-CNN效果优于ST-FCN,说明 池化层 确实对 增加 空间不变性很重要
  • 在 Spatial Transformer 中使用 plate spline transformation (TPS) 变换效果是最好的
  • Spatial Transformer 可以将歪的数字扭正
  • Spatial Transformer 在输入图片上确定的attention区域很明显利于后续分类层分类,可以更加有效地减少分类损失

4.2 Street View House Numbers

Street View House Numbers是一个真实的 街景门牌号 数据集,共200k张图片,每张图片包含1-5个数字 ,数字都有形变。

  • baseline character sequence CNN model :11层,5个softmax层输出对应位置的预测序列
  • STCNN Single :在输入层添加一个Spatial Transformer
  • ST-CNN Multi :前四层,每一层都添加一个Spatial Transformer 见下面 tabel 2 右侧
  • localisation networks 子网络:两层32维的全连接层
  • 使用仿射变换和双线性插值
    Spatial Transformer Networks(STN)-论文笔记_第7张图片

结果:

参考

  1. arleyzhang:基础DL模型-STN-Spatial Transformer Networks-论文笔记
  2. Spatial Transformer Networks笔记
  3. 详细解读Spatial Transformer Networks(STN)-一篇文章让你完全理解STN了
  4. Spatial Transformer Networks
  5. 论文笔记:空间变换网络(Spatial Transformer Networks)

你可能感兴趣的