Untitled

public
kawabata030223 Dec 10, 2024 Never 25
Clone
Python paste1.py 12 lines (9 loc) | 349 Bytes
1
import torch
2
import torch.nn as nn
3
4
# 定义一个 InstanceNorm2d 层
5
inst_norm = nn.InstanceNorm2d(num_features=64, affine=True)
6
7
# 输入张量 (batch_size=8, channels=64, height=32, width=32)
8
x = torch.randn(8, 64, 32, 32)
9
10
# 应用 Instance Normalization
11
output = inst_norm(x)
12
print(output.shape) # 输出形状仍为 (8, 64, 32, 32)