Post

Pytorch Accelerate multi-GPU Training Inference

Pytorch Accelerate multi-GPU Training Inference

  • enhance the training speed in around 2 times
  • easy to use the mix way in training to perform the acceleration
#GPUfp16Batch size per GPUSeconds per epoch
1no25660
2no25633
2no12841
2yes12835

Use Accelerate change the code of single GPU to implement the training and inference of multi-GPUs

Accelerate the parameters and mix-training

multi-GPU to print the information of main process

multi-GPU training and set the right batch size

Single-GPU

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import pandas as pd
import numpy as np
import cv2

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, Subset
from torch.nn import CrossEntropyLoss

import albumentations as A
from albumentations.pytorch import ToTensorV2

import timm
from tqdm import tqdm


def get_transform(image_size, train=True):
    if train:
        return A.Compose([
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.5),
                    A.RandomBrightnessContrast(p=0.2),
                    A.Resize(image_size, image_size, interpolation=cv2.INTER_LANCZOS4),
                    A.Normalize(0.1310, 0.30854),
                    ToTensorV2(),
                ])
    else:
        return A.Compose([
                    A.Resize(image_size, image_size, interpolation=cv2.INTER_LANCZOS4),
                    A.Normalize(0.1310, 0.30854),
                    ToTensorV2(),
                ])
    
    
class MiniDataSet(Dataset):
    
    def __init__(self, images, labels=None, transform=None):
        self.images = images.astype("float32")
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        ret = {}
        img = self.images[idx]
        
        if self.transform is not None:
            img = self.transform(image=img)["image"]
        ret["image"] = img
        
        if self.labels is not None:
            ret["label"] = self.labels[idx]
        
        return ret
    
    
class TimmModel(nn.Module):
    
    def __init__(self, backbone, num_class, pretrained=False):
        
        super().__init__()
        self.model = timm.create_model(backbone, pretrained=pretrained, in_chans=1, num_classes=num_class)
    
    def forward(self, image):
        logit = self.model(image)
        return logit


def train_fn(args, model, optimizer, dataloader):
    model.to(args.device)
    model.train()
    train_loss = []
    loss_fn = torch.nn.CrossEntropyLoss()
    
    for batch in tqdm(dataloader):
        logits = model(batch["image"].to(args.device))
        optimizer.zero_grad()
        loss = loss_fn(logits, batch["label"].to(args.device))
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
    
    return np.mean(train_loss)


@torch.no_grad()
def predict_fn(args, model, dataloader):
    model.to(args.device)
    model.eval()
    predictions = []
    
    for step, batch in enumerate(dataloader):
        output = model(batch["image"].to(args.device))
        prediction = torch.argmax(output, 1)
        predictions.append(prediction.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0)
    return predictions


def fit_model(args, model, optimizer, train_dl, val_dl):
    best_score = 0.
    for ep in range(args.ep):
        train_loss = train_fn(args, model, optimizer, train_dl)
        val_pred = predict_fn(args, model, val_dl)
        val_acc = np.mean(val_pred == val_dl.dataset.labels)
        print(f"Epoch {ep+1}, train loss {train_loss:.4f}, val acc {val_acc:.4f}")
        if val_acc > best_score:
            best_score = val_acc
            torch.save(model.state_dict(), "model.bin")
    model.load_state_dict(torch.load("model.bin"))
    return model
    
    

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--ep", default=5, type=int)
    parser.add_argument("--lr", default=0.00005, type=float)
    parser.add_argument("--bs", default=256, type=int)
    parser.add_argument("--device", default=0, type=int)
    parser.add_argument("--model", default="convnext_small")
    parser.add_argument("--image_size", default=56, type=int)
    args = parser.parse_args()
    
    train = pd.read_csv("digit-recognizer/train.csv")
    train_images = train.iloc[:, 1:].values.reshape(-1, 28, 28)
    train_labels = train.iloc[:, 0].values
    
    test = pd.read_csv("digit-recognizer/test.csv")
    test_images = test.values.reshape(-1, 28, 28)
    
    submission = pd.read_csv("digit-recognizer/sample_submission.csv")

    
    train_transform = get_transform(args.image_size, True)
    valid_transform = get_transform(args.image_size, False)
    train_ds = MiniDataSet(train_images[:40000], train_labels[:40000], train_transform)
    val_ds = MiniDataSet(train_images[40000:], train_labels[40000:], valid_transform)
    test_ds = MiniDataSet(test_images, transform=valid_transform)
    
    train_dl = DataLoader(
    train_ds, 
    batch_size=args.bs, 
    num_workers=2, 
    shuffle=True, 
    drop_last=True)

    val_dl = DataLoader(
        val_ds, 
        batch_size=args.bs * 2, 
        num_workers=2, 
        shuffle=False, 
        drop_last=False)

    test_dl = DataLoader(
        test_ds, 
        batch_size=args.bs * 2, 
        num_workers=2, 
        shuffle=False, 
        drop_last=False)


    train_ds = MiniDataSet(train_images[:40000], train_labels[:40000], train_transform)
    val_ds = MiniDataSet(train_images[40000:], train_labels[40000:], valid_transform)
    test_ds = MiniDataSet(test_images, transform=valid_transform)
    
    model = TimmModel(args.model, 10, pretrained=True)
    optimizer = Adam(model.parameters(), lr=args.lr)
    model = fit_model(args, model, optimizer, train_dl, val_dl)
    
    test_pred = predict_fn(args, model, test_dl)
    submission["Label"] = test_pred
    submission.to_csv("./submission.csv", index=False)

Modify code for Multi-GPUs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
warnings.filterwarnings("ignore")

def get_transform(image_size, train=True):
    if train:
        return A.Compose([
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.5),
                    A.RandomBrightnessContrast(p=0.2),
                    A.Resize(image_size, image_size, interpolation=cv2.INTER_LANCZOS4),
                    A.Normalize(0.1310, 0.30854),
                    ToTensorV2(),
                ])
    else:
        return A.Compose([
                    A.Resize(image_size, image_size, interpolation=cv2.INTER_LANCZOS4),
                    A.Normalize(0.1310, 0.30854),
                    ToTensorV2(),
                ])
    
    
class MiniDataSet(Dataset):
    
    def __init__(self, images, labels=None, transform=None):
        self.images = images.astype("float32")
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        ret = {}
        img = self.images[idx]
        
        if self.transform is not None:
            img = self.transform(image=img)["image"]
        ret["image"] = img
        
        if self.labels is not None:
            ret["label"] = self.labels[idx]
        
        return ret
    
    
class TimmModel(nn.Module):
    
    def __init__(self, backbone, num_class, pretrained=False):
        
        super().__init__()
        self.model = timm.create_model(backbone, pretrained=pretrained, in_chans=1, num_classes=num_class)
    
    def forward(self, image):
        logit = self.model(image)
        return logit
    
    
def train_fn(args, model, optimizer, dataloader):
#     model.to(args.device)
    model.train()
    train_loss = []
    loss_fn = torch.nn.CrossEntropyLoss()

#     增加disable=not args.accelerator.is_main_process,只在主进程显示进度条,避免重复显示
#     for batch in tqdm(dataloader):
    for batch in tqdm(dataloader, disable=not args.accelerator.is_main_process):
#         logits = model(batch["image"].to(args.device))
        logits = model(batch["image"])
        optimizer.zero_grad()
#         loss = loss_fn(logits, batch["label"].to(args.device))
        loss = loss_fn(logits, batch["label"])
#         loss.backward 修改为 accelerator.backward(loss)
#         loss.backward()
        args.accelerator.backward(loss)
        optimizer.step()
        train_loss.append(loss.item())
    
    return np.mean(train_loss)


@torch.no_grad()
def predict_fn(args, model, dataloader):
#     删除 to(device)
#     model.to(args.device)
    model.eval()
    predictions = []
    
    for step, batch in enumerate(dataloader):
#         output = model(batch["image"].to(args.device))
        output = model(batch["image"])
        prediction = torch.argmax(output, 1)
#         使用accelerator.gather_for_metrics(prediction)汇总多张GPU预测结果
        prediction = args.accelerator.gather_for_metrics(prediction)
        predictions.append(prediction.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0)
    return predictions


# def train_fn(args, model, optimizer, dataloader):
#     # 删除 to(device)
#     # loss.backward 修改为 accelerator.backward(loss)
#     # 增加disable=not args.accelerator.is_main_process,只在主进程显示进度条,避免重复显示
#     model.train()
#     train_loss = []
#     loss_fn = torch.nn.CrossEntropyLoss()
    
#     for batch in tqdm(dataloader, disable=not args.accelerator.is_main_process):
#         logits = model(batch["image"])
#         optimizer.zero_grad()
#         loss = loss_fn(logits, batch["label"])
#         args.accelerator.backward(loss)
#         optimizer.step()
#         train_loss.append(loss.item())
    
#     return np.mean(train_loss)


# @torch.no_grad()
# def predict_fn(args, model, dataloader):
#     # 删除 to(device)
#     # 使用accelerator.gather_for_metrics(prediction)汇总多张GPU预测结果
#     model.eval()
#     predictions = []
    
#     for step, batch in enumerate(dataloader):
#         output = model(batch["image"])
#         prediction = torch.argmax(output, 1)
#         prediction = args.accelerator.gather_for_metrics(prediction)
#         predictions.append(prediction.cpu().numpy())
    
#     predictions = np.concatenate(predictions, axis=0)
#     return predictions


def fit_model(args, model, optimizer, train_dl, val_dl):
    best_score = 0.
    for ep in range(args.ep):
        train_loss = train_fn(args, model, optimizer, train_dl)
        val_pred = predict_fn(args, model, val_dl)
        val_acc = np.mean(val_pred == val_dl.dataset.labels)
        if args.accelerator.is_main_process:
            print(f"Epoch {ep+1}, train loss {train_loss:.4f}, val acc {val_acc:.4f}")
        if val_acc > best_score:
            best_score = val_acc
            torch.save(model.state_dict(), "model.bin")
    model.load_state_dict(torch.load("model.bin"))
    return model
    
    

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--ep", default=5, type=int)
    parser.add_argument("--lr", default=0.00005, type=float)
    parser.add_argument("--bs", default=256, type=int)
    parser.add_argument("--device", default=0, type=int)
    parser.add_argument("--model", default="convnext_small")
    parser.add_argument("--image_size", default=56, type=int)
    args = parser.parse_args()
    
    
    train = pd.read_csv("digit-recognizer/train.csv")
    train_images = train.iloc[:, 1:].values.reshape(-1, 28, 28)
    train_labels = train.iloc[:, 0].values
    
    test = pd.read_csv("digit-recognizer/test.csv")
    test_images = test.values.reshape(-1, 28, 28)
    
    submission = pd.read_csv("digit-recognizer/sample_submission.csv")

    
    train_transform = get_transform(args.image_size, True)
    valid_transform = get_transform(args.image_size, False)
    train_ds = MiniDataSet(train_images[:40000], train_labels[:40000], train_transform)
    val_ds = MiniDataSet(train_images[40000:], train_labels[40000:], valid_transform)
    test_ds = MiniDataSet(test_images, transform=valid_transform)
    
    train_dl = DataLoader(
    train_ds, 
    batch_size=args.bs, 
    num_workers=2, 
    shuffle=True, 
    drop_last=True)

    val_dl = DataLoader(
        val_ds, 
        batch_size=args.bs * 2, 
        num_workers=2, 
        shuffle=False, 
        drop_last=False)

    test_dl = DataLoader(
        test_ds, 
        batch_size=args.bs * 2, 
        num_workers=2, 
        shuffle=False, 
        drop_last=False)
    
    model = TimmModel(args.model, 10, pretrained=True)
    optimizer = Adam(model.parameters(), lr=args.lr)
    
    # 初始化Accelerator
    accelerator = Accelerator()
    
    # 多GPU训练准备
    model, optimizer, train_dl, val_dl, test_dl = accelerator.prepare(model, optimizer, train_dl, val_dl, test_dl)
    args.accelerator = accelerator
    
    model = fit_model(args, model, optimizer, train_dl, val_dl)
    
    test_pred = predict_fn(args, model, test_dl)
    if accelerator.is_local_main_process:
        submission["Label"] = test_pred
        submission.to_csv("./submission.csv", index=False)
This post is licensed under CC BY 4.0 by the author.