Directly denoising Diffusion model the Simplest deep learning
public
Jun 01, 2025
Never
29
1 #!/usr/bin/env python 2 # coding: utf-8 3 4 '''Directly denoising Diffusion model, the Simplest deep learning.''' 5 '''Version 1.0''' 6 '''python diff.py''' 7 8 __credits__ = ["iBoxDB", "Bruce Yang CL-N", "2025-5"] 9 10 11 # 3rd parts 12 #https://pytorch.org CPU version only 13 import torch 14 15 import torch.nn as nn 16 import torch.utils.data as tdata 17 from torch.optim import Adam,AdamW 18 19 from torchvision.datasets import MNIST, CIFAR10 20 from torchvision.utils import save_image, make_grid 21 22 import torchvision.transforms as transforms 23 import torchvision.transforms.functional as TF 24 25 import matplotlib.pyplot as plt 26 import numpy as np 27 28 import math 29 import os 30 31 th = torch 32 DataLoader = tdata.DataLoader 33 34 35 ''' 36 if over trained, remove the *.pt files. 37 set 'epochs' to 20 38 and re-train several times till best results. 39 it using random noisy, results varying. 40 ''' 41 epochs = 60 #0 42 43 train_batch_size = 36 44 th.set_default_dtype(th.float32) 45 th.set_num_threads(4) 46 47 48 dataset_path = '~/datasets' #download to here 49 print(os.path.expanduser(dataset_path)) 50 51 dataset = 'MNIST' 52 img_size = (1, 28, 28) 53 img_size = (1, 16, 16) 54 55 transform = transforms.Compose([ 56 transforms.Resize( (img_size[1],img_size[2]) , transforms.InterpolationMode.BILINEAR ), 57 transforms.ToTensor(), 58 ]) 59 60 train_dataset = MNIST(dataset_path, transform=transform, train=True, download=True) 61 #test_dataset = MNIST(dataset_path, transform=transform, train=False, download=True) 62 63 64 print(len(train_dataset)) 65 generator1 = torch.Generator().manual_seed(69981) 66 train_dataset , _ = tdata.random_split(train_dataset,[0.01,0.99], generator1) 67 68 ''' 69 train_dataset = [x for x in train_dataset if x[1] == 9 ] 70 train_dataset = train_dataset[0:train_batch_size*2] 71 ''' 72 73 train_dataset = [x for x in train_dataset] 74 print(len(train_dataset)) 75 76 def draw_sample_image(x, postfix, block=True): 77 plt.close('all') 78 plt.figure(figsize=(5,5)) 79 plt.axis("off") 80 plt.title(postfix) 81 im = make_grid(x.detach().cpu(), 82 nrow=int(math.sqrt(len(x))), 83 scale_each=True, normalize=False) 84 im = TF.resize(im,(im.size(1),im.size(2))) 85 im = np.transpose(im, (1, 2, 0)) 86 plt.imshow(im) 87 plt.show(block=block) 88 if not block: 89 plt.pause(3) 90 plt.close('all') 91 92 93 class Denoiser(nn.Module): 94 95 def __init__(self): 96 super(Denoiser, self).__init__() 97 C,H,W = img_size 98 self.unet = nn.Sequential( 99 nn.Conv2d(C,32,4,2,1), 100 nn.InstanceNorm2d(32), 101 nn.Conv2d(32,64,4,2,1), 102 nn.InstanceNorm2d(64), 103 nn.Conv2d(64,128,1), 104 nn.SiLU(), 105 nn.Conv2d(128,128,2), 106 nn.LeakyReLU(), 107 nn.Conv2d(128,256,2), 108 nn.LeakyReLU(), 109 110 nn.ConvTranspose2d(256,128,2), 111 nn.LeakyReLU(), 112 nn.ConvTranspose2d(128,64,2), 113 nn.LeakyReLU(), 114 nn.ConvTranspose2d(64,32,4,2,1), 115 nn.LeakyReLU(), 116 nn.ConvTranspose2d(32,C,4,2,1), 117 ) 118 119 def forward(self,x): 120 return self.bruce_forward(x) 121 122 def bruce_forward(self,x): 123 res = self.unet(x) 124 return x - res 125 126 127 class Diffusion(nn.Module): 128 129 def __init__(self): 130 super(Diffusion, self).__init__() 131 self.model = Denoiser() 132 133 def scale_to_minus_one_to_one(self, x): 134 return x * 2 - 1 135 136 def reverse_scale_to_zero_to_one(self, x): 137 return (x + 1) / 2 138 139 def bruce_noisy(self, x_zeros, ranLen=31): 140 x_zeros = x_zeros.detach() 141 x_zeros = self.scale_to_minus_one_to_one(x_zeros) 142 143 rs = [] 144 es = [] 145 146 for _ in range(1): 147 target = torch.rand_like(x_zeros) 148 target = self.scale_to_minus_one_to_one(target) 149 150 alpha = 20 / 100 151 epsilon = target - x_zeros * alpha 152 rs.append(target) 153 es.append(epsilon) 154 155 for _ in range(ranLen-1): 156 alpha = th.randint(21,100,(1,)).item() / 100 157 epsilon = torch.rand_like(x_zeros) 158 epsilon = self.scale_to_minus_one_to_one(epsilon) 159 160 epsilon = epsilon * (1-alpha) 161 noisy_sample = x_zeros * alpha + epsilon 162 rs.append(noisy_sample) 163 es.append(epsilon) 164 165 return rs, es 166 167 168 @th.no_grad() 169 def sample(self,time=64): 170 target = torch.rand(img_size) 171 target = self.scale_to_minus_one_to_one(target) 172 173 rs = [] 174 175 target = target.unsqueeze(0) 176 for alpha in range(20,20+time): 177 178 epsilon = self.model(target).detach() 179 if alpha == 20: 180 pass 181 182 alpha = alpha / 100 183 x_zeros = (target - epsilon) / (alpha) 184 185 a = x_zeros.squeeze(0) 186 a = a.clamp(-1,1) 187 a = self.reverse_scale_to_zero_to_one(a) 188 rs.append(a) 189 190 alpha += 0.01 191 epsilon = torch.rand_like(x_zeros) 192 epsilon = self.scale_to_minus_one_to_one(epsilon) 193 epsilon = epsilon * (1-alpha) 194 epsilon = epsilon * 0.5 195 196 target = x_zeros * alpha + epsilon 197 target = target.clamp(-5,5) 198 199 return th.stack(rs) 200 201 def forward(self,x,y): 202 inputX = [] 203 inputY = [] 204 for a,_ in zip(x,y): 205 rs,es = self.bruce_noisy(a) 206 inputX += rs 207 inputY += es 208 inputX = th.stack(inputX) 209 inputY = th.stack(inputY) 210 return self.model(inputX), inputY 211 212 213 denoising_loss = nn.MSELoss() 214 diffusion = Diffusion() 215 lr = 0.01 216 optimizer = AdamW(diffusion.parameters(), lr=lr) 217 218 if os.path.exists("diff.pt"): 219 a = th.load("diff.pt") 220 try: 221 diffusion.load_state_dict(a["d"]) 222 optimizer.load_state_dict(a["o"]) 223 print("load diff.pt") 224 except Exception as e: 225 print(e) 226 227 228 train_loader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True,) 229 for x,y in train_loader: 230 x = x[0:36] 231 print(th.min(x), th.max(x)) 232 #draw_sample_image(x,"Show") 233 234 x = x[0] 235 rs,es = diffusion.bruce_noisy(x,36) 236 x = th.stack(rs) 237 print(th.min(x), th.max(x)) 238 x = (x+1)/2 239 #draw_sample_image(x,"Noisy") 240 241 x = th.stack(es) 242 print(th.min(x), th.max(x)) 243 x *= (1/1.21) 244 x = (x+1)/2 245 #draw_sample_image(x,"De Noisy") 246 247 break 248 249 count_loader = len(train_loader) 250 def count_parameters(model): 251 return sum(p.numel() for p in model.parameters() if p.requires_grad) 252 print("Model Parameters: ", count_parameters(diffusion), count_loader) 253 254 def show_samples(time=70,block=True): 255 es = [] 256 for l in range(64): 257 x = diffusion.sample(time)[-1] 258 es.append(x) 259 x = th.stack(es) 260 draw_sample_image(x,"Samples",block) 261 262 diffusion.train() 263 for epoch in range(epochs): 264 noise_prediction_loss = 0 265 266 train_loader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True,) 267 for batch_idx, (x, y) in enumerate(train_loader): 268 optimizer.zero_grad() 269 270 x,y = diffusion(x,y) 271 loss = denoising_loss(x.view(-1), y.view(-1)) 272 noise_prediction_loss += loss.item() 273 loss.backward() 274 optimizer.step() 275 print(f"{batch_idx} / {count_loader}.", loss.item()) 276 277 noise_prediction_loss = noise_prediction_loss / count_loader 278 print("Epoch", epoch + 1, f"/ {epochs} complete.", " L: ", noise_prediction_loss) 279 a = {"d":diffusion.state_dict(), 280 "o":optimizer.state_dict()} 281 th.save(a, "diff.pt") 282 print("save diff.pt") 283 if epoch % 10 == 1: 284 show_samples(70,False) 285 286 if noise_prediction_loss < 0.005 : 287 print(epoch+1, " Goto Eval, remove diff.pt before re-train ") 288 break 289 290 diffusion.eval() 291 292 for _ in range(2): 293 show_samples() 294 295 for l in range(10): 296 x = diffusion.sample(81) 297 print(th.min(x), th.max(x)) 298 draw_sample_image(x,"Sample Single") 299 300 301 print("End.")