Pytorch从零构建风格迁移(Style Transfer)

Pytorch构建风格迁移

  • 前言
  • 风格迁移示例
  • Pytorch实战
    • 获取原始内容图片与风格图片并进行预处理
    • 搭建网络框架
    • **构建内容损失与风格损失**
    • 构建优化器进行最终训练

前言

艺术创作可以看做两个重要因素的联合,即画什么和怎么画(内容与风格)。而风格迁移(Style Transfer)在图像处理中被广泛用于风格再创作,即基于所定内容按照指定的艺术风格进行绘画。复现本文需要用到Pytorch库,可参考本人之前的博客Pytorch安装详解。

风格迁移示例

风格迁移算法由内容图片(C)与风格图片(S)构成,算法必须生成一个具有内容图片内容与风格图片风格的新图片(O)。如下图所示:

其中内容图片(C)为塞纳河畔一处风景,风格图片(S)为梵高所做星夜,将其二者结合训练即可得到具有星夜风格的塞纳河畔景色(O)。

Pytorch实战

我们要实现上述风格迁移过程,需要以下步骤:

  • 获取原始内容图片与风格图片并进行预处理
  • 搭建网络框架
  • 构建内容损失与风格损失
  • 构建优化器进行最终训练

下面基于Pytorch实现上述风格迁移示例

获取原始内容图片与风格图片并进行预处理

  1. 原始图片可根据自己需要准备风格与内容图片,用函数image_loader进行加载。
  2. 由于训练网络框架采用VGG19,需要将输入图片调整至适应VGG19的数据格式,即采用prep函数进行前处理。
  3. 在最终训练完成后需将张量格式后处理转换为PILImage格式,需要用到postb函数,其中包含postpa后处理函数。
  4. 最终将我们想得到的优化图片初始化为内容图片,requires_grad设置为True
imsize = 512
prep = transforms.Compose([transforms.Resize(imsize),
                          transforms.ToTensor(),
                          
                          
                          transforms.Lambda(lambda x:
                                            x[torch.LongTensor([2,1,0])]),
                              transforms.Normalize([0.40760392, 0.45795686,
                                                    0.48501961], [1, 1, 1]),
                              transforms.Lambda(lambda x: x.mul_(255))])

postpa = transforms.Compose([transforms.Lambda(lambda x: x.mul_(1/255)),
                             transforms.Normalize([-0.40760392, -0.45795686,
                                                    -0.48501961], [1, 1, 1]),
                             transforms.Lambda(lambda x:
                                            x[torch.LongTensor([2,1,0])]),])

postpb = transforms.Compose([transforms.ToPILImage()])

def postb(tensor):
    t = postpa(tensor)
    t[t>1] = 1
    t[t<0] = 0
    im = postpb(t)
    return im

def image_loader(image_name):
    image = Image.open(image_name)
    image = prep(image)
    image = image.unsqueeze(0)
    return image

style_image = image_loader('NightSky.jpg')
content_image = image_loader('River.jpg')

opt_image = content_image.data.clone().requires_grad_(True)

搭建网络框架

  • 由于我们采用VGG19框架,所以加载其features层即可,同时需将其参数固定,不需要优化,因此requires_grad设置为False
vgg = models.vgg19(pretrained=True).features
for param in vgg.parameters():
    param.requires_grad_(False)

构建内容损失与风格损失

  • 内容损失我们采用特定层上的均方根误差MSELoss
  • 风格损失我们采用跨多层计算的特征平面格拉姆(Gram)矩阵均方误差GramMatrix类,之后构成StyleLoss
  • 在具体实现过程中需要提取VGG19网络的特定层特征,即LayerActivations类,之后用extract_layers函数进行调用
  • 由于我们不需要更新风格与内容图片,所以提取特征后需将输出与原始图片解绑,用到detach()函数
  • 最终将风格与内容损失加入到一个列表中
class GramMatrix(nn.Module):
    def forward(self, inp):
        b, c, h, w = inp.size()
        features = inp.view(b, c, h*w)
        gram_matrix = torch.bmm(features, features.transpose(1,2))
        gram_matrix.div_(h*w)
        return gram_matrix
    
class StyleLoss(nn.Module):
    def forward(self, inputs, targets):
        out = nn.MSELoss()(GramMatrix()(inputs), targets)
        return out
    
class LayerActivations():
    features = []
    def __init__(self, model, layer_nums):
        self.hooks = []
        for layer_num in layer_nums:
            self.hooks.append(model[layer_num].register_forward_hook(self.hook_fn))
            
    def hook_fn(self, module, inp, outp):
        self.features.append(outp)
        
    def remove(self):
        for hook in self.hooks:
            hook.remove()
            
def extract_layers(layers, img, model=None):
    la = LayerActivations(model, layers)
    la.features = []
    out = model(img)
    la.remove()
    return la.features

style_layers = [1, 6, 11, 20, 25]
content_layers = [21]
loss_layers = style_layers + content_layers
style_weights = [1*10**3/n**2 for n in [64, 128, 256, 512, 512]]
conten_weights = [1]
weights = style_weights + conten_weights
content_target = extract_layers(content_layers, content_image, model = vgg)
style_target = extract_layers(style_layers, style_image, model = vgg)
content_target = [t.detach() for t in content_target]
style_target = [GramMatrix()(t).detach() for t in style_target]
target = style_target + content_target

loss_fn = [StyleLoss()]*len(style_layers) + [nn.MSELoss()] * len(content_layers)

构建优化器进行最终训练

由于我们要对优化图片进行训练,仅给该图片变量提供参数进行训练即可,因此创建优化该变量的优化器LBFGS

optimizer = torch.optim.LBFGS([opt_image])

最终将上述所构建的各个模块串联起来即构成主程序:

max_iter = 500
show_iter = 10
n_iter = [0]
while n_iter[0] <= max_iter:
    def closure():
        optimizer.zero_grad()
        out = extract_layers(loss_layers, opt_image, model = vgg)
        layer_losses = [weights[a] * loss_fn[a](A, target[a]) for a, A in enumerate(out)]
        loss = sum(layer_losses)
        loss.backward()
        n_iter[0] += 1
        if n_iter[0] % show_iter == (show_iter - 1):
            print('Iteration: ', n_iter[0],'\nLoss: ', loss.data)
        return loss
    optimizer.step(closure)


opt_img = opt_image.squeeze(0)
opt_img = postb(opt_img)
opt_img.save('new_style.jpg')

在这里迭代了500次,最终得到的结果相当不错,即本文开始时举例演示部分的结果:

你可能感兴趣的