1 import torch 2 import torch.nn as nn 3 4 # 定义一个 InstanceNorm2d 层 5 inst_norm = nn.InstanceNorm2d(num_features=64, affine=True) 6 7 # 输入张量 (batch_size=8, channels=64, height=32, width=32) 8 x = torch.randn(8, 64, 32, 32) 9 10 # 应用 Instance Normalization 11 output = inst_norm(x) 12 print(output.shape) # 输出形状仍为 (8, 64, 32, 32)