Skip to content

Commit 824d30f

Browse files
author
yixu.cui
committed
add models/vit/vit_fashion_mnist.py
1 parent f4bd8b0 commit 824d30f

File tree

1 file changed

+341
-0
lines changed

1 file changed

+341
-0
lines changed

models/vit/vit_fashion_mnist.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
import math
2+
import os
3+
import time
4+
5+
import matplotlib.pyplot as plt
6+
import numpy as np
7+
import torch
8+
import torch.nn.functional as F
9+
import torch.optim.lr_scheduler as lr_scheduler
10+
import torchvision
11+
from einops import rearrange
12+
from torch import nn
13+
from torch import optim
14+
from torchvision import datasets
15+
from torchvision import transforms
16+
from torchvision.transforms import ToTensor
17+
18+
19+
plt.ion() # interactive mode
20+
21+
torch.manual_seed(42)
22+
DOWNLOAD_PATH = "/share-global/yixu.cui/datas/mnist"
23+
DOWNLOAD_DATA_PATH = "/share-global/yixu.cui/datas/"
24+
DOWNLOAD_DATA_PATH = "debug"
25+
BATCH_SIZE_TRAIN = 256 * 8 * 2
26+
BATCH_SIZE_TEST = 256 * 8 * 2 * 2
27+
28+
29+
# device: GPU
30+
os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3, 4, 5, 6, 7"
31+
os.environ["CUDA_VISIBLE_DEVICES"] = "4, 5, 6, 7"
32+
33+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34+
device = (
35+
"cuda"
36+
if torch.cuda.is_available()
37+
else "mps"
38+
if torch.backends.mps.is_available()
39+
else "cpu"
40+
)
41+
42+
gpu_num = torch.cuda.device_count()
43+
device_ids = [i for i in range(gpu_num)]
44+
45+
46+
# 残差模块,放在每个前馈网络和注意力之后
47+
class Residual(nn.Module):
48+
def __init__(self, fn):
49+
super().__init__()
50+
self.fn = fn
51+
52+
def forward(self, x, **kwargs):
53+
return self.fn(x, **kwargs) + x
54+
55+
56+
# layernorm归一化,放在多头注意力层和激活函数层。用绝对位置编码的BERT,layernorm用来自身通道归一化
57+
class PreNorm(nn.Module):
58+
def __init__(self, dim, fn):
59+
super().__init__()
60+
self.norm = nn.LayerNorm(dim)
61+
self.fn = fn
62+
63+
def forward(self, x, **kwargs):
64+
return self.fn(self.norm(x), **kwargs)
65+
66+
67+
# 放置多头注意力后,因为在于多头注意力使用的矩阵乘法为线性变换,后面跟上由全连接网络构成的FeedForward增加非线性结构
68+
class FeedForward(nn.Module):
69+
def __init__(self, dim, hidden_dim):
70+
super().__init__()
71+
self.net = nn.Sequential(
72+
nn.Linear(dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, dim)
73+
)
74+
75+
def forward(self, x):
76+
return self.net(x)
77+
78+
79+
# 多头注意力层,多个自注意力连起来。使用qkv计算
80+
class Attention(nn.Module):
81+
def __init__(self, dim, heads=8):
82+
super().__init__()
83+
self.heads = heads
84+
self.scale = dim**-0.5
85+
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
86+
self.to_out = nn.Linear(dim, dim)
87+
88+
def forward(self, x, mask=None):
89+
b, n, _, h = *x.shape, self.heads
90+
qkv = self.to_qkv(x)
91+
q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=h)
92+
dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale
93+
if mask is not None:
94+
mask = F.pad(mask.flatten(1), (1, 0), value=True)
95+
assert mask.shape[-1] == dots.shape[-1], "mask has incorrect dimensions"
96+
mask = mask[:, None, :] * mask[:, :, None]
97+
dots.masked_fill_(~mask, float("-inf"))
98+
del mask
99+
attn = dots.softmax(dim=-1)
100+
out = torch.einsum("bhij,bhjd->bhid", attn, v)
101+
out = rearrange(out, "b h n d -> b n (h d)")
102+
out = self.to_out(out)
103+
return out
104+
105+
106+
class Transformer(nn.Module):
107+
def __init__(self, dim, depth, heads, mlp_dim):
108+
super().__init__()
109+
self.layers = nn.ModuleList([])
110+
for _ in range(depth):
111+
self.layers.append(
112+
nn.ModuleList(
113+
[
114+
Residual(PreNorm(dim, Attention(dim, heads=heads))),
115+
Residual(PreNorm(dim, FeedForward(dim, mlp_dim))),
116+
]
117+
)
118+
)
119+
120+
def forward(self, x, mask=None):
121+
for attn, ff in self.layers:
122+
# print(f"batch size: {x.shape[0]}") # debug to locate how many img in per GPU
123+
x = attn(x, mask=mask)
124+
x = ff(x)
125+
return x
126+
127+
128+
# 将图像切割成一个个图像块,组成序列化的数据输入Transformer执行图像分类任务。
129+
class ViT(nn.Module):
130+
def __init__(
131+
self,
132+
*,
133+
image_size,
134+
patch_size,
135+
num_classes,
136+
dim,
137+
depth,
138+
heads,
139+
mlp_dim,
140+
channels=3,
141+
):
142+
super().__init__()
143+
assert (
144+
image_size % patch_size == 0
145+
), "image dimensions must be divisible by the patch size"
146+
num_patches = (image_size // patch_size) ** 2
147+
patch_dim = channels * patch_size**2
148+
self.patch_size = patch_size
149+
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
150+
self.patch_to_embedding = nn.Linear(patch_dim, dim)
151+
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
152+
self.transformer = Transformer(dim, depth, heads, mlp_dim)
153+
self.to_cls_token = nn.Identity()
154+
self.mlp_head = nn.Sequential(
155+
nn.Linear(dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, num_classes)
156+
)
157+
158+
def forward(self, img, mask=None):
159+
p = self.patch_size
160+
x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p)
161+
x = self.patch_to_embedding(x)
162+
cls_tokens = self.cls_token.expand(img.shape[0], -1, -1)
163+
x = torch.cat((cls_tokens, x), dim=1)
164+
x += self.pos_embedding
165+
x = self.transformer(x, mask)
166+
x = self.to_cls_token(x[:, 0])
167+
return self.mlp_head(x)
168+
169+
170+
###
171+
def train_epoch(model, optimizer, data_loader, loss_history):
172+
total_samples = len(data_loader.dataset)
173+
model.train()
174+
175+
for i, (data, target) in enumerate(data_loader):
176+
optimizer.zero_grad()
177+
output = F.log_softmax(model(data.to(device)), dim=1)
178+
loss = F.nll_loss(output, target.to(device))
179+
loss.backward()
180+
optimizer.step()
181+
182+
if i % 128 == 0:
183+
print(
184+
"["
185+
+ "{:5}".format(i * len(data))
186+
+ "/"
187+
+ "{:5}".format(total_samples)
188+
+ " ("
189+
+ "{:3.0f}".format(100 * i / len(data_loader))
190+
+ "%)] Loss: "
191+
+ "{:6.4f}".format(loss.item())
192+
)
193+
loss_history.append(loss.item())
194+
195+
196+
def evaluate(model, data_loader, loss_history):
197+
model.eval()
198+
199+
total_samples = len(data_loader.dataset)
200+
correct_samples = 0
201+
total_loss = 0
202+
203+
with torch.no_grad():
204+
for data, target in data_loader:
205+
output = F.log_softmax(model(data.to(device)), dim=1)
206+
loss = F.nll_loss(output, target.to(device), reduction="sum")
207+
_, pred = torch.max(output, dim=1)
208+
209+
total_loss += loss.item()
210+
correct_samples += pred.eq(target.to(device)).sum()
211+
212+
avg_loss = total_loss / total_samples
213+
loss_history.append(avg_loss)
214+
print(
215+
"\nAverage test loss: "
216+
+ "{:.4f}".format(avg_loss)
217+
+ " Accuracy:"
218+
+ "{:5}".format(correct_samples)
219+
+ "/"
220+
+ "{:5}".format(total_samples)
221+
+ " ("
222+
+ "{:14.2f}".format(100.0 * correct_samples / total_samples)
223+
+ "%)\n"
224+
)
225+
226+
227+
if __name__ == "__main__":
228+
# mnist
229+
transform_mnist = torchvision.transforms.Compose(
230+
[
231+
torchvision.transforms.ToTensor(),
232+
torchvision.transforms.Normalize((0.1307,), (0.3081,)),
233+
]
234+
)
235+
236+
train_set = torchvision.datasets.MNIST(
237+
DOWNLOAD_PATH, train=True, download=True, transform=transform_mnist
238+
)
239+
train_loader = torch.utils.data.DataLoader(
240+
train_set, batch_size=BATCH_SIZE_TRAIN, shuffle=True
241+
)
242+
243+
test_set = torchvision.datasets.MNIST(
244+
DOWNLOAD_PATH, train=False, download=True, transform=transform_mnist
245+
)
246+
test_loader = torch.utils.data.DataLoader(
247+
test_set, batch_size=BATCH_SIZE_TEST, shuffle=True
248+
)
249+
250+
# FashionMNIST
251+
# Download training data from open datasets.
252+
training_data = datasets.FashionMNIST(
253+
root=DOWNLOAD_DATA_PATH,
254+
train=True,
255+
download=True,
256+
transform=ToTensor(),
257+
)
258+
259+
# Download test data from open datasets.
260+
test_data = datasets.FashionMNIST(
261+
root=DOWNLOAD_DATA_PATH,
262+
train=False,
263+
download=True,
264+
transform=ToTensor(),
265+
)
266+
exit()
267+
268+
# batch_size = 64
269+
270+
# Create data loaders.
271+
train_dataloader = DataLoader(
272+
training_data,
273+
batch_size=batch_size,
274+
)
275+
test_dataloader = DataLoader(
276+
test_data,
277+
batch_size=batch_size,
278+
)
279+
280+
for X, y in test_dataloader:
281+
print(f"Shape of X [N, C, H, W]: {X.shape}")
282+
print(f"Shape of y: {y.shape} {y.dtype}")
283+
break
284+
285+
for X, y in test_loader:
286+
print(f"Shape of X [N, C, H, W]: {X.shape}")
287+
print(f"Shape of y: {y.shape} {y.dtype}")
288+
break
289+
290+
EPOCHS_NUM = 100
291+
# start_time = time.time()
292+
293+
"""
294+
patch大小为 7x7(对于 28x28 图像,这意味着每个图像 4 x 4 = 16 个patch)、10 个可能的目标类别(0 到 9)和 1 个颜色通道(因为图像是灰度)。
295+
在网络参数方面,使用了 64 个单元的维度,6 个 Transformer 块的深度,8 个 Transformer 头,MLP 使用 128 维度。
296+
"""
297+
model = ViT(
298+
image_size=28,
299+
patch_size=7,
300+
num_classes=10,
301+
channels=1,
302+
dim=64,
303+
depth=6,
304+
heads=8,
305+
mlp_dim=128,
306+
)
307+
# print(f"No device:\n{model}")
308+
# 模型并行化
309+
if torch.cuda.device_count() > 1:
310+
print("Let's use", torch.cuda.device_count(), "GPUs!")
311+
model = nn.DataParallel(model)
312+
# model = nn.parallel.DistributedDataParallel(model)
313+
314+
model = model.to(device)
315+
# print(f"On device:\n{model}")
316+
317+
loss_fn = (
318+
nn.CrossEntropyLoss()
319+
) # Ref: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
320+
321+
optimizer = optim.Adam(model.parameters(), lr=0.008)
322+
# Scheduler https://arxiv.org/pdf/1812.01187.pdf
323+
# lf = (
324+
# lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf)
325+
# + args.lrf
326+
# ) # cosine
327+
# scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)
328+
329+
train_loss_history, test_loss_history = [], []
330+
for epoch in range(1, EPOCHS_NUM + 1):
331+
print("Epoch:", epoch)
332+
start_time = time.time()
333+
train_epoch(model, optimizer, train_loader, train_loss_history)
334+
evaluate(model, test_loader, test_loss_history)
335+
print(
336+
"This EPOCH takes time:",
337+
"{:5.2f}".format(time.time() - start_time),
338+
"seconds",
339+
)
340+
341+
# print("Execution time:", "{:5.2f}".format(time.time() - start_time), "seconds")

0 commit comments

Comments
 (0)