当前位置:网站首页>Realbasicvsr source code analysis

Realbasicvsr source code analysis

2022-06-09 00:20:00 weixin_ forty-one million twelve thousand three hundred and nin

Some source code examples :
Model training code :mmedit/models/restorers/real_basicvsr.py
Data processing code :mmedit/datasets/pipelines/random_degradations.py

In the configuration file lq Random degradation of data , Input to generation network , Identify the network and then identify the output of the generated network and the original image .

@MODELS.register_module()
class RealBasicVSR(RealESRGAN):
    """RealBasicVSR model for real-world video super-resolution. Ref: pretrained (str): Path for pretrained model. Default: None. """

    def __init__(self,
                 generator,
                 discriminator=None,
                 gan_loss=None,
                 pixel_loss=None,
                 cleaning_loss=None,
                 perceptual_loss=None,
                 is_use_sharpened_gt_in_pixel=False,
                 is_use_sharpened_gt_in_percep=False,
                 is_use_sharpened_gt_in_gan=False,
                 is_use_ema=True,
                 train_cfg=None,
                 test_cfg=None,
                 pretrained=None):

        super().__init__(generator, discriminator, gan_loss, pixel_loss,
                         perceptual_loss, is_use_sharpened_gt_in_pixel,
                         is_use_sharpened_gt_in_percep,
                         is_use_sharpened_gt_in_gan, is_use_ema, train_cfg,
                         test_cfg, pretrained)

        self.cleaning_loss = build_loss(
            cleaning_loss) if cleaning_loss else None

    def train_step(self, data_batch, optimizer):
        """Train step.
		...
		        # data
        lq = data_batch['lq']
        gt = data_batch['gt']
        # generator
        fake_g_output, fake_g_lq = self.generator(lq, return_lqs=True)       ## Low quality pictures , Input to generation network , obtain fake picture .fake_g_output It should be the output of the model ,fake_g_lq Not sure what it is .
        losses = dict()
        log_vars = dict()

	  fake_g_output = fake_g_output.view(-1, c, h, w)     ## Output transform the dimension 
	  if (self.step_counter % self.disc_steps == 0
                and self.step_counter >= self.disc_init_steps):

### Choose one of them loss Calculation 
            if self.pixel_loss:
                losses['loss_pix'] = self.pixel_loss(fake_g_output, gt_pixel)
            if self.cleaning_loss:
                losses['loss_clean'] = self.cleaning_loss(fake_g_lq, gt_clean)
            if self.perceptual_loss:
                loss_percep, loss_style = self.perceptual_loss(
                    fake_g_output, gt_percep)
                if loss_percep is not None:
                    losses['loss_perceptual'] = loss_percep
                if loss_style is not None:
                    losses['loss_style'] = loss_style

            # gan loss for generator, Let the discriminator identify fake data 
            if self.gan_loss:
                fake_g_pred = self.discriminator(fake_g_output)
                losses['loss_gan'] = self.gan_loss(
                    fake_g_pred, target_is_real=True, is_disc=False)

            # parse loss
            loss_g, log_vars_g = self.parse_losses(losses)
            log_vars.update(log_vars_g)

            # optimize
            optimizer['generator'].zero_grad()
            loss_g.backward()
            optimizer['generator'].step()

        # discriminator
        if self.gan_loss:
            set_requires_grad(self.discriminator, True)
            # real
            real_d_pred = self.discriminator(gt_gan)
            loss_d_real = self.gan_loss(
                real_d_pred, target_is_real=True, is_disc=True)
            loss_d, log_vars_d = self.parse_losses(
                dict(loss_d_real=loss_d_real))
            optimizer['discriminator'].zero_grad()
            loss_d.backward()
            log_vars.update(log_vars_d)

            # fake
            fake_d_pred = self.discriminator(fake_g_output.detach())
            loss_d_fake = self.gan_loss(
                fake_d_pred, target_is_real=False, is_disc=True)
            loss_d, log_vars_d = self.parse_losses(
                dict(loss_d_fake=loss_d_fake))
            loss_d.backward()
            log_vars.update(log_vars_d)

            optimizer['discriminator'].step()

        self.step_counter += 1

        log_vars.pop('loss')  # remove the unnecessary 'loss'
        outputs = dict(
            log_vars=log_vars,
            num_samples=len(gt.data),
            results=dict(lq=lq.cpu(), gt=gt.cpu(), output=fake_g_output.cpu()))

        return outputs
原网站

版权声明
本文为[weixin_ forty-one million twelve thousand three hundred and nin]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/160/202206090018169429.html