Skip to content

Commit 2cb6c4d

Browse files
author
yixu.cui
committed
add vit
1 parent 110b267 commit 2cb6c4d

File tree

2 files changed

+736
-0
lines changed

2 files changed

+736
-0
lines changed

models/vit/vit_mnist.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
import math
2+
import os
3+
import time
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import torch
8+
import torch.nn.functional as F
9+
import torch.optim.lr_scheduler as lr_scheduler
10+
import torchvision
11+
from einops import rearrange
12+
from torch import nn
13+
from torch import optim
14+
from torchvision import datasets
15+
from torchvision import transforms
16+
from torchvision.transforms import ToTensor
17+
18+
19+
plt.ion() # interactive mode
20+
21+
torch.manual_seed(42)
22+
DOWNLOAD_PATH = "/share-global/yixu.cui/datas/mnist"
23+
DOWNLOAD_DATA_PATH = "/share-global/yixu.cui/datas/"
24+
BATCH_SIZE_TRAIN = 256 * 8 * 2
25+
BATCH_SIZE_TEST = 256 * 8 * 2 * 2
26+
27+
28+
# device: GPU
29+
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3, 4, 5, 6, 7"
30+
os.environ["CUDA_VISIBLE_DEVICES"] = "4, 5, 6, 7"
31+
32+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33+
device = (
34+
"cuda"
35+
if torch.cuda.is_available()
36+
else "mps"
37+
if torch.backends.mps.is_available()
38+
else "cpu"
39+
)
40+
41+
gpu_num = torch.cuda.device_count()
42+
device_ids = [i for i in range(gpu_num)]
43+
44+
45+
# 残差模块,放在每个前馈网络和注意力之后
46+
class Residual(nn.Module):
47+
def __init__(self, fn):
48+
super().__init__()
49+
self.fn = fn
50+
51+
def forward(self, x, **kwargs):
52+
return self.fn(x, **kwargs) + x
53+
54+
55+
# layernorm归一化,放在多头注意力层和激活函数层。用绝对位置编码的BERT,layernorm用来自身通道归一化
56+
class PreNorm(nn.Module):
57+
def __init__(self, dim, fn):
58+
super().__init__()
59+
self.norm = nn.LayerNorm(dim)
60+
self.fn = fn
61+
62+
def forward(self, x, **kwargs):
63+
return self.fn(self.norm(x), **kwargs)
64+
65+
66+
# 放置多头注意力后,因为在于多头注意力使用的矩阵乘法为线性变换,后面跟上由全连接网络构成的FeedForward增加非线性结构
67+
class FeedForward(nn.Module):
68+
def __init__(self, dim, hidden_dim):
69+
super().__init__()
70+
self.net = nn.Sequential(
71+
nn.Linear(dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, dim)
72+
)
73+
74+
def forward(self, x):
75+
return self.net(x)
76+
77+
78+
# 多头注意力层,多个自注意力连起来。使用qkv计算
79+
class Attention(nn.Module):
80+
def __init__(self, dim, heads=8):
81+
super().__init__()
82+
self.heads = heads
83+
self.scale = dim**-0.5
84+
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
85+
self.to_out = nn.Linear(dim, dim)
86+
87+
def forward(self, x, mask=None):
88+
b, n, _, h = *x.shape, self.heads
89+
qkv = self.to_qkv(x)
90+
q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=h)
91+
dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
92+
if mask is not None:
93+
mask = F.pad(mask.flatten(1), (1, 0), value=True)
94+
assert mask.shape[-1] == dots.shape[-1], "mask has incorrect dimensions"
95+
mask = mask[:, None, :] * mask[:, :, None]
96+
dots.masked_fill_(~mask, float("-inf"))
97+
del mask
98+
attn = dots.softmax(dim=-1)
99+
out = torch.einsum("bhij,bhjd->bhid", attn, v)
100+
out = rearrange(out, "b h n d -> b n (h d)")
101+
out = self.to_out(out)
102+
return out
103+
104+
105+
class Transformer(nn.Module):
106+
def __init__(self, dim, depth, heads, mlp_dim):
107+
super().__init__()
108+
self.layers = nn.ModuleList([])
109+
for _ in range(depth):
110+
self.layers.append(
111+
nn.ModuleList(
112+
[
113+
Residual(PreNorm(dim, Attention(dim, heads=heads))),
114+
Residual(PreNorm(dim, FeedForward(dim, mlp_dim))),
115+
]
116+
)
117+
)
118+
119+
def forward(self, x, mask=None):
120+
for attn, ff in self.layers:
121+
# print(f"batch size: {x.shape[0]}") # debug to locate how many img in per GPU
122+
x = attn(x, mask=mask)
123+
x = ff(x)
124+
return x
125+
126+
127+
# 将图像切割成一个个图像块,组成序列化的数据输入Transformer执行图像分类任务。
128+
class ViT(nn.Module):
129+
def __init__(
130+
self,
131+
*,
132+
image_size,
133+
patch_size,
134+
num_classes,
135+
dim,
136+
depth,
137+
heads,
138+
mlp_dim,
139+
channels=3,
140+
mlp_drop_ratio=0.0,
141+
):
142+
super().__init__()
143+
assert (
144+
image_size % patch_size == 0
145+
), "image dimensions must be divisible by the patch size"
146+
num_patches = (image_size // patch_size) ** 2
147+
patch_dim = channels * patch_size**2
148+
self.patch_size = patch_size
149+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
150+
self.patch_to_embedding = nn.Linear(patch_dim, dim)
151+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
152+
self.transformer = Transformer(dim, depth, heads, mlp_dim)
153+
self.to_cls_token = nn.Identity()
154+
self.mlp_head = nn.Sequential(
155+
nn.Linear(dim, mlp_dim),
156+
nn.Dropout(p=mlp_drop_ratio),
157+
nn.GELU(),
158+
nn.Linear(mlp_dim, num_classes),
159+
# nn.Dropout(p=mlp_drop_ratio),
160+
)
161+
162+
def forward(self, img, mask=None):
163+
p = self.patch_size
164+
x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p)
165+
x = self.patch_to_embedding(x)
166+
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
167+
x = torch.cat((cls_tokens, x), dim=1)
168+
x += self.pos_embedding
169+
x = self.transformer(x, mask)
170+
x = self.to_cls_token(x[:, 0])
171+
return self.mlp_head(x)
172+
173+
174+
###
175+
def train_epoch(model, optimizer, data_loader, loss_history):
176+
total_samples = len(data_loader.dataset)
177+
model.train()
178+
179+
for i, (data, target) in enumerate(data_loader):
180+
optimizer.zero_grad()
181+
output = F.log_softmax(model(data.to(device)), dim=1)
182+
loss = F.nll_loss(output, target.to(device))
183+
loss.backward()
184+
optimizer.step()
185+
186+
if i % 128 == 0:
187+
print(
188+
"["
189+
+ "{:5}".format(i * len(data))
190+
+ "/"
191+
+ "{:5}".format(total_samples)
192+
+ " ("
193+
+ "{:3.0f}".format(100 * i / len(data_loader))
194+
+ "%)] Loss: "
195+
+ "{:6.4f}".format(loss.item())
196+
)
197+
loss_history.append(loss.item())
198+
199+
200+
def evaluate(model, data_loader, loss_history):
201+
model.eval()
202+
203+
total_samples = len(data_loader.dataset)
204+
correct_samples = 0
205+
total_loss = 0
206+
207+
with torch.no_grad():
208+
for data, target in data_loader:
209+
output = F.log_softmax(model(data.to(device)), dim=1)
210+
loss = F.nll_loss(output, target.to(device), reduction="sum")
211+
_, pred = torch.max(output, dim=1)
212+
213+
total_loss += loss.item()
214+
correct_samples += pred.eq(target.to(device)).sum()
215+
216+
avg_loss = total_loss / total_samples
217+
loss_history.append(avg_loss)
218+
print(
219+
"\nAverage test loss: "
220+
+ "{:.4f}".format(avg_loss)
221+
+ " Accuracy:"
222+
+ "{:5}".format(correct_samples)
223+
+ "/"
224+
+ "{:5}".format(total_samples)
225+
+ " ("
226+
+ "{:14.2f}".format(100.0 * correct_samples / total_samples)
227+
+ "%)\n"
228+
)
229+
230+
231+
if __name__ == "__main__":
232+
# mnist
233+
transform_mnist = torchvision.transforms.Compose(
234+
[
235+
torchvision.transforms.ToTensor(),
236+
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
237+
]
238+
)
239+
240+
train_set = torchvision.datasets.MNIST(
241+
DOWNLOAD_PATH, train=True, download=True, transform=transform_mnist
242+
)
243+
train_loader = torch.utils.data.DataLoader(
244+
train_set, batch_size=BATCH_SIZE_TRAIN, shuffle=True
245+
)
246+
247+
test_set = torchvision.datasets.MNIST(
248+
DOWNLOAD_PATH, train=False, download=True, transform=transform_mnist
249+
)
250+
test_loader = torch.utils.data.DataLoader(
251+
test_set, batch_size=BATCH_SIZE_TEST, shuffle=True
252+
)
253+
254+
for X, y in test_loader:
255+
print(f"Shape of X [N, C, H, W]: {X.shape}")
256+
print(f"Shape of y: {y.shape} {y.dtype}")
257+
break
258+
259+
EPOCHS_NUM = 100
260+
# start_time = time.time()
261+
262+
"""
263+
patch大小为 7x7(对于 28x28 图像,这意味着每个图像 4 x 4 = 16 个patch)、10 个可能的目标类别(0 到 9)和 1 个颜色通道(因为图像是灰度)。
264+
在网络参数方面,使用了 64 个单元的维度,6 个 Transformer 块的深度,8 个 Transformer 头,MLP 使用 128 维度。
265+
"""
266+
model = ViT(
267+
image_size=28,
268+
patch_size=7,
269+
num_classes=10,
270+
channels=1,
271+
dim=64,
272+
depth=6,
273+
heads=8,
274+
mlp_dim=128,
275+
mlp_drop_ratio=0.5,
276+
)
277+
# print(f"No device:\n{model}")
278+
# 模型并行化
279+
if torch.cuda.device_count() > 1:
280+
print("Let's use", torch.cuda.device_count(), "GPUs!")
281+
model = nn.DataParallel(model)
282+
# model = nn.parallel.DistributedDataParallel(model)
283+
284+
model = model.to(device)
285+
# print(f"On device:\n{model}")
286+
287+
loss_fn = (
288+
nn.CrossEntropyLoss()
289+
) # Ref: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
290+
291+
optimizer = optim.Adam(model.parameters(), lr=0.008)
292+
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
293+
# lf = (
294+
# lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf)
295+
# + args.lrf
296+
# ) # cosine
297+
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
298+
299+
train_loss_history, test_loss_history = [], []
300+
for epoch in range(1, EPOCHS_NUM + 1):
301+
print("Epoch:", epoch)
302+
start_time = time.time()
303+
train_epoch(model, optimizer, train_loader, train_loss_history)
304+
evaluate(model, test_loader, test_loss_history)
305+
print(
306+
"This EPOCH takes time:",
307+
"{:5.2f}".format(time.time() - start_time),
308+
"seconds",
309+
)
310+
311+
# print("Execution time:", "{:5.2f}".format(time.time() - start_time), "seconds")

0 commit comments

Comments
 (0)