I'm learning to use Pytorch Write GAN Code , Found that some of the code in the training part of the details are slightly different , Some of them use detach() Function truncates gradient flow , Some people are useless detch(), Instead, the loss function in the back propagation process will backward(retain_graph=True), In this paper, through two gan Code for , Introduce their role , And analyze , The effect of different update strategies on program efficiency .
these two items. GAN In the implementation of , There are two different training strategies :
- First train the discriminator (discriminator), Retraining generator (generator), This is the original paper Generative Adversarial Networks Algorithm in
- Train first generator, Retraining discriminator
To reduce Internet spam ,GAN There's a lot on the Internet , I won't repeat it here , Want to know more about GAN Friends of principle , You can refer to my special article : Neural network structure : Generative adversary network (GAN).
Knowledge needed to understand :
detach(): truncation node Back propagation of gradient flow , Will be a node It doesn't need gradients Varibale, So when back propagation goes through this node when , The gradient doesn't come from this node Spread to the front .
Update strategy
Let's go straight to the subject of this article , namely , stay pytorch in ,detach and retain_graph What is it for ? This article will use three paragraphs GAN Implementation code , Here's an example of how they work .
First train the discriminator , Retraining generator
A strategy
Let's analyze one of the loops step Code for :
valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # Real label , All are 1 fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # Fake labels , All are 0 # ######################## # Training discriminator # # ######################## real_imgs = imgs.to(device) # Real picture z = torch.randn((imgs.shape[0], 100)).to(device) # noise gen_imgs = generator(z) # Generating false data from noise pred_gen = discriminator(gen_imgs) # The output of discriminator to false data pred_real = discriminator(real_imgs) # The output of the discriminator to the true data optimizer_D.zero_grad() # Zero the gradient of all parameters in the discriminator real_loss = adversarial_loss(pred_real, valid) # The loss of discriminator to the real sample fake_loss = adversarial_loss(pred_gen, fake) # The loss of discriminator to false samples d_loss = (real_loss + fake_loss) / 2 # Add up the two losses to get the average # The following line of code is very important , The main body will focus on d_loss.backward(retain_graph=True) # retain_graph=True Very important , Otherwise, the memory of the calculation graph will be released optimizer_D.step() # Discriminator parameter update # ######################## # Training generator # # ######################## g_loss = adversarial_loss(pred_gen, valid) # Loss function of generator optimizer_G.zero_grad() # The generator parameter gradient returns to zero g_loss.backward() # The loss function of the generator is gradient back propagation optimizer_G.step() # Generator parameter update
Code explanation
The loss function of the discriminator d_loss By real_loss and fake_loss Composed of , and fake_loss again noise after generator To the . In this way, we have to d_loss Back propagation , Not only can you calculate discriminator And the gradient of the generator Gradient of ( Although this step optimizer_D.step() Update only discriminator Parameters of ), So here's an update generator When parameters are , First of all generator The gradient of the parameter is cleared , Avoid being discriminator loss The gradient effect that comes back .
generator Of Loss on return , The same goes through discriminator The network can pass back to itself ( The system goes from input noise to Discriminator Output , There's only one forward spread from start to finish , And there were two back propagation , So in the first back propagation , The discriminator should be set up backward(retain graph=True), Keep the graph from being released . because pytorch Default A computational graph computes only one backpropagation , After back propagation , The memory of this graph will be released , So we use this parameter to control the graph not to be released . therefore , When you return the gradient , It's also calculated discriminator The gradient of the parameters of , Only this time discriminator Parameters of are not updated , Update only generator Parameters of , namely optimizer_G.step(). meanwhile , We see , next step First of all, will discriminator The gradient of is reset to 0, Just to prevent generator loss Back propagation is affected by the gradient of the in-line calculation ( And the last step discriminator loss Cumulative gradient on return ).
Sum up , We see , In order to complete one step parameter update , We did two back propagation , The first back propagation is for renewal discriminator Parameters of , But it's redundant generator Gradient of . The second back propagation is to update generator Parameters of , But it calculated discriminator Gradient of , So I'm writing a step, It needs to be cleared immediately discriminator gradient .
If you really don't understand , Just write the code in this form , Anyway, the form has been written for you .
Strategy two
I've come across a lot of this strategy , Also train the discriminator first , Retraining generator
Discriminator training stage ,noise from generator Input , Output fake data, then detach once , With true data Type... Together discriminator, Calculation discriminator Loss , And update the discriminator Parameters . Generator training phase , Don't pass by detach Of fake data Input to discriminator in , Calculation generator loss, And then back propagation gradient , to update generator Parameters of . This strategy , Calculated twice discriminator gradient , once generator gradient . I feel this kind of comparison conforms to update first discriminator The habit of . The disadvantage is that , Previous generator The generated graph must be preserved , until discriminator Update complete , Re release .
valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # Real label , All are 1 fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # Fake labels , All are 0 # ######################## # Training discriminator # # ######################## real_imgs = imgs.to(device) # Real picture z = torch.randn((imgs.shape[0], 100)).to(device) # noise gen_imgs = generator(z) # Generating false data from noise pred_gen = discriminator(gen_imgs.detach()) # Fake data detach(), The output of discriminator to false data pred_real = discriminator(real_imgs) # The output of the discriminator to the true data optimizer_D.zero_grad() # Zero the gradient of all parameters in the discriminator real_loss = adversarial_loss(pred_real, valid) # The loss of discriminator to the real sample fake_loss = adversarial_loss(pred_gen, fake) # The loss of discriminator to false samples d_loss = (real_loss + fake_loss) / 2 # Add up the two losses to get the average # The following line of code is very important , The main body will focus on d_loss.backward() # retain_graph=True Very important , Otherwise, the memory of the calculation graph will be released optimizer_D.step() # Discriminator parameter update # ######################## # Training generator # # ######################## g_loss = adversarial_loss(pred_gen, valid) # Loss function of generator optimizer_G.zero_grad() # The generator parameter gradient returns to zero g_loss.backward() # The loss function of the generator is gradient back propagation optimizer_G.step() # Generator parameter update
Train the generator first , Retraining the discriminator
Let's analyze one of the loops step Code for :
valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # The label of the real sample , All are 1 fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # Generate the label of the sample , All are 0 z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # noise real_imgs = Variable(imgs.type(Tensor)) # Real picture # ######################## # Training generator # # ######################## optimizer_G.zero_grad() # The generator parameter gradient returns to zero gen_imgs = generator(z) # Generating false samples from noise g_loss = adversarial_loss(discriminator(gen_imgs), valid) # With real labels + False sample , Computing generator loss g_loss.backward() # Generator gradient back propagation , Back propagation goes through the discriminator , Therefore, the discriminator parameters also have gradients optimizer_G.step() # Generator parameter update , Although the discriminator parameters have gradients , But this step does not update the discriminator # ######################## # Training discriminator # # ######################## optimizer_D.zero_grad() # The generator loss function gradient back propagation , The parameter gradient of the discriminator calculated by the algorithm is cleared real_loss = adversarial_loss(discriminator(real_imgs), valid) # The real sample + Real label : Discriminator loss fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # False sample + Fake labels : Discriminator loss d_loss = (real_loss + fake_loss) / 2 # The total loss function of the discriminator d_loss.backward() # Discriminator loss return optimizer_D.step() # Discriminator parameter update
To update the generator parameters , Calculate the gradient with the loss function of the generator , And back propagation , A discriminator is passed in the propagation diagram , According to the chain rule , We have to calculate the parameter gradient of the discriminator , Although the discriminator parameters are not updated at this step . After back propagation ,noise To fake image Until then discriminator The output of the forward propagation graph is released , There will be no more .
Then update the discriminator parameters , Note at this time , We input the discriminator in two parts , Part of it is real data , The other part is the output of the generator , That's fake data . Pay attention to the details , In the discriminator forward propagation process , The input false data is detach 了 ,detach It means , This data and the calculation diagram that generated it “ Decoupling ” 了 , That is, when the gradient reaches it, it stops , It doesn't continue to spread ( It doesn't actually spread any further , because generator The graph is released after the first backpropagation ). therefore , Discriminator gradient back propagation , It's on itself .
therefore , Compared to the first strategy , This strategy requires less computation generator The gradient of all the parameters of , meanwhile , You don't have to save the graph once , Take up unnecessary memory .
But it should be noted that , In the first strategy ,noise from generator Input , To discriminator Output , There was only one forward spread ,discriminator The output of the terminal , Used twice , One was to calculate discriminator Loss function of , The other is to calculate generator Loss function of .
And in this strategy ,noise from generator Input , To discriminator Output , Calculation generator Loss , Comes back , This step updates generator Parameters of , And released the calculation diagram . Next update discriminator Parameter time ,generator The output of detach after , Passed again discriminator, amount to ,generator The output of the is passed twice discriminator , Get the same output . obviously , It's also redundant .
summary
Sum up , Each of these two pieces of code has its own advantages and disadvantages :
First code , The advantage is noise There was only one forward propagation , The disadvantage is that , to update discriminator When parameters are , One more calculation generator Gradient of , meanwhile , First update discriminator You need to keep the calculation chart , It's guaranteed generator loss The calculation chart is not destroyed .
The third code , The advantage is that by updating generator , So that the updated forward propagation graph can be destroyed easily , So you don't have to keep the calculation graph and take up memory . meanwhile , Updating discriminator When , It's not like the code above , Computationally redundant generator Gradient of . The disadvantage is that , stay discriminator On , Yes generator The output of the is calculated twice forward propagation , For the second time, a new calculation chart was produced ( But smaller than the first one ).
One more calculation generator gradient , One more calculation discriminator Forward propagation . therefore , There is little difference between the two . If discriminator Than generator complex , So the first strategy should be taken , If discriminator Than generator Simple , Then a third strategy should be adopted , Usually ,discriminator than generator Simple , So if the effect is almost the same, try to adopt the third strategy .
But the third one is updated first generator, Update again discriminator It's always weird , because generator We need to update discriminator Provide accurate loss and gradient, Otherwise, it's not a blind update ?
But strategy three , Use it and release it immediately . Comprehensive, , Strategy three is the best , Strategy two, second , Strategy one is the worst ( The difference is to calculate once more generator gradient On , And usually one more calculation generator gradient The amount of calculation is more than one discriminator Forward propagation requires a lot of computation ), therefore ,detach It's necessary .
Reference resources
Pytorch: detach and retain_graph
Use PyTorch Conduct GAN Thinking about gradient truncation in training .detach()