G

Untitled

public
Guest Jan 17, 2025 Never 12
Clone
Python softmax 21 lines (15 loc) | 572 Bytes
1
def softmax_hand(x, dim):
2
exp_x = (x - x.max(dim=dim, keepdim=True)[0]).exp()
3
sum_exp_x = exp_x.sum(dim=dim, keepdim=True)
4
return exp_x / sum_exp_x
5
6
import torch
7
x = torch.rand((6, 6), dtype=torch.bfloat16)
8
dim = -1
9
10
x_custom = x.clone().requires_grad_(True)
11
x_builtin = x.clone().requires_grad_(True)
12
13
custom_out = softmax_hand(x_custom, dim)
14
builtin_out = torch.nn.functional.softmax(x_builtin, dim=dim)
15
16
custom_out.sum().backward()
17
builtin_out.sum().backward()
18
19
(x_custom.grad - x_builtin.grad).abs().max()
20
21
x_custom.grad, x_builtin.grad