非常简洁的图像复原新方法:退化分类预训练,已中ICLR2025

Universal Image Restoration Pre-training via Degradation Classification

论文地址:https://openreview.net/forum?id=PacBhLzeGO

代码地址:https://github.com/MILab-PKU/dcpt

背景

图像复原是利用模型将低质量(LQ)图像改进为高质量(HQ)图像的任务,在深度学习时代,图像复原任务可以被进一步理解为:以低质量图像为条件生成高质量图像

通用图像复原(Universal Image Restoration, UIR)任务是图像复原的一项重要的子任务。UIR 试图创造一种方法,使得模型能够自主的应对不同退化,并生成语义、细节纹理一致的高质量图像。可以简单地认为,一个合格的UIR模型应当包含以下两种能力:

  • 退化判别:用于提升模型对输入低质量图像的退化的鉴别能力,使得模型能够“自如”地使用自身参数进行自适应复原(这种解释的正确性有待商榷,但已经有大量文献证明退化判别能力的引入有助于图像复原性能增长)
  • 生成高质量图像:生成prior将有助于复原能力的提升,尤其在输入图像退化极其严重的情况下。在干净、高质量图像数据集下训练的生成模型,能够促进复原模型恢复出干净、高质量的图像。

这导向了两种不同的通用图像复原方法设计思路:(1)促进退化判别;(2)引入生成Prior。其中前者已经被得到广泛的研究。流行的方法使用输入图像的退化表征作为判别提示,如:梯度、频率、附加参数和经神经网络压缩的抽象特征等等。虽然这些方法通过使用精确有效的退化提示获得了很高的复原性能,但它们未能利用复原模型本身所蕴含的潜在先验信息

DCPT的诞生来源于对复原模型自我退化判别能力的分析。

发现

我们对复原模型自身的退化判别能力进行了分析,并得到三个有趣的发现:

  1. 随机初始化模型显示出对退化进行分类的内在能力;
  2. 在一体化(All-in-one)复原任务中训练的模型表现出辨别未知退化的能力;
  3. 在修复模型的早期训练中,有一个退化理解步骤。

我们进行了一项简单的预实验来说明这三点:我们提取了复原训练过程中网络复原头之前的输出特征,训练过程中,模型仅见到雾霾、雨、高斯噪声三种退化。根据该特征, kNN 分类器将对五种退化类型(包括雾霾、雨天、高斯噪声、运动模糊和弱光)进行分类。

预实验结果如下:

Methods
NAFNet
SwinIR
Restormer
PromptIR
Acc. on Random initialized (%)
52 ± 1
64 ± 4
71 ± 4
55 ± 3
Acc. on 3D all-in-one trained 200k iterations (%)
90 ± 5
92 ± 6
93 ± 3
93 ± 5
Acc. on 3D all-in-one trained 400k iterations (%)
94 ± 4
95 ± 4
95 ± 4
95 ± 4
Acc. on 3D all-in-one trained 600k iterations (%)
94 ± 5
95 ± 4
97 ± 2
95 ± 4

可以看到四种网络在网络初始化时就表现出52%~71%的分类准确率,且在复原训练过程早期(前200k次迭代)快速收敛到90%以上的分类准确率。

  • 当退化数量进一步增多,…

    遗憾的是,我们发现复原模型对未知退化的辨别能力会随着退化种类的增多而逐渐减弱。我们将在后续工作中对此进行更充分的讨论。

动机

由于图像复原的核心任务还是以低质量图像为条件生成高质量图像,我们不希望在复原训练过程中出现与该任务存在潜在冲突的其他训练子任务,例如退化分类。于是,我们选择将显式地将该训练阶段提前为“预训练”,并进一步创造了DCPT。

方法

Degradation Classification Pre-Training (DCPT) 是一个简单且有效的方法,可见下图。

在单次迭代中,它包含两个阶段:退化分类阶段、生成阶段,这两个阶段交替进行。其中,

  • 退化分类阶段:通过提取复原网络的深层特征,并将其输入一个轻量级分类器,以对输入图像的退化种类进行分类。
  • 生成阶段:我们利用最原始的Autoencoder手段对复原模型的生成能力进行保留。

实现代码也非常简洁:

### train to generate the clean image
encoder.train()
decoder.eval()
optimizer_encoder.zero_grad()
pix_output = encoder(gt, hook=False)
l_total = 0
# pixel loss
if cri_pixel:
 l_pix = cri_pixel(pix_output, gt)
 l_total += l_pix

### train to classify the degradation
decoder.train()
optimizer_decoder.zero_grad()
hook_outputs = encoder(lq, hook=True)
cls_output = decoder(lq, hook_outputs[::-1])
# classification loss
if cri_cls:
 l_cls = cri_cls(cls_output, dataset_idx)
 l_total += l_cls

l_total.backward()
optimizer_encoder.step()
optimizer_decoder.step()

需要注意,在预训练结束后,仍需要进行复原任务上的fine-tune。

实验结果

5D All-in-one image restoration

Method
Dehazing
Deraining
Denoising
Deblurring
Low-Light
Average

on SOTS
on Test100L
on BSD68
on GoPro
on LOL


PSNR / SSIM
PSNR / SSIM
PSNR / SSIM
PSNR / SSIM
PSNR / SSIM
PSNR / SSIM
AirNet
21.04 / 0.884
32.98 / 0.951
30.91 / 0.882
24.35 / 0.781
18.18 / 0.735
25.49 / 0.846
IDR
25.24 / 0.943
35.63 / 0.965
31.60 / 0.887
27.87 / 0.846
21.34 / 0.826
28.34 / 0.893
InstructIR
27.10 / 0.956
36.84 / 0.973
31.40 / 0.887
29.40 / 0.886
23.00 / 0.836
29.55 / 0.907
SwinIR
21.50 / 0.891
30.78 / 0.923
30.59 / 0.868
24.52 / 0.773
17.81 / 0.723
25.04 / 0.835
DCPT-SwinIR 28.67

 / 0.973
35.70

 / 0.974
31.16

 / 0.882
26.42

 / 0.807
20.38

 / 0.836
28.47

 / 0.894
NAFNet
25.23 / 0.939
35.56 / 0.967
31.02 / 0.883
26.53 / 0.808
20.49 / 0.809
27.76 / 0.881
DCPT-NAFNet 29.47

 / 0.971
35.68

 / 0.973
31.31

 / 0.886
29.22

 / 0.883
23.52

 / 0.855
29.84

 / 0.914
Restormer
24.09 / 0.927
34.81 / 0.962
31.49 / 0.884
27.22 / 0.829
20.41 / 0.806
27.60 / 0.881
DCPT-Restormer 29.86

 / 0.973
36.68

 / 0.975
31.46

 / 0.888
28.95

 / 0.879
23.26

 / 0.842
30.04

 / 0.911
PromptIR
25.20 / 0.931
35.94 / 0.964
31.17 / 0.882
27.32 / 0.842
20.94 / 0.799
28.11 / 0.883
DCPT-PromptIR 30.72

 / 0.977
37.32

 / 0.978
31.32

 / 0.885
28.84

 / 0.877
23.35

 / 0.840
30.31

 / 0.911

可以看出,无论是 CNN 网络还是 Transformer 网络,无论是直线网络还是类 UNet 网络,DCPT在 5D All-in-one image restoration 任务上的平均性能提升始终保持在 2.08 dB 及以上。

我们也展示一些可视化数据,以证明DCPT也确实能提升输出图像的视觉感观。

10D All-in-one image restoration

我们选取了十种退化进行试验,并绘制了雷达图。

可以看到使用DCPT预训练后,NAFNet的PSNR与SSIM指标都有比较显著的提升。具体数值指标如下:

Method
Average
AirNet
26.41 / 0.842
TransWeather
22.83 / 0.779
WeatherDiff
24.60 / 0.793
PromptIR
27.93 / 0.851
DiffUIR-L
28.75 / 0.869
NAFNet
27.17 / 0.837
+ DACLIP
27.42 (+ 0.25) / 0.798 (- 0.039)
+ Instruct
28.30 (+ 1.13) / 0.862 (+ 0.025)
DCPT (Ours)
29.72 (+ 2.55) / 0.888 (+ 0.051)

Mixed degradation

我们也在混合退化场景下进行研究,我们使用了CDD数据集,结果如下:

Transfer learning

众所周知,图像复原模型的过拟合现象严重。在A退化任务下训练的复原模型极难泛化到B退化任务。我们发现DCPT中的退化分类器有助于模型跨任务泛化。

为此,我们首先设计了DC-guided training,如下图所示:

我们冻结在DCPT阶段训练的退化分类解码器,通过让模型对退化、非退化图像进行二分类,让模型知晓自身在处理何种退化(何种复原任务),从而增强任务间泛化能力。泛化结果如下:

更多其他实验结果请关注我们的文章(https://openreview.net/forum?id=PacBhLzeGO)。

讨论

  • Q: 在复原训练中,仅仅只是存在退化分类嘛?

A. 否定。或许可以在实验结果中窥见一二。在5D All-in-one任务中,SwinIR的去雾结果展示出明显的“区域性”。我们猜测,对于退化在全局图像中不均匀出现的情况,复原模型也会对输入图像的退化进行“分割”。

  • 复原中隐藏着(退化)辨别

之前的研究(https://arxiv.org/abs/2108.00406)已经调查了超分辨率模型在复原过程中区分不同类型退化的能力。DCPT的初步实验也表明,随机初始化模型能够对退化进行分类。此外,一体化复原训练也能增强了模型的退化分类能力,并赋予复原模型在退化分类任务上的泛化能力。这些结果表明,复原中隐藏着(退化)辨别

DCPT的实验结果凸显了判别先验在图像复原预训练中的有效性。这些结果表明,在训练前将足够的判别信息纳入模型可以显著提高其性能。我们假设,在复原模型中加入卓越的降解感知判别信息,并最大限度地提高其判别能力,将进一步提高模型的复原性能。预计这一假设将为通用复原领域开发大量新型预训练方法铺平道路。


(文:GiantPandaCV)

欢迎分享

发表评论