首页 GFPGAN源码分析-第六篇

GFPGAN源码分析-第六篇

举报
开通vip

GFPGAN源码分析-第六篇     GFPGAN源码分析第六篇          2021SC@SDUSC源码:archs\gfpganv1_clean_arch.py本篇主要分析gfpganv1_clean_arch.py下的classGFPGANv1Clean(nn.Module)类_init_()方法目录classGFPGANv1Clean(nn.Module)init()(1)channels的设置(2)调用torch.nn.Conv2d()创建了一层卷积神经网络(3)下采样(downsample)(4)上采样(upsample)(...

GFPGAN源码分析-第六篇
     GFPGAN源码分析第六篇          2021SC@SDUSC源码:archs\gfpganv1_clean_arch.py本篇主要分析gfpganv1_clean_arch.py下的classGFPGANv1Clean(nn.Module)类_init_()方法目录classGFPGANv1Clean(nn.Module)init()(1)channels的设置(2)调用torch.nn.Conv2d()创建了一层卷积神经网络(3)下采样(downsample)(4)上采样(upsample)(5)全连接层(6)创建self.stylegan_decoder(7)如果decoder_load_path不为空则读取(8)forSFT(SFTlayer)classGFPGANv1Clean(nn.Module)        继承自nn.Module类,使得我们可以使用很多现成的类,比如本类中使用的Conv2d以及RelU激活函数等等。init()参数:self,out_size,num_style_feat=512,channel_multiplier=1,decoder_load_path=None,fix_decoder=True,#forstylegandecodernum_mlp=8,input_is_latent=False,different_w=False,narrow=1,sft_half=False在classGFPGANer()-init()中被调用时:self.gfpgan=GFPGANv1Clean(out_size=512,num_style_feat=512,channel_multiplier=channel_multiplier,decoder_load_path=None,fix_decoder=False,num_mlp=8,input_is_latent=True,different_w=True,narrow=1,sft_half=True)(1)channels的设置实际调用的时候narrow=1,channels保存了经过convolution层后的输出的通道数unet_narrow=narrow*0.5channels={'4':int(512*unet_narrow),'8':int(512*unet_narrow),'16':int(512*unet_narrow),'32':int(512*unet_narrow),'64':int(256*channel_multiplier*unet_narrow),'128':int(128*channel_multiplier*unet_narrow),'256':int(64*channel_multiplier*unet_narrow),'512':int(32*channel_multiplier*unet_narrow),'1024':int(16*channel_multiplier*unet_narrow)}(2)调用torch.nn.Conv2d()搭建卷积神经网络#out_size=512,solog_size=9self.log_size=int(math.log(out_size,2))#first_out_size=512first_out_size=2**(int(math.log(out_size,2)))#channels['512']=32*2*0.5=32self.conv_body_first=nn.Conv2d(3,channels[f'{first_out_size}'],1)在这里介绍一下nn.Conv2d()的几个参数in_channels:int,#输入的通道数目【必选】out_channels:int,#输出的通道数目【必选】kernel_size:_size_2_t,#卷积核的大小,类型为int(方形边长)或者元组(长和宽)【必选】stride:_size_2_t=1,#步长padding:Union[str,_size_2_t]=0,#边界增益,可以控制输出结果的尺寸dilation:_size_2_t=1,#控制卷积核之间的间距groups:int=1,bias:bool=True,padding_mode:str='zeros',#TODO:refinethistypedevice=None,dtype=None那么可以得知self.conv_body_first=nn.Conv2d(3,channels[f'{first_out_size}'],1)#实际上是传入通道为3(RGB)的输入,使用边长为1的卷积核,最后获得通道为32的输出#由于卷积核边长为1,我们输入与输入的图片大小仍然保持一致,但增加了通道数(3)下采样(downsample)可以看到实际上是调用ResBlock做了下采样#输入图片的通道数(实际为32)in_channels=channels[f'{first_out_size}']#创建ModuleList容器self.conv_body_down=nn.ModuleList()#i从self.log_size(9)->3:7次循环foriinrange(self.log_size,2,-1):out_channels=channels[f'{2**(i-1)}']#调用ResBlock残差网络做下采样,并将该module添加到设置的ModuleListself.conv_body_down.append(ResBlock(in_channels,out_channels,mode='down'))#这一层的输出管道数作为下一层输入的管道数in_channels=out_channels介绍一下nn.ModuleList()nn.ModuleList,它是一个储存不同module,并自动将每个module的parameters添加到网络之中的容器。你可以把任意nn.Module的子类(比如nn.Conv2d,nn.Linear之类的)加到这个list里面,方法和Python自带的list一样,无非是extend,append等操作。但不同于一般的list,加入到nn.ModuleList里面的module是会自动注册到整个网络上的,同时module的parameters也会自动添加到整个网络中。#注意nn.ModuleList则没有实现内部forward函数,所以需要手动实现最后一层卷积层的搭建:#最终输出通道数为channels['4']=256,使用边长为3的卷积核,步长为1,padding为1,保证维度不变self.final_conv=nn.Conv2d(in_channels,channels['4'],3,1,1)(4)上采样(upsample)#输入通道数为channels['4']=256,即下采样的输出的通道数in_channels=channels['4']#创建ModuleList容器self.conv_body_up=nn.ModuleList()#i从3->self.log_size(9):7次循环foriinrange(3,self.log_size+1):#定义输出的通道数out_channels=channels[f'{2**i}']#调用带有上采样ResBlock残差网络,并将该module添加到设置的ModuleListself.conv_body_up.append(ResBlock(in_channels,out_channels,mode='up'))#这一层的输出管道数作为下一层输入的管道数in_channels=out_channels(5)全连接层根据传入的参数different_w,选择每个输出样本的大小,并搭建相应的全连接层。ifdifferent_w:#16*512=8192linear_out_channel=(int(math.log(out_size,2))*2-2)*num_style_featprint(linear_out_channel)else:#512linear_out_channel=num_style_feat#全连接层sizeofeachinputsample:4096,sizeofeachoutputsample:8192self.final_linear=nn.Linear(channels['4']*4*4,linear_out_channel)(6)创建self.stylegan_decoderself.stylegan_decoder=StyleGAN2GeneratorCSFT(out_size=out_size,num_style_feat=num_style_feat,num_mlp=num_mlp,channel_multiplier=channel_multiplier,narrow=narrow,sft_half=sft_half)(7)如果decoder_load_path不为空则读取ifdecoder_load_path:self.stylegan_decoder.load_state_dict(torch.load(decoder_load_path,map_location=lambdastorage,loc:storage)['params_ema'])iffix_decoder:forname,paraminself.stylegan_decoder.named_parameters():param.requires_grad=False(8)forSFT(SFTlayer)#ModuleListself.condition_scale=nn.ModuleList()self.condition_shift=nn.ModuleList()#i从3->self.log_size(9):7次循环foriinrange(3,self.log_size+1):#定义输出的通道数out_channels=channels[f'{2**i}']#输出通道数是否减半ifsft_half:sft_out_channels=out_channelselse:sft_out_channels=out_channels*2#使用nn.Sequential搭建网络,并添加到ModuleListself.condition_scale.append(nn.Sequential(#卷积核边长为3,步长为1,输出与输出保持相同维度nn.Conv2d(out_channels,out_channels,3,1,1),nn.LeakyReLU(0.2,True),nn.Conv2d(out_channels,sft_out_channels,3,1,1)))self.condition_shift.append(nn.Sequential(nn.Conv2d(out_channels,out_channels,3,1,1),nn.LeakyReLU(0.2,True),nn.Conv2d(out_channels,sft_out_channels,3,1,1)))nn.Sequential是一个有序的容器,其中传入的是构造器类(各种用来处理input的类),最终input会被Sequential中的构造器依次执行。 -全文完-
本文档为【GFPGAN源码分析-第六篇】,请使用软件OFFICE或WPS软件打开。作品中的文字与图均可以修改和编辑, 图片更改请在作品中右键图片并更换,文字修改请直接点击文字进行修改,也可以新增和删除文档中的内容。
该文档来自用户分享,如有侵权行为请发邮件ishare@vip.sina.com联系网站客服,我们会及时删除。
[版权声明] 本站所有资料为用户分享产生,若发现您的权利被侵害,请联系客服邮件isharekefu@iask.cn,我们尽快处理。
本作品所展示的图片、画像、字体、音乐的版权可能需版权方额外授权,请谨慎使用。
网站提供的党政主题相关内容(国旗、国徽、党徽..)目的在于配合国家政策宣传,仅限个人学习分享使用,禁止用于任何广告和商用目的。
下载需要: 免费 已有0 人下载
最新资料
资料动态
专题动态
个人认证用户
资教之佳
暂无简介~
格式:doc
大小:94KB
软件:Word
页数:15
分类:互联网
上传时间:2023-06-20
浏览量:6