Untitled
public
Jan 17, 2025
Never
12
1 def softmax_hand(x, dim): 2 exp_x = (x - x.max(dim=dim, keepdim=True)[0]).exp() 3 sum_exp_x = exp_x.sum(dim=dim, keepdim=True) 4 return exp_x / sum_exp_x 5 6 import torch 7 x = torch.rand((6, 6), dtype=torch.bfloat16) 8 dim = -1 9 10 x_custom = x.clone().requires_grad_(True) 11 x_builtin = x.clone().requires_grad_(True) 12 13 custom_out = softmax_hand(x_custom, dim) 14 builtin_out = torch.nn.functional.softmax(x_builtin, dim=dim) 15 16 custom_out.sum().backward() 17 builtin_out.sum().backward() 18 19 (x_custom.grad - x_builtin.grad).abs().max() 20 21 x_custom.grad, x_builtin.grad