0%

Pytorch常用例子

本文档记录了一些pytorch常用操作以及概念

1. 构造数据

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

# 1. 全 0 / 全 1 / 常数
torch.zeros(3, 4) # shape=(3, 4)
torch.ones(2, 3)
torch.full((2, 2), 7.0)
x = torch.randn(10,8,4)
y = torch.ones_like(x) # 复制shape

# 2. 随机数据
torch.randn(5, 10) # 标准正态分布 N(0,1)
torch.rand(3, 3) # 均匀分布 [0, 1)
torch.randint(0, 10, (2, 4)) # 整数随机数

# 3. 类似 numpy 的方式
torch.arange(0, 10, 2) # [0, 2, 4, 6, 8]
torch.linspace(0, 1, 5) # [0., 0.25, 0.5, 0.75, 1.]

# 4. 从 numpy 转换
import numpy as np
torch.from_numpy(np.array([[1, 2], [3, 4]]))

2. 常用输入数据

1
2
3
# 假数据:10 张 RGB 图片(3 通道,32x32),每张图一个标签
x = torch.randn(10, 3, 32, 32) # 图像张量
y = torch.randint(0, 5, (10,)) # 标签:5 类分类任务

3. 向量化操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
a = torch.randn(3, 4)
b = torch.randn(4, 5)

# 1. 矩阵乘法
out = a @ b # shape=(3,5)
out = torch.matmul(a, b)

# 2. 广播加法
x = torch.randn(10, 5) # 广播机制,自动扩展
bias = torch.randn(5)
x = x + bias
x = torch.arange(0, 10, 1).reshape(2, 5)
bias = torch.tensor(1.0) # 广播机制,scalar 自动扩展
y = x + bias


# 3. 拼接
torch.cat([a, a], dim=0) # 拼接行
torch.cat([a, a], dim=1) # 拼接列

# 4. reshape & transpose
x = torch.randn(2, 3, 4)
x.view(6, 4) # 改形状
x.permute(1, 0, 2) # 交换维度

# 5. 选择/掩码
x = torch.randn(10)
mask = x > 0
x[mask] # 选出正数元素

eval 和 train的切换

eval和train类似现场保护功能,开关切换switch off
model.train()
将模型设置为“训练模式”。这会启用诸如 Dropout 和 BatchNorm 这样的层的训练行为(如参数更新、随机失活等)。通常在训练阶段调用。

model.eval()
将模型设置为“评估/推理模式”。这会关闭 Dropout、BatchNorm 等层的训练特性,使用固定参数进行推理。通常在验证或测试阶段调用

1
2
3
4
5
6
7
8
9
10
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out