How Can We Be So Dense The Benefits of Using Highly Sparse Representations论文复现

How Can We Be So Dense? The Benefits of Using Highly Sparse Representations

论文复现。

阅读本文前请先阅读论文,本文简化了论文提供的代码,并对论文结果进行了复现和测试。

 

论文描述了使用稀疏表征的网络对于噪音的鲁棒性。

具体而言,论文使用一个具有一个卷积层+两个全连接层的网络在原始mnist上进行训练,然后在添加了噪音的测试集上进行测试。当使用稀疏连接的卷积层和全连接层时,相比于原始网络,稀疏网络表现出强大的对噪音的鲁棒性。

 

K_winner

K_winner是一个激活函数,用于替换relu。它的基本原理和relu一样,但是在输出时,不是像relu一样,大于阈值(一般是0)就输出。而是将输入排序,将最大的k个输出。换种方式,可以看出其设定了一个动态的阈值,使得有k个输入大于该阈值而可以输出。通过这种方式,控制了网络输出的稀疏程度,在本例中,1000个输入仅有100个可以输出。

其反向传播的实现机制还是和relu相同。对于k个输出,损失是1,其它的损失是0

 

import torch

class k_winners(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, dutyCycles, k, boostStrength):
        if boostStrength > 0.0:
            targetDensity = float(k) / x.size(1)
            boostFactors = torch.exp((targetDensity - dutyCycles) * boostStrength)
            boosted = x.detach() * boostFactors
        else:
            boosted = x.detach()

        # Take the boosted version of the input x, find the top k winners.
        # Compute an output that contains the values of x corresponding to the top k
        # boosted values
       
res = torch.zeros_like(x)
        topk, indices = boosted.topk(k, sorted=False)
        for i in range(x.shape[0]):
            res[i, indices[i]] = x[i, indices[i]]

        ctx.save_for_backward(indices)
        return res

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass, we set the gradient to 1 for the winning units, and 0
        for the others.
        """
       
indices, = ctx.saved_tensors
        grad_x = torch.zeros_like(grad_output, requires_grad=True)

        # Probably a better way to do it, but this is not terrible as it only loops
        # over the batch size.
       
for i in range(grad_output.size(0)):
            grad_x[i, indices[i]] = grad_output[i, indices[i]]

        return grad_x, None, None, None


class k_winners2d(torch.autograd.Function):
    """
    A K-winner take all autograd function for CNN 2D inputs (batch, Channel, H, W).
    .. seealso::
         Function :class:`k_winners`
    """
   
@staticmethod
    def forward(ctx, x, dutyCycles, k, boostStrength):
        batchSize = x.shape[0]
        if boostStrength > 0.0:
            targetDensity = float(k) / (x.shape[1] * x.shape[2] * x.shape[3])
            boostFactors = torch.exp((targetDensity - dutyCycles) * boostStrength)
            boosted = x.detach() * boostFactors
        else:
            boosted = x.detach()
        # Take the boosted version of the input x, find the top k winners.
        # Compute an output that only contains the values of x corresponding to the top k
        # boosted values. The rest of the elements in the output should be 0.
       
boosted = boosted.reshape((batchSize, -1))
        xr = x.reshape((batchSize, -1))
        res = torch.zeros_like(boosted)
        topk, indices = boosted.topk(k, dim=1, sorted=False)
        res.scatter_(1, indices, xr.gather(1, indices))
        res = res.reshape(x.shape)

        ctx.save_for_backward(indices)
        return res

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass, we set the gradient to 1 for the winning units, and 0
        for the others.
        """
       
batchSize = grad_output.shape[0]
        indices, = ctx.saved_tensors

        g = grad_output.reshape((batchSize, -1))
        grad_x = torch.zeros_like(g, requires_grad=False)
        grad_x.scatter_(1, indices, g.gather(1, indices))
        grad_x = grad_x.reshape(grad_output.shape)

        return grad_x, None, None, None

 

 

稀疏卷积和全连接层

SparseWeightNet

 

方法很简单,为了减少网络的连接,我们关闭输入到输出之间的部分连接,这个比例是50%。具体来说将这些关闭的连接在每次正向传播时的权重置0.

这与dropout有本质区别,dropout切断连接是随机的,每次都不同。SparseWeightNet

是真正的稀疏网络,它的输入和输出的连接自始至终是不同的。

论文和测试结果显示稀疏的卷积层并没有太影响性能。

 

 

import abc
import math
import numpy as np
import torch
import torch.nn as nn

def rezeroWeights(m):
  if isinstance(m, SparseWeightsBase):
    if m.training:
      m.rezeroWeights()

def normalizeSparseWeights(m):
  """
 
凯明初始化的意义在于使其导数的期望不为0
  由于用的python2.7,这里的除法应该是整数
  """
 
if isinstance(m, SparseWeightsBase):
    _, inputSize = m.module.weight.shape
    fan = int(inputSize * m.weightSparsity)
    gain = nn.init.calculate_gain('leaky_relu', math.sqrt(5))
    std = gain // np.math.sqrt(fan)
    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
   
nn.init.uniform_(m.module.weight, -bound, bound)
    if m.module.bias is not None:
      bound = 1 // math.sqrt(fan)
      nn.init.uniform_(m.module.bias, -bound, bound)



#这里创建了一个抽象类,为了从relu扩展到更多种类的层
class SparseWeightsBase(nn.Module):
  __metaclass__ = abc.ABCMeta

  def __init__(self, module, weightSparsity):
    super(SparseWeightsBase, self).__init__()
    assert 0 < weightSparsity < 1
    self.module = module
    self.weightSparsity = weightSparsity
    self.register_buffer("zeroWts", self.computeIndices())
    self.rezeroWeights()

  def forward(self, x):
    if self.training:
      self.rezeroWeights()
    return self.module.forward(x)

  @abc.abstractmethod
  def computeIndices(self):
    """
    For each unit, decide which weights are going to be zero
   
:return: tensor indices for all non-zero weights. See :meth:`rezeroWeights`
    """
   
raise NotImplementedError


  @abc.abstractmethod
  def rezeroWeights(self):
    """
    Set the previously selected weights to zero. See :meth:`computeIndices`
    """
   
raise NotImplementedError



class SparseWeights(SparseWeightsBase):
  def __init__(self, module, weightSparsity):
    """
      model = nn.Linear(784, 10)
      model = SparseWeights(model, 0.4)
    """
   
super(SparseWeights, self).__init__(module, weightSparsity)


  def computeIndices(self):
    # For each unit, decide which weights are going to be zero
   
outputSize, inputSize = self.module.weight.shape
    numZeros = int(round((1.0 - self.weightSparsity) * inputSize))

    outputIndices = np.arange(outputSize)
    #哇。产生outputIndices个,inputSize的随机排列,取numzeros个。这样就选出了numzeros个取0的下标
   
inputIndices = np.array([np.random.permutation(inputSize)[:numZeros]
                             for _ in outputIndices], dtype=np.int64)
    #对于输入的每一位,指示那些输入下标应该被置零
    # Create tensor indices for all non-zero weights
   
zeroIndices = np.empty((outputSize, numZeros, 2), dtype=np.int64)
    zeroIndices[:, :, 0] = outputIndices[:, None]
    zeroIndices[:, :, 1] = inputIndices
    #刚好填入numzeros个位置
   
zeroIndices = zeroIndices.reshape(-1, 2)
    return torch.from_numpy(zeroIndices.transpose())


  def rezeroWeights(self):
    zeroIdx = (self.zeroWts[0], self.zeroWts[1])
    self.module.weight.data[zeroIdx] = 0.0



class SparseWeights2d(SparseWeightsBase):
  def __init__(self, module, weightSparsity):
    super(SparseWeights2d, self).__init__(module, weightSparsity)


  def computeIndices(self):
    # For each unit, decide which weights are going to be zero
   
inChannels = self.module.in_channels
    outChannels = self.module.out_channels
    kernelSize = self.module.kernel_size

    inputSize = inChannels * kernelSize[0] * kernelSize[1]
    numZeros = int(round((1.0 - self.weightSparsity) * inputSize))

    outputIndices = np.arange(outChannels)
    inputIndices = np.array([np.random.permutation(inputSize)[:numZeros]
                             for _ in outputIndices], dtype=np.int64)

    # Create tensor indices for all non-zero weights
   
zeroIndices = np.empty((outChannels, numZeros, 2), dtype=np.int64)
    zeroIndices[:, :, 0] = outputIndices[:, None]
    zeroIndices[:, :, 1] = inputIndices
    zeroIndices = zeroIndices.reshape(-1, 2)

    return torch.from_numpy(zeroIndices.transpose())


  def rezeroWeights(self):
    zeroIdx = (self.zeroWts[0], self.zeroWts[1])
    #print(self.zeroWts[0].type())
   
self.module.weight.data.view(self.module.out_channels, -1)[zeroIdx] = 0.0

 

给数据添加噪音


from __future__ import print_function
import numpy as np

class RandomNoise(object):
  def __init__(self,noiselevel=0.0,whiteValue=0.1307 + 2*0.3081,):
    self.noiseLevel = noiselevel
    self.whiteValue = whiteValue
    self.iteration = 0

  def __call__(self, image):
    self.iteration += 1
    a = image.view(-1)
    numNoiseBits = int(a.shape[0] * self.noiseLevel)
    noise = np.random.permutation(a.shape[0])[0:numNoiseBits]
    a[noise] = self.whiteValue
    return image

 

 

我们使用以下代码进行测试。

 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch_mnist.k_winners import k_winners2d
from torch_mnist import sparse_weights
import skimage
from torch_mnist.image_transforms import RandomNoise

batch_size = 64
NOISE_VALUES = ["0.0", "0.05", "0.1", "0.15", "0.2", "0.25", "0.3", "0.35",
                "0.4", "0.45", "0.5","0.55","0.6","0.65","0.7","0.75","0.8"]
transform_ = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST(root='./data/',
                               train=True,
                               transform=transform_,
                               download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

class SparseNet(nn.Module):
    def __init__(self,sparseCNN=False,k_winner=False,sparseLinear=False):
        super(SparseNet, self).__init__()
        self.sparseCNN = sparseCNN
        self.k_winner = k_winner
        self.sparseLinear = sparseLinear
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=30, kernel_size=5)
        if self.sparseCNN :
            self.conv1 = sparse_weights.SparseWeights2d(self.conv1,weightSparsity=0.5)
        self.mp = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(4320, 300)
        if self.sparseLinear:
            self.fc1 = sparse_weights.SparseWeights(self.fc1,weightSparsity=0.5)
        self.fc2 = nn.Linear(300, 10)
    def forward(self, x):
        in_size = x.size(0)
        x = self.mp(self.conv1(x))
        if self.k_winner:
            x = k_winners2d.apply(x,1000,100,0)
        else:
            x = F.relu(x)
        x = x.view(in_size, -1) # flatten the tensor 相当于resharp
       
x = self.fc1(x)
        x = self.fc2(x)
        return F.log_softmax(x,dim=1)  #64*10

model = SparseNet(sparseCNN=True,k_winner=True,sparseLinear=True)


optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

def train(epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        output = model(data)
        loss = F.nll_loss(output, target)
        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

        optimizer.zero_grad()   # 所有参数的梯度清零
       
loss.backward()         #即反向传播求梯度
       
optimizer.step()        #调用optimizer进行梯度下降更新参数

def test(noise_idx):
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,))])
    noise = float(NOISE_VALUES[noise_idx])
    transform.transforms.append(RandomNoise(noise, whiteValue=0.1307 + 2 * 0.3081))
    test_dataset = datasets.MNIST(root='./data/',
                                  train=False,
                                  transform=transform)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=batch_size,
                                              shuffle=False)
    test_loss = 0
    correct = 0
    for data, target in test_loader:
        with torch.no_grad():
            data, target = Variable(data), Variable(target)
        output = model(data)
        # sum up batch loss
       
test_loss += F.cross_entropy(output, target, reduction='sum').item()
        # get the index of the max log-probability
       
pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).cpu().sum()

    test_loss /= len(test_loader.dataset)
    print(correct.item()/100)


for epoch in range(1):
    train(epoch)
    for i in range(17):
        test(i)

 

 

我们对测试级添加0-80%不等的噪音,获得了以下结果。

稀疏网络显示出对噪音惊人的容忍程度。

 

 

 

噪音

denseNet

sparseNet

0

96.84

95.38

0.05

96.55

95.13

0.1

96.2

94.55

0.15

95.56

94.02

0.2

94.9

93.84

0.25

92.64

92.72

0.3

89.22

92.25

0.35

83.86

91.06

0.4

77.36

90.04

0.45

69.36

88.89

0.5

59.37

87.58

0.55

48.66

85.35

0.6

38.23

82.98

0.65

28.23

80.12

0.7

21.49

75.18

0.75

15.63

69.1

0.8

12.19

60.07

 

 

你可能感兴趣的