Post

Pytorch Accelerate多GPU训练推理

Pytorch Accelerate多GPU训练推理

  • 训练速度基本是两倍的提升。
  • 很方便的使用混合精度进行训练,进行加速。
#GPUfp16Batch size per GPUSeconds per epoch
1no25660
2no25633
2no12841
2yes12835

使用Accelerate修改单GPU代码,实现GPUs的训练和推理

Accelerate 参数与混合精度训练

多GPU打印主进程信息

多GPU训练时设置正确batch size

Single-GPU

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import pandas as pd
import numpy as np
import cv2

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, Subset
from torch.nn import CrossEntropyLoss

import albumentations as A
from albumentations.pytorch import ToTensorV2

import timm
from tqdm import tqdm


def get_transform(image_size, train=True):
    if train:
        return A.Compose([
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.5),
                    A.RandomBrightnessContrast(p=0.2),
                    A.Resize(image_size, image_size, interpolation=cv2.INTER_LANCZOS4),
                    A.Normalize(0.1310, 0.30854),
                    ToTensorV2(),
                ])
    else:
        return A.Compose([
                    A.Resize(image_size, image_size, interpolation=cv2.INTER_LANCZOS4),
                    A.Normalize(0.1310, 0.30854),
                    ToTensorV2(),
                ])
    
    
class MiniDataSet(Dataset):
    
    def __init__(self, images, labels=None, transform=None):
        self.images = images.astype("float32")
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        ret = {}
        img = self.images[idx]
        
        if self.transform is not None:
            img = self.transform(image=img)["image"]
        ret["image"] = img
        
        if self.labels is not None:
            ret["label"] = self.labels[idx]
        
        return ret
    
    
class TimmModel(nn.Module):
    
    def __init__(self, backbone, num_class, pretrained=False):
        
        super().__init__()
        self.model = timm.create_model(backbone, pretrained=pretrained, in_chans=1, num_classes=num_class)
    
    def forward(self, image):
        logit = self.model(image)
        return logit


def train_fn(args, model, optimizer, dataloader):
    model.to(args.device)
    model.train()
    train_loss = []
    loss_fn = torch.nn.CrossEntropyLoss()
    
    for batch in tqdm(dataloader):
        logits = model(batch["image"].to(args.device))
        optimizer.zero_grad()
        loss = loss_fn(logits, batch["label"].to(args.device))
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
    
    return np.mean(train_loss)


@torch.no_grad()
def predict_fn(args, model, dataloader):
    model.to(args.device)
    model.eval()
    predictions = []
    
    for step, batch in enumerate(dataloader):
        output = model(batch["image"].to(args.device))
        prediction = torch.argmax(output, 1)
        predictions.append(prediction.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0)
    return predictions


def fit_model(args, model, optimizer, train_dl, val_dl):
    best_score = 0.
    for ep in range(args.ep):
        train_loss = train_fn(args, model, optimizer, train_dl)
        val_pred = predict_fn(args, model, val_dl)
        val_acc = np.mean(val_pred == val_dl.dataset.labels)
        print(f"Epoch {ep+1}, train loss {train_loss:.4f}, val acc {val_acc:.4f}")
        if val_acc > best_score:
            best_score = val_acc
            torch.save(model.state_dict(), "model.bin")
    model.load_state_dict(torch.load("model.bin"))
    return model
    
    

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--ep", default=5, type=int)
    parser.add_argument("--lr", default=0.00005, type=float)
    parser.add_argument("--bs", default=256, type=int)
    parser.add_argument("--device", default=0, type=int)
    parser.add_argument("--model", default="convnext_small")
    parser.add_argument("--image_size", default=56, type=int)
    args = parser.parse_args()
    
    train = pd.read_csv("digit-recognizer/train.csv")
    train_images = train.iloc[:, 1:].values.reshape(-1, 28, 28)
    train_labels = train.iloc[:, 0].values
    
    test = pd.read_csv("digit-recognizer/test.csv")
    test_images = test.values.reshape(-1, 28, 28)
    
    submission = pd.read_csv("digit-recognizer/sample_submission.csv")

    
    train_transform = get_transform(args.image_size, True)
    valid_transform = get_transform(args.image_size, False)
    train_ds = MiniDataSet(train_images[:40000], train_labels[:40000], train_transform)
    val_ds = MiniDataSet(train_images[40000:], train_labels[40000:], valid_transform)
    test_ds = MiniDataSet(test_images, transform=valid_transform)
    
    train_dl = DataLoader(
    train_ds, 
    batch_size=args.bs, 
    num_workers=2, 
    shuffle=True, 
    drop_last=True)

    val_dl = DataLoader(
        val_ds, 
        batch_size=args.bs * 2, 
        num_workers=2, 
        shuffle=False, 
        drop_last=False)

    test_dl = DataLoader(
        test_ds, 
        batch_size=args.bs * 2, 
        num_workers=2, 
        shuffle=False, 
        drop_last=False)


    train_ds = MiniDataSet(train_images[:40000], train_labels[:40000], train_transform)
    val_ds = MiniDataSet(train_images[40000:], train_labels[40000:], valid_transform)
    test_ds = MiniDataSet(test_images, transform=valid_transform)
    
    model = TimmModel(args.model, 10, pretrained=True)
    optimizer = Adam(model.parameters(), lr=args.lr)
    model = fit_model(args, model, optimizer, train_dl, val_dl)
    
    test_pred = predict_fn(args, model, test_dl)
    submission["Label"] = test_pred
    submission.to_csv("./submission.csv", index=False)

Modify code for Multi-GPUs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
warnings.filterwarnings("ignore")

def get_transform(image_size, train=True):
    if train:
        return A.Compose([
                    A.HorizontalFlip(p=0.5),
                    A.VerticalFlip(p=0.5),
                    A.RandomBrightnessContrast(p=0.2),
                    A.Resize(image_size, image_size, interpolation=cv2.INTER_LANCZOS4),
                    A.Normalize(0.1310, 0.30854),
                    ToTensorV2(),
                ])
    else:
        return A.Compose([
                    A.Resize(image_size, image_size, interpolation=cv2.INTER_LANCZOS4),
                    A.Normalize(0.1310, 0.30854),
                    ToTensorV2(),
                ])
    
    
class MiniDataSet(Dataset):
    
    def __init__(self, images, labels=None, transform=None):
        self.images = images.astype("float32")
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        ret = {}
        img = self.images[idx]
        
        if self.transform is not None:
            img = self.transform(image=img)["image"]
        ret["image"] = img
        
        if self.labels is not None:
            ret["label"] = self.labels[idx]
        
        return ret
    
    
class TimmModel(nn.Module):
    
    def __init__(self, backbone, num_class, pretrained=False):
        
        super().__init__()
        self.model = timm.create_model(backbone, pretrained=pretrained, in_chans=1, num_classes=num_class)
    
    def forward(self, image):
        logit = self.model(image)
        return logit
    
    
def train_fn(args, model, optimizer, dataloader):
#     删除 to(device)
#     model.to(args.device)
    model.train()
    train_loss = []
    loss_fn = torch.nn.CrossEntropyLoss()

#     增加disable=not args.accelerator.is_main_process,只在主进程显示进度条,避免重复显示
#     for batch in tqdm(dataloader):
    for batch in tqdm(dataloader, disable=not args.accelerator.is_main_process):
#         logits = model(batch["image"].to(args.device))
        logits = model(batch["image"])
        optimizer.zero_grad()
#         loss = loss_fn(logits, batch["label"].to(args.device))
        loss = loss_fn(logits, batch["label"])
#         loss.backward 修改为 accelerator.backward(loss)
#         loss.backward()
        args.accelerator.backward(loss)
        optimizer.step()
        train_loss.append(loss.item())
    
    return np.mean(train_loss)


@torch.no_grad()
def predict_fn(args, model, dataloader):
#     删除 to(device)
#     model.to(args.device)
    model.eval()
    predictions = []
    
    for step, batch in enumerate(dataloader):
#         output = model(batch["image"].to(args.device))
        output = model(batch["image"])
        prediction = torch.argmax(output, 1)
#         使用accelerator.gather_for_metrics(prediction)汇总多张GPU预测结果
        prediction = args.accelerator.gather_for_metrics(prediction)
        predictions.append(prediction.cpu().numpy())
    
    predictions = np.concatenate(predictions, axis=0)
    return predictions


# def train_fn(args, model, optimizer, dataloader):
#     # 删除 to(device)
#     # loss.backward 修改为 accelerator.backward(loss)
#     # 增加disable=not args.accelerator.is_main_process,只在主进程显示进度条,避免重复显示
#     model.train()
#     train_loss = []
#     loss_fn = torch.nn.CrossEntropyLoss()
    
#     for batch in tqdm(dataloader, disable=not args.accelerator.is_main_process):
#         logits = model(batch["image"])
#         optimizer.zero_grad()
#         loss = loss_fn(logits, batch["label"])
#         args.accelerator.backward(loss)
#         optimizer.step()
#         train_loss.append(loss.item())
    
#     return np.mean(train_loss)


# @torch.no_grad()
# def predict_fn(args, model, dataloader):
#     # 删除 to(device)
#     # 使用accelerator.gather_for_metrics(prediction)汇总多张GPU预测结果
#     model.eval()
#     predictions = []
    
#     for step, batch in enumerate(dataloader):
#         output = model(batch["image"])
#         prediction = torch.argmax(output, 1)
#         prediction = args.accelerator.gather_for_metrics(prediction)
#         predictions.append(prediction.cpu().numpy())
    
#     predictions = np.concatenate(predictions, axis=0)
#     return predictions


def fit_model(args, model, optimizer, train_dl, val_dl):
    best_score = 0.
    for ep in range(args.ep):
        train_loss = train_fn(args, model, optimizer, train_dl)
        val_pred = predict_fn(args, model, val_dl)
        val_acc = np.mean(val_pred == val_dl.dataset.labels)
        if args.accelerator.is_main_process:
            print(f"Epoch {ep+1}, train loss {train_loss:.4f}, val acc {val_acc:.4f}")
        if val_acc > best_score:
            best_score = val_acc
            torch.save(model.state_dict(), "model.bin")
    model.load_state_dict(torch.load("model.bin"))
    return model
    
    

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--ep", default=5, type=int)
    parser.add_argument("--lr", default=0.00005, type=float)
    parser.add_argument("--bs", default=256, type=int)
    parser.add_argument("--device", default=0, type=int)
    parser.add_argument("--model", default="convnext_small")
    parser.add_argument("--image_size", default=56, type=int)
    args = parser.parse_args()
    
    
    train = pd.read_csv("digit-recognizer/train.csv")
    train_images = train.iloc[:, 1:].values.reshape(-1, 28, 28)
    train_labels = train.iloc[:, 0].values
    
    test = pd.read_csv("digit-recognizer/test.csv")
    test_images = test.values.reshape(-1, 28, 28)
    
    submission = pd.read_csv("digit-recognizer/sample_submission.csv")

    
    train_transform = get_transform(args.image_size, True)
    valid_transform = get_transform(args.image_size, False)
    train_ds = MiniDataSet(train_images[:40000], train_labels[:40000], train_transform)
    val_ds = MiniDataSet(train_images[40000:], train_labels[40000:], valid_transform)
    test_ds = MiniDataSet(test_images, transform=valid_transform)
    
    train_dl = DataLoader(
    train_ds, 
    batch_size=args.bs, 
    num_workers=2, 
    shuffle=True, 
    drop_last=True)

    val_dl = DataLoader(
        val_ds, 
        batch_size=args.bs * 2, 
        num_workers=2, 
        shuffle=False, 
        drop_last=False)

    test_dl = DataLoader(
        test_ds, 
        batch_size=args.bs * 2, 
        num_workers=2, 
        shuffle=False, 
        drop_last=False)
    
    model = TimmModel(args.model, 10, pretrained=True)
    optimizer = Adam(model.parameters(), lr=args.lr)
    
    # 初始化Accelerator
    accelerator = Accelerator()
    
    # 多GPU训练准备
    model, optimizer, train_dl, val_dl, test_dl = accelerator.prepare(model, optimizer, train_dl, val_dl, test_dl)
    args.accelerator = accelerator
    
    model = fit_model(args, model, optimizer, train_dl, val_dl)
    
    test_pred = predict_fn(args, model, test_dl)
    if accelerator.is_local_main_process:
        submission["Label"] = test_pred
        submission.to_csv("./submission.csv", index=False)
This post is licensed under CC BY 4.0 by the author.