G

Directly denoising Diffusion model the Simplest deep learning

public
Guest Jun 01, 2025 Never 29
Clone
Python diff.py 301 lines (229 loc) | 8.1 KB
1
#!/usr/bin/env python
2
# coding: utf-8
3
4
'''Directly denoising Diffusion model, the Simplest deep learning.'''
5
'''Version 1.0'''
6
'''python diff.py'''
7
8
__credits__ = ["iBoxDB", "Bruce Yang CL-N", "2025-5"]
9
10
11
# 3rd parts
12
#https://pytorch.org CPU version only
13
import torch
14
15
import torch.nn as nn
16
import torch.utils.data as tdata
17
from torch.optim import Adam,AdamW
18
19
from torchvision.datasets import MNIST, CIFAR10
20
from torchvision.utils import save_image, make_grid
21
22
import torchvision.transforms as transforms
23
import torchvision.transforms.functional as TF
24
25
import matplotlib.pyplot as plt
26
import numpy as np
27
28
import math
29
import os
30
31
th = torch
32
DataLoader = tdata.DataLoader
33
34
35
'''
36
if over trained, remove the *.pt files.
37
set 'epochs' to 20
38
and re-train several times till best results.
39
it using random noisy, results varying.
40
'''
41
epochs = 60 #0
42
43
train_batch_size = 36
44
th.set_default_dtype(th.float32)
45
th.set_num_threads(4)
46
47
48
dataset_path = '~/datasets' #download to here
49
print(os.path.expanduser(dataset_path))
50
51
dataset = 'MNIST'
52
img_size = (1, 28, 28)
53
img_size = (1, 16, 16)
54
55
transform = transforms.Compose([
56
transforms.Resize( (img_size[1],img_size[2]) , transforms.InterpolationMode.BILINEAR ),
57
transforms.ToTensor(),
58
])
59
60
train_dataset = MNIST(dataset_path, transform=transform, train=True, download=True)
61
#test_dataset = MNIST(dataset_path, transform=transform, train=False, download=True)
62
63
64
print(len(train_dataset))
65
generator1 = torch.Generator().manual_seed(69981)
66
train_dataset , _ = tdata.random_split(train_dataset,[0.01,0.99], generator1)
67
68
'''
69
train_dataset = [x for x in train_dataset if x[1] == 9 ]
70
train_dataset = train_dataset[0:train_batch_size*2]
71
'''
72
73
train_dataset = [x for x in train_dataset]
74
print(len(train_dataset))
75
76
def draw_sample_image(x, postfix, block=True):
77
plt.close('all')
78
plt.figure(figsize=(5,5))
79
plt.axis("off")
80
plt.title(postfix)
81
im = make_grid(x.detach().cpu(),
82
nrow=int(math.sqrt(len(x))),
83
scale_each=True, normalize=False)
84
im = TF.resize(im,(im.size(1),im.size(2)))
85
im = np.transpose(im, (1, 2, 0))
86
plt.imshow(im)
87
plt.show(block=block)
88
if not block:
89
plt.pause(3)
90
plt.close('all')
91
92
93
class Denoiser(nn.Module):
94
95
def __init__(self):
96
super(Denoiser, self).__init__()
97
C,H,W = img_size
98
self.unet = nn.Sequential(
99
nn.Conv2d(C,32,4,2,1),
100
nn.InstanceNorm2d(32),
101
nn.Conv2d(32,64,4,2,1),
102
nn.InstanceNorm2d(64),
103
nn.Conv2d(64,128,1),
104
nn.SiLU(),
105
nn.Conv2d(128,128,2),
106
nn.LeakyReLU(),
107
nn.Conv2d(128,256,2),
108
nn.LeakyReLU(),
109
110
nn.ConvTranspose2d(256,128,2),
111
nn.LeakyReLU(),
112
nn.ConvTranspose2d(128,64,2),
113
nn.LeakyReLU(),
114
nn.ConvTranspose2d(64,32,4,2,1),
115
nn.LeakyReLU(),
116
nn.ConvTranspose2d(32,C,4,2,1),
117
)
118
119
def forward(self,x):
120
return self.bruce_forward(x)
121
122
def bruce_forward(self,x):
123
res = self.unet(x)
124
return x - res
125
126
127
class Diffusion(nn.Module):
128
129
def __init__(self):
130
super(Diffusion, self).__init__()
131
self.model = Denoiser()
132
133
def scale_to_minus_one_to_one(self, x):
134
return x * 2 - 1
135
136
def reverse_scale_to_zero_to_one(self, x):
137
return (x + 1) / 2
138
139
def bruce_noisy(self, x_zeros, ranLen=31):
140
x_zeros = x_zeros.detach()
141
x_zeros = self.scale_to_minus_one_to_one(x_zeros)
142
143
rs = []
144
es = []
145
146
for _ in range(1):
147
target = torch.rand_like(x_zeros)
148
target = self.scale_to_minus_one_to_one(target)
149
150
alpha = 20 / 100
151
epsilon = target - x_zeros * alpha
152
rs.append(target)
153
es.append(epsilon)
154
155
for _ in range(ranLen-1):
156
alpha = th.randint(21,100,(1,)).item() / 100
157
epsilon = torch.rand_like(x_zeros)
158
epsilon = self.scale_to_minus_one_to_one(epsilon)
159
160
epsilon = epsilon * (1-alpha)
161
noisy_sample = x_zeros * alpha + epsilon
162
rs.append(noisy_sample)
163
es.append(epsilon)
164
165
return rs, es
166
167
168
@th.no_grad()
169
def sample(self,time=64):
170
target = torch.rand(img_size)
171
target = self.scale_to_minus_one_to_one(target)
172
173
rs = []
174
175
target = target.unsqueeze(0)
176
for alpha in range(20,20+time):
177
178
epsilon = self.model(target).detach()
179
if alpha == 20:
180
pass
181
182
alpha = alpha / 100
183
x_zeros = (target - epsilon) / (alpha)
184
185
a = x_zeros.squeeze(0)
186
a = a.clamp(-1,1)
187
a = self.reverse_scale_to_zero_to_one(a)
188
rs.append(a)
189
190
alpha += 0.01
191
epsilon = torch.rand_like(x_zeros)
192
epsilon = self.scale_to_minus_one_to_one(epsilon)
193
epsilon = epsilon * (1-alpha)
194
epsilon = epsilon * 0.5
195
196
target = x_zeros * alpha + epsilon
197
target = target.clamp(-5,5)
198
199
return th.stack(rs)
200
201
def forward(self,x,y):
202
inputX = []
203
inputY = []
204
for a,_ in zip(x,y):
205
rs,es = self.bruce_noisy(a)
206
inputX += rs
207
inputY += es
208
inputX = th.stack(inputX)
209
inputY = th.stack(inputY)
210
return self.model(inputX), inputY
211
212
213
denoising_loss = nn.MSELoss()
214
diffusion = Diffusion()
215
lr = 0.01
216
optimizer = AdamW(diffusion.parameters(), lr=lr)
217
218
if os.path.exists("diff.pt"):
219
a = th.load("diff.pt")
220
try:
221
diffusion.load_state_dict(a["d"])
222
optimizer.load_state_dict(a["o"])
223
print("load diff.pt")
224
except Exception as e:
225
print(e)
226
227
228
train_loader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True,)
229
for x,y in train_loader:
230
x = x[0:36]
231
print(th.min(x), th.max(x))
232
#draw_sample_image(x,"Show")
233
234
x = x[0]
235
rs,es = diffusion.bruce_noisy(x,36)
236
x = th.stack(rs)
237
print(th.min(x), th.max(x))
238
x = (x+1)/2
239
#draw_sample_image(x,"Noisy")
240
241
x = th.stack(es)
242
print(th.min(x), th.max(x))
243
x *= (1/1.21)
244
x = (x+1)/2
245
#draw_sample_image(x,"De Noisy")
246
247
break
248
249
count_loader = len(train_loader)
250
def count_parameters(model):
251
return sum(p.numel() for p in model.parameters() if p.requires_grad)
252
print("Model Parameters: ", count_parameters(diffusion), count_loader)
253
254
def show_samples(time=70,block=True):
255
es = []
256
for l in range(64):
257
x = diffusion.sample(time)[-1]
258
es.append(x)
259
x = th.stack(es)
260
draw_sample_image(x,"Samples",block)
261
262
diffusion.train()
263
for epoch in range(epochs):
264
noise_prediction_loss = 0
265
266
train_loader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True,)
267
for batch_idx, (x, y) in enumerate(train_loader):
268
optimizer.zero_grad()
269
270
x,y = diffusion(x,y)
271
loss = denoising_loss(x.view(-1), y.view(-1))
272
noise_prediction_loss += loss.item()
273
loss.backward()
274
optimizer.step()
275
print(f"{batch_idx} / {count_loader}.", loss.item())
276
277
noise_prediction_loss = noise_prediction_loss / count_loader
278
print("Epoch", epoch + 1, f"/ {epochs} complete.", " L: ", noise_prediction_loss)
279
a = {"d":diffusion.state_dict(),
280
"o":optimizer.state_dict()}
281
th.save(a, "diff.pt")
282
print("save diff.pt")
283
if epoch % 10 == 1:
284
show_samples(70,False)
285
286
if noise_prediction_loss < 0.005 :
287
print(epoch+1, " Goto Eval, remove diff.pt before re-train ")
288
break
289
290
diffusion.eval()
291
292
for _ in range(2):
293
show_samples()
294
295
for l in range(10):
296
x = diffusion.sample(81)
297
print(th.min(x), th.max(x))
298
draw_sample_image(x,"Sample Single")
299
300
301
print("End.")