您正在使用IE低版浏览器,为了您的雷峰网账号安全和更好的产品体验,强烈建议使用更快更安全的浏览器
此为临时链接,仅用于文章预览,将在时失效
人工智能开发者 正文
发私信给汪思颖
发送

0

从一次 CycleGAN 实现聊聊 TF

本文作者:汪思颖 2017-11-20 19:16
导语:用 TensorFlow 实现 CycleGAN 时需要注意的小技巧

雷锋网 AI科技评论按,本文作者Coldwings,该文首发于知乎专栏为爱写程序,雷锋网 AI科技评论获其授权转载。以下为原文内容,有删减。

CycleGAN是个很有趣的想法(Unpaired Image-to-Image Translationusing Cycle-Consistent Adversarial Networks [https://arxiv.org/pdf/1703.10593.pdf]),看完这篇论文之后,隐隐地觉得,这后面有更多的内容可以挖,我尽我所能做出了各种尝试,努力发掘更多的可能性。

实现过程中可以说还是略微纠结的,最初是用Keras快速实践了一下,然而其实并不『快速』,后来反倒是用TensorFlow重写以及尝试各种意外想法时才感觉,当需要处理一些比较复杂的网络结构、训练流程甚至op时,TF提供的可以细化到每个操作的体验实际上要比各种上层API都来得更好,而结合TensorBoard,可视化的训练将取得更好的效果。当然,我对Torch无感,或许用Torch能有更好的体验,但我不擅长这个;Chainer(A flexible framework for neural networks)讲道理写出来的代码会更好看,但是似乎身边用的人并不多,姑且放过。

这篇文章倒不是来介绍什么是CycleGAN的,若是不甚了解,我妻子将会将她的发表整理一下再发布出来(CycleGAN(以及DiscoGAN和DualGAN)简介 - 知乎专栏)。这一阵的尝试中,我自己也对GAN,对Generator中的图像甚至其它东西的生成,以及单纯从写代码角度来看,怎么管理TF里的变量,怎么把代码写得好看,怎么更好地利用TensorBoard都有了更多地理解,算是不小的提高吧……

所以这里也就大概提一提一些实现中需要注意的小技巧吧。(虽然我觉得大概大多数真正拿着TF搞DL研究的人都不需要研究这篇文章)

CycleGAN比较麻烦的地方

其实CycleGAN麻烦的地方不少,这是一个挺复合的模型:两个Generator,两个Discriminator,这已经是四个比较简单的网络了(是的,考虑到所有可能性,Generator和Discriminator完全可以各自都有两种不同的结构);一组Generator+Discriminator复合成一个GAN,又一层复合模型,并且GAN的训练还得控制,由于G和D的损失相反,训练G时需要控制D的变量让其不可训练;我们还要让Cycle loss作为模型loss的一部分,这个更高一层的复合模型由两个GAN组成……

良好的代码结构

TensorFlow的自由度挺高的,类比的话,有那么点DL框架里的C++的意思;Python的语言灵活度也是高得不行,两个很灵活的玩意放一起,写个简单模型自然想怎么玩就怎么玩,写个复杂一些的模型,为了保证写着方便,用着方便,改起来方便,还是需要比较好的代码结构的。

如果翻翻GitHub上一些比较热的用TF写的模型,通常都会发现大家比较习惯于把代码分成op、module和model三个部分。

op里是一些通用层或者运算的简化定义,例如写个卷积层,总是包含定义变量和定义运算。习惯于Keras这样不需要自己定义变量的玩意当然不会太纠结,但用TF时,若是写两行定义一下变量总是挺让人伤神的。

如果参照Keras的实现,通过写个类来定制op,变量管理看起来方便一点,未免太过繁琐。实际上TF提供的variable scope已经非常方便了,这一部分写成这样似乎也不错

def conv2d(input, filter, kernel, strides=1, stddev=0.02, name='conv2d'):
   with tf.variable_scope(name):
       w = tf.get_variable(
           'w',
           (kernel, kernel, input.get_shape()[-1], filter),
           initializer=tf.truncated_normal_initializer(stddev=stddev)
       )
       conv = tf.nn.conv2d(input, w, strides=[1, strides, strides, 1], padding='VALID')
       b = tf.get_variable(
           'b',
           [filter],
           initializer=tf.constant_initializer(0.0)
       )
       conv = tf.reshape(tf.nn.bias_add(conv, b), tf.shape(conv))
       return conv

这样定义几个op之后,写起代码来就更有点类似于mxnet那样的感觉了。

特别的,有些时候有些简单结构,例如ResNet中的一个block这样的玩意,我们也可以用类似的方式,用一个简单函数包装起来

def res_block(x, dim, name='res_block'):
   with tf.variable_scope(name):
       y = reflect_pad(x, name='rp1')
       y = conv2d(y, dim, 3, name='conv1')
       y = lrelu(y)
       y = reflect_pad(y, name='rp2')
       y = conv2d(y, dim, 3, name='conv2')
       y = lrelu(y)
       return tf.add(x, y)

对于重复的模块,这样的包装也方便多次使用。

这些是很常见的做法。同时我们也发现了,几乎每个这样的函数里都少不了一个variable scope的使用,一方面避免定义变量时名字的重复以及训练时变量的管理,另一方面也方便TensorBoard画图的时候能把有用的东西放到一起。但这样每个函数里带个name参数的做法写多了也会烦,加上奇怪的缩进……我会更倾向于用一个装饰器来解决这样的问题,同时也能减少『忘了用variable scope』的情况。

def scope(default_name):
   def deco(fn):
       def wrapper(*args, **kwargs):
           if 'name' in kwargs:
               name = kwargs['name']
               kwargs.pop('name')
           else:
               name = default_name
           with tf.variable_scope(name):
               return fn(*args, **kwargs)
       return wrapper
   return deco@scope('conv2d')def conv2d(input, filter, kernel, strides=1, stddev=0.02):
   w = tf.get_variable(
       'w',
       (kernel, kernel, input.get_shape()[-1], filter),
       initializer=tf.truncated_normal_initializer(stddev=stddev)
   )
   conv = tf.nn.conv2d(input, w, strides=[1, strides, strides, 1], padding='VALID')
   b = tf.get_variable(
       'b',
       [filter],
       initializer=tf.constant_initializer(0.0)
   )
   conv = tf.reshape(tf.nn.bias_add(conv, b), tf.shape(conv))
   return conv

至于module,也就是一些稍微复杂的成型结构,例如GAN里的Discriminator和Generator,讲道理这玩意其实和op大体上是类似的,就不多说了。

最后是model。通常大家都是用类来做,因为model中往往还包含了输入数据用的placeholder、训练用的op,甚至一些具体的方法等等内容。这一块的代码建议,只不过是最好先写一个抽象类,把需要的几个接口给定义一下,然后让实际的model类继承,代码会漂亮很多,也更便于利用诸如PyCharm这样的IDE来提示你哪些东西该做而没有做。

关于config/options

网上常见的代码里,模型的一些参数信息大都设计成用命令行参数来传入,更多是直接使用tf.flags来处理。但无论如何,我仍然觉得定义一个config类来管理参数是有一定必要性的,直接使用tf.flags主要是是有大段tf.flags.DEFINE_xxx,不好看,也不方便直观地反应默认参数。相对的,如果定义一个参数类,在__init__里写下默认参数,然后写个小方法自动地根据dir来添加这些tf.flags会漂亮许多。但这个只是个人观点,似乎并没有具体的优劣之分。

关于TensorBoard

不得不说TensorBoard作为TF自带的配套可视化工具,只要你不是太在意刷新频率的问题(通常不会有人在意这个吧……),用起来实在太方便。加上能够自动生成运算的各个符号的结构图,哪怕不说训练,就是检查模型结构是否符合自己所想都是个非常好用的工具。比如封面图,生成出来用来检查代码的模型逻辑,还可以根据需要点选观察依赖关系。

从一次 CycleGAN 实现聊聊 TF

顺带一提,如果生成的模型图长得非常奇怪,八成是代码有问题……

不过要用好TensorBoard,有几个小小的要点:首先是,至少,你的各个op和module里,得用上variable scope或者name scope。对于一个scope,在TensorBoard的Graph里会将其聚集成一个小块,内部结构可以展开观察,而如果不用scope,你会看到满眼都是一堆一堆的基本op,当模型复杂时,图基本没法看……

此外,对于图片处理,用好TensorBoard的ImageSummary当然是很不错的选择。但是记得一定要为添加图片的summary op定义一个喂数据的placeholder。

self.p_img = tf.placeholder(tf.float32, shape=[1, 256 * 6, 256 * 4, 3])

self.img_op = tf.summary.image('sample', self.p_img)

……

img = np.array([img])

s_img = self.sess.run(self.img_op, feed_dict={self.p_img: img})

self.writer.add_summary(s_img, count)

这样才是正确的。网上有些材料里告诉你可以直接用tf.summary.image('tag', data)来生成图片summary,这样其实每次都会构造一个新的summary,不便于图片归类,但更大的问题是,这样做会使得每次都申请一个新的变量(用来装你的图片数据),倘若你有定周期存储训练权重的习惯,会发现没几个小时就会因为权重变量总量超过2GB而使得程序跑崩……想想看晚上跑着训练的代码想着可以回家休息了,结果前脚刚进家门,程序就罢工了,大好的训练时间就给直接浪费了。

另外,这里的图片可以是重新归为0~255的整形的数据,也可以直接给浮点数据[-1, 1]。更不错的想法是,先使用matplotlib/pil/numpy来合成、拼凑甚至生成图像,然后再来添加,会让效果更令人满意,比如这样:

从一次 CycleGAN 实现聊聊 TF

最后补充一句……双显示器确实有利于提高写代码、改代码以及码字的效率……

雷峰网版权文章,未经授权禁止转载。详情见转载须知

从一次 CycleGAN 实现聊聊 TF

分享:
相关文章

编辑

关注AI学术,例如论文
当月热门文章
最新文章
请填写申请人资料
姓名
电话
邮箱
微信号
作品链接
个人简介
为了您的账户安全,请验证邮箱
您的邮箱还未验证,完成可获20积分哟!
请验证您的邮箱
立即验证
完善账号信息
您的账号已经绑定,现在您可以设置密码以方便用邮箱登录
立即设置 以后再说