|
| 1 | +import math |
| 2 | +import os |
| 3 | +import time |
| 4 | + |
| 5 | +import matplotlib.pyplot as plt |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | +import torch.nn.functional as F |
| 9 | +import torch.optim.lr_scheduler as lr_scheduler |
| 10 | +import torchvision |
| 11 | +from einops import rearrange |
| 12 | +from torch import nn |
| 13 | +from torch import optim |
| 14 | +from torchvision import datasets |
| 15 | +from torchvision import transforms |
| 16 | +from torchvision.transforms import ToTensor |
| 17 | + |
| 18 | + |
| 19 | +plt.ion() # interactive mode |
| 20 | + |
| 21 | +torch.manual_seed(42) |
| 22 | +DOWNLOAD_PATH = "/share-global/yixu.cui/datas/mnist" |
| 23 | +DOWNLOAD_DATA_PATH = "/share-global/yixu.cui/datas/" |
| 24 | +DOWNLOAD_DATA_PATH = "debug" |
| 25 | +BATCH_SIZE_TRAIN = 256 * 8 * 2 |
| 26 | +BATCH_SIZE_TEST = 256 * 8 * 2 * 2 |
| 27 | + |
| 28 | + |
| 29 | +# device: GPU |
| 30 | +os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1, 2, 3, 4, 5, 6, 7" |
| 31 | +os.environ["CUDA_VISIBLE_DEVICES"] = "4, 5, 6, 7" |
| 32 | + |
| 33 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 34 | +device = ( |
| 35 | + "cuda" |
| 36 | + if torch.cuda.is_available() |
| 37 | + else "mps" |
| 38 | + if torch.backends.mps.is_available() |
| 39 | + else "cpu" |
| 40 | +) |
| 41 | + |
| 42 | +gpu_num = torch.cuda.device_count() |
| 43 | +device_ids = [i for i in range(gpu_num)] |
| 44 | + |
| 45 | + |
| 46 | +# 残差模块,放在每个前馈网络和注意力之后 |
| 47 | +class Residual(nn.Module): |
| 48 | + def __init__(self, fn): |
| 49 | + super().__init__() |
| 50 | + self.fn = fn |
| 51 | + |
| 52 | + def forward(self, x, **kwargs): |
| 53 | + return self.fn(x, **kwargs) + x |
| 54 | + |
| 55 | + |
| 56 | +# layernorm归一化,放在多头注意力层和激活函数层。用绝对位置编码的BERT,layernorm用来自身通道归一化 |
| 57 | +class PreNorm(nn.Module): |
| 58 | + def __init__(self, dim, fn): |
| 59 | + super().__init__() |
| 60 | + self.norm = nn.LayerNorm(dim) |
| 61 | + self.fn = fn |
| 62 | + |
| 63 | + def forward(self, x, **kwargs): |
| 64 | + return self.fn(self.norm(x), **kwargs) |
| 65 | + |
| 66 | + |
| 67 | +# 放置多头注意力后,因为在于多头注意力使用的矩阵乘法为线性变换,后面跟上由全连接网络构成的FeedForward增加非线性结构 |
| 68 | +class FeedForward(nn.Module): |
| 69 | + def __init__(self, dim, hidden_dim): |
| 70 | + super().__init__() |
| 71 | + self.net = nn.Sequential( |
| 72 | + nn.Linear(dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, dim) |
| 73 | + ) |
| 74 | + |
| 75 | + def forward(self, x): |
| 76 | + return self.net(x) |
| 77 | + |
| 78 | + |
| 79 | +# 多头注意力层,多个自注意力连起来。使用qkv计算 |
| 80 | +class Attention(nn.Module): |
| 81 | + def __init__(self, dim, heads=8): |
| 82 | + super().__init__() |
| 83 | + self.heads = heads |
| 84 | + self.scale = dim**-0.5 |
| 85 | + self.to_qkv = nn.Linear(dim, dim * 3, bias=False) |
| 86 | + self.to_out = nn.Linear(dim, dim) |
| 87 | + |
| 88 | + def forward(self, x, mask=None): |
| 89 | + b, n, _, h = *x.shape, self.heads |
| 90 | + qkv = self.to_qkv(x) |
| 91 | + q, k, v = rearrange(qkv, "b n (qkv h d) -> qkv b h n d", qkv=3, h=h) |
| 92 | + dots = torch.einsum("bhid,bhjd->bhij", q, k) * self.scale |
| 93 | + if mask is not None: |
| 94 | + mask = F.pad(mask.flatten(1), (1, 0), value=True) |
| 95 | + assert mask.shape[-1] == dots.shape[-1], "mask has incorrect dimensions" |
| 96 | + mask = mask[:, None, :] * mask[:, :, None] |
| 97 | + dots.masked_fill_(~mask, float("-inf")) |
| 98 | + del mask |
| 99 | + attn = dots.softmax(dim=-1) |
| 100 | + out = torch.einsum("bhij,bhjd->bhid", attn, v) |
| 101 | + out = rearrange(out, "b h n d -> b n (h d)") |
| 102 | + out = self.to_out(out) |
| 103 | + return out |
| 104 | + |
| 105 | + |
| 106 | +class Transformer(nn.Module): |
| 107 | + def __init__(self, dim, depth, heads, mlp_dim): |
| 108 | + super().__init__() |
| 109 | + self.layers = nn.ModuleList([]) |
| 110 | + for _ in range(depth): |
| 111 | + self.layers.append( |
| 112 | + nn.ModuleList( |
| 113 | + [ |
| 114 | + Residual(PreNorm(dim, Attention(dim, heads=heads))), |
| 115 | + Residual(PreNorm(dim, FeedForward(dim, mlp_dim))), |
| 116 | + ] |
| 117 | + ) |
| 118 | + ) |
| 119 | + |
| 120 | + def forward(self, x, mask=None): |
| 121 | + for attn, ff in self.layers: |
| 122 | + # print(f"batch size: {x.shape[0]}") # debug to locate how many img in per GPU |
| 123 | + x = attn(x, mask=mask) |
| 124 | + x = ff(x) |
| 125 | + return x |
| 126 | + |
| 127 | + |
| 128 | +# 将图像切割成一个个图像块,组成序列化的数据输入Transformer执行图像分类任务。 |
| 129 | +class ViT(nn.Module): |
| 130 | + def __init__( |
| 131 | + self, |
| 132 | + *, |
| 133 | + image_size, |
| 134 | + patch_size, |
| 135 | + num_classes, |
| 136 | + dim, |
| 137 | + depth, |
| 138 | + heads, |
| 139 | + mlp_dim, |
| 140 | + channels=3, |
| 141 | + ): |
| 142 | + super().__init__() |
| 143 | + assert ( |
| 144 | + image_size % patch_size == 0 |
| 145 | + ), "image dimensions must be divisible by the patch size" |
| 146 | + num_patches = (image_size // patch_size) ** 2 |
| 147 | + patch_dim = channels * patch_size**2 |
| 148 | + self.patch_size = patch_size |
| 149 | + self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) |
| 150 | + self.patch_to_embedding = nn.Linear(patch_dim, dim) |
| 151 | + self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) |
| 152 | + self.transformer = Transformer(dim, depth, heads, mlp_dim) |
| 153 | + self.to_cls_token = nn.Identity() |
| 154 | + self.mlp_head = nn.Sequential( |
| 155 | + nn.Linear(dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, num_classes) |
| 156 | + ) |
| 157 | + |
| 158 | + def forward(self, img, mask=None): |
| 159 | + p = self.patch_size |
| 160 | + x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p) |
| 161 | + x = self.patch_to_embedding(x) |
| 162 | + cls_tokens = self.cls_token.expand(img.shape[0], -1, -1) |
| 163 | + x = torch.cat((cls_tokens, x), dim=1) |
| 164 | + x += self.pos_embedding |
| 165 | + x = self.transformer(x, mask) |
| 166 | + x = self.to_cls_token(x[:, 0]) |
| 167 | + return self.mlp_head(x) |
| 168 | + |
| 169 | + |
| 170 | +### |
| 171 | +def train_epoch(model, optimizer, data_loader, loss_history): |
| 172 | + total_samples = len(data_loader.dataset) |
| 173 | + model.train() |
| 174 | + |
| 175 | + for i, (data, target) in enumerate(data_loader): |
| 176 | + optimizer.zero_grad() |
| 177 | + output = F.log_softmax(model(data.to(device)), dim=1) |
| 178 | + loss = F.nll_loss(output, target.to(device)) |
| 179 | + loss.backward() |
| 180 | + optimizer.step() |
| 181 | + |
| 182 | + if i % 128 == 0: |
| 183 | + print( |
| 184 | + "[" |
| 185 | + + "{:5}".format(i * len(data)) |
| 186 | + + "/" |
| 187 | + + "{:5}".format(total_samples) |
| 188 | + + " (" |
| 189 | + + "{:3.0f}".format(100 * i / len(data_loader)) |
| 190 | + + "%)] Loss: " |
| 191 | + + "{:6.4f}".format(loss.item()) |
| 192 | + ) |
| 193 | + loss_history.append(loss.item()) |
| 194 | + |
| 195 | + |
| 196 | +def evaluate(model, data_loader, loss_history): |
| 197 | + model.eval() |
| 198 | + |
| 199 | + total_samples = len(data_loader.dataset) |
| 200 | + correct_samples = 0 |
| 201 | + total_loss = 0 |
| 202 | + |
| 203 | + with torch.no_grad(): |
| 204 | + for data, target in data_loader: |
| 205 | + output = F.log_softmax(model(data.to(device)), dim=1) |
| 206 | + loss = F.nll_loss(output, target.to(device), reduction="sum") |
| 207 | + _, pred = torch.max(output, dim=1) |
| 208 | + |
| 209 | + total_loss += loss.item() |
| 210 | + correct_samples += pred.eq(target.to(device)).sum() |
| 211 | + |
| 212 | + avg_loss = total_loss / total_samples |
| 213 | + loss_history.append(avg_loss) |
| 214 | + print( |
| 215 | + "\nAverage test loss: " |
| 216 | + + "{:.4f}".format(avg_loss) |
| 217 | + + " Accuracy:" |
| 218 | + + "{:5}".format(correct_samples) |
| 219 | + + "/" |
| 220 | + + "{:5}".format(total_samples) |
| 221 | + + " (" |
| 222 | + + "{:14.2f}".format(100.0 * correct_samples / total_samples) |
| 223 | + + "%)\n" |
| 224 | + ) |
| 225 | + |
| 226 | + |
| 227 | +if __name__ == "__main__": |
| 228 | + # mnist |
| 229 | + transform_mnist = torchvision.transforms.Compose( |
| 230 | + [ |
| 231 | + torchvision.transforms.ToTensor(), |
| 232 | + torchvision.transforms.Normalize((0.1307,), (0.3081,)), |
| 233 | + ] |
| 234 | + ) |
| 235 | + |
| 236 | + train_set = torchvision.datasets.MNIST( |
| 237 | + DOWNLOAD_PATH, train=True, download=True, transform=transform_mnist |
| 238 | + ) |
| 239 | + train_loader = torch.utils.data.DataLoader( |
| 240 | + train_set, batch_size=BATCH_SIZE_TRAIN, shuffle=True |
| 241 | + ) |
| 242 | + |
| 243 | + test_set = torchvision.datasets.MNIST( |
| 244 | + DOWNLOAD_PATH, train=False, download=True, transform=transform_mnist |
| 245 | + ) |
| 246 | + test_loader = torch.utils.data.DataLoader( |
| 247 | + test_set, batch_size=BATCH_SIZE_TEST, shuffle=True |
| 248 | + ) |
| 249 | + |
| 250 | + # FashionMNIST |
| 251 | + # Download training data from open datasets. |
| 252 | + training_data = datasets.FashionMNIST( |
| 253 | + root=DOWNLOAD_DATA_PATH, |
| 254 | + train=True, |
| 255 | + download=True, |
| 256 | + transform=ToTensor(), |
| 257 | + ) |
| 258 | + |
| 259 | + # Download test data from open datasets. |
| 260 | + test_data = datasets.FashionMNIST( |
| 261 | + root=DOWNLOAD_DATA_PATH, |
| 262 | + train=False, |
| 263 | + download=True, |
| 264 | + transform=ToTensor(), |
| 265 | + ) |
| 266 | + exit() |
| 267 | + |
| 268 | + # batch_size = 64 |
| 269 | + |
| 270 | + # Create data loaders. |
| 271 | + train_dataloader = DataLoader( |
| 272 | + training_data, |
| 273 | + batch_size=batch_size, |
| 274 | + ) |
| 275 | + test_dataloader = DataLoader( |
| 276 | + test_data, |
| 277 | + batch_size=batch_size, |
| 278 | + ) |
| 279 | + |
| 280 | + for X, y in test_dataloader: |
| 281 | + print(f"Shape of X [N, C, H, W]: {X.shape}") |
| 282 | + print(f"Shape of y: {y.shape} {y.dtype}") |
| 283 | + break |
| 284 | + |
| 285 | + for X, y in test_loader: |
| 286 | + print(f"Shape of X [N, C, H, W]: {X.shape}") |
| 287 | + print(f"Shape of y: {y.shape} {y.dtype}") |
| 288 | + break |
| 289 | + |
| 290 | + EPOCHS_NUM = 100 |
| 291 | + # start_time = time.time() |
| 292 | + |
| 293 | + """ |
| 294 | + patch大小为 7x7(对于 28x28 图像,这意味着每个图像 4 x 4 = 16 个patch)、10 个可能的目标类别(0 到 9)和 1 个颜色通道(因为图像是灰度)。 |
| 295 | + 在网络参数方面,使用了 64 个单元的维度,6 个 Transformer 块的深度,8 个 Transformer 头,MLP 使用 128 维度。 |
| 296 | + """ |
| 297 | + model = ViT( |
| 298 | + image_size=28, |
| 299 | + patch_size=7, |
| 300 | + num_classes=10, |
| 301 | + channels=1, |
| 302 | + dim=64, |
| 303 | + depth=6, |
| 304 | + heads=8, |
| 305 | + mlp_dim=128, |
| 306 | + ) |
| 307 | + # print(f"No device:\n{model}") |
| 308 | + # 模型并行化 |
| 309 | + if torch.cuda.device_count() > 1: |
| 310 | + print("Let's use", torch.cuda.device_count(), "GPUs!") |
| 311 | + model = nn.DataParallel(model) |
| 312 | + # model = nn.parallel.DistributedDataParallel(model) |
| 313 | + |
| 314 | + model = model.to(device) |
| 315 | + # print(f"On device:\n{model}") |
| 316 | + |
| 317 | + loss_fn = ( |
| 318 | + nn.CrossEntropyLoss() |
| 319 | + ) # Ref: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html |
| 320 | + |
| 321 | + optimizer = optim.Adam(model.parameters(), lr=0.008) |
| 322 | + # Scheduler https://arxiv.org/pdf/1812.01187.pdf |
| 323 | + # lf = ( |
| 324 | + # lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) |
| 325 | + # + args.lrf |
| 326 | + # ) # cosine |
| 327 | + # scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf) |
| 328 | + |
| 329 | + train_loss_history, test_loss_history = [], [] |
| 330 | + for epoch in range(1, EPOCHS_NUM + 1): |
| 331 | + print("Epoch:", epoch) |
| 332 | + start_time = time.time() |
| 333 | + train_epoch(model, optimizer, train_loader, train_loss_history) |
| 334 | + evaluate(model, test_loader, test_loss_history) |
| 335 | + print( |
| 336 | + "This EPOCH takes time:", |
| 337 | + "{:5.2f}".format(time.time() - start_time), |
| 338 | + "seconds", |
| 339 | + ) |
| 340 | + |
| 341 | + # print("Execution time:", "{:5.2f}".format(time.time() - start_time), "seconds") |
0 commit comments