Post

The use of CLIP

Basic Structure

  • pre-training 架构,对于输入的同一个图像和文本pair对,使他们的相似度越大越好,这就引出了对比学习的方法。简单来讲就是对角线的相似度最大,其他位置最小,以此来训练模型。
  • 核心代码

    图像编码器

  • 图像编码器使用的是ViT(简单来讲,就是将图片切割为大小相同的patch,每个patch可以看做是一个token,剩余的做法和普通的Transformer一样。)
  • 举例说明
    1. patch embedding:假如输入图片大小为224x224,patch大小为16x16,则每张图像会生成224x224/16x16=196个patch,即输入序列长度为196,每个patch维度16x16x3=768,即一共有196个token,每个token的维度是768。这里还需要加上一个特殊字符cls,因此最终的维度是197x768。
    1. positional embedding:位置编码就是将以上的197个token按着bert的绝对位置编码来进行设置。 ```python class VisionTransformer(nn.Module): def init(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): super().init() self.input_resolution = input_resolution self.output_dim = output_dim self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

      scale = width ** -0.5 self.class_embedding = nn.Parameter(scale * torch.randn(width)) self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) self.ln_pre = LayerNorm(width)

      self.transformer = Transformer(width, layers, heads)

      self.ln_post = LayerNorm(width) self.proj = nn.Parameter(scale * torch.randn(width, output_dim))

    def forward(self, x: torch.Tensor): x = self.conv1(x) # shape = [, width, grid, grid] x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [, width, grid ** 2] x = x.permute(0, 2, 1) # shape = [, grid ** 2, width] x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [, grid ** 2 + 1, width] x = x + self.positional_embedding.to(x.dtype) x = self.ln_pre(x)

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    
      x = x.permute(1, 0, 2)  # NLD -> LND
      x = self.transformer(x)
      x = x.permute(1, 0, 2)  # LND -> NLD
    
      # 取x的L这个维度的第一个向量 x.shape=[batch_size, width]
      x = self.ln_post(x[:, 0, :])
    
      if self.proj is not None:
          x = x @ self.proj
        
      # shape = [batch_size, output_dim]
      return x
    
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# CLIP核心实现
* 这里对CLIP的forward的过程做了一些更改,返回的是通过图像和文本编码器encode之后的tensor,对应的维度在代码中有标识
# Finetuning and CLIP (use of CLIP)

On many tasks, it is now common to initialize the network's weights using a similar architecture trained on some other data by modifying only the last layers. An example of how to do this is https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

The net used as a starting point is often designated as the "backbone", and the output of the backbone is named 'embedding'—the part doing the prediction the projection head.
 
Last year a model named CLIP got a lot of attention. It is trained using very simple cosine matching, trying to reduce the distance between some embeddings of images and the embedding of their legend. A pre-trained version is available at https://github.com/openai/CLIP

This sparkled a lot of derivative works, one of them beeing stable diffusion. 


Yet not necessary to answer the questions you may be interested in a [video explanation](https://www.youtube.com/watch?v=T9XSU0pKX2E) made by a researcher independent of the authors. 



**Q6.** CLIP claims to perform 0 shot classification by comparing the likelihood of several embedding of tuple (image, description). Adapt the Zero-shot example of the repository to perform 0-shot predictions on the Hymenoptera dataset. Do the same on Cifar-100.


**Q7.** Adapt the finetuning code of Q5 such that the network model from CLIP can now be used as the backbone for a fine tuning on Cifar-100 and 

**Q8.** Perform a training trying to tune the learning rate with a SGD optimizer and freezing/unfreezing the CLIP network. Also try different learning rate for SGD on the frozen paramters and the added layers. Can you find some recommendations? 

**Q9.** On the Cifar-100 dataset create embedding of images and  corresponding texts using 0-shot prediction (which is precisely the setting given in the example). Plot a few of the images and their corresponding texts using 2d PCA made on their common embedding space with a color depending if the point is from an image or from a text. Do you have any comment? Are the text and image spaces organized similarly? 

**Q10.** Does the training of Q7 improves the situation detected at Q9?

**Q11.** Try to improve the performances of Cifar-100 following recommendations of http://karpathy.github.io/2019/04/25/recipe/ and/or adding a multi-head attention layer instead of a single fully connected layer on the top of CLIP. 

# Install CLIP

install CLIP

!conda install –yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.2 !pip install ftfy regex tqdm !pip install git+https://github.com/openai/CLIP.git

1
2
3
4
5
6
7
8
9
10
11
* here we use CIFAR100 as an example
```python
import os
import clip
import torch
from torchvision.datasets import CIFAR100, ImageFolder
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np

1
2
3
4
5
6
7
8
9
10
# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('ViT-B/32', device)

# Download CIFAR dataset
cifar100 = CIFAR100(root='./data', download=True, train=False, transform=preprocess)

# download Hymenoptera dataset at https://download.pytorch.org/tutorial/hymenoptera_data.zip
!wget -P './data' https://download.pytorch.org/tutorial/hymenoptera_data.zip
!unzip -q './data/hymenoptera_data.zip' -d './data'

Q6: zero-shot predictions

Cifar100

  • In the first time, we perform zero-shot classification on a single sample, showing the top 5 predictions made by the model ```python input_index = 1096 text_all_class = torch.cat([clip.tokenize(f”a photo of a {c}”) for c in cifar100.classes]).to(device)

with torch.no_grad(): text_features = model.encode_text(text_all_class) text_features /= text_features.norm(dim=-1, keepdim=True)

image, class_id = cifar100[input_index] image_input = image.unsqueeze(0).to(device) image_features = model.encode_image(image_input) image_features /= image_features.norm(dim=-1, keepdim=True)

similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) values, indices = similarity[0].topk(5)

print(f”The true label is {cifar100.classes[class_id]}”)

print(“\nTop predictions:\n”) for value, index in zip(values, indices): print(f”{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%”)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
* We can notice that the legend prediction corresponds well to the input image (the correct answer gets more than 90% probability).

* Now, we compute the average prediction accuracy on the whole Cifar100 dataset.
```python
def evaluate(model, dataloader, text_all_class):
    n_correct = 0
    total = 0
    with torch.no_grad():
        text_features = model.encode_text(text_all_class)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        for images, labels in dataloader:
          images = images.to(device)
          image_features = model.encode_image(images)
          image_features /= image_features.norm(dim=-1, keepdim=True)

          similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
          _, indices = similarity.topk(1)
          indices = indices.reshape(-1,)
          n_correct += indices.eq(labels.to(device)).sum().item()
          total += labels.size(0)
    accuracy = n_correct / total
    return accuracy
1
2
3
4
5
dataloader = torch.utils.data.DataLoader(cifar100, batch_size=64,
                                        shuffle=True, num_workers=2)
text_all_class = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
accuracy = evaluate(model, dataloader, text_all_class)
print(f"For CIFAR100, The accuracy of zero-shot inference is {accuracy*100:.2f}%")
  • Using ViT/B-32, we get an accuracy about 62% by top1 similarity. If we refer to the original paper of CLIP, the accuracy of zero shot prediction on Cifar100 is 65.1%.
  • Here are some discussions. It seems that we have to ensemble the prompts to achieve such accuracy in paper, instead of using one single template of prompt.
  • It proves that the quality of prompts could be a main factor of model performance.

    Hymenoptera dataset

    ```python

    [refer to tutorial at] (https://pytorch.org/tutorials/beginner/transfer_learning_tutorial.html)

    data_dir = ‘./data/hymenoptera_data’ image_datasets = ImageFolder(os.path.join(data_dir, ‘train’), preprocess) text_all_class_hyme = torch.cat([clip.tokenize(f”a photo of a {c}”) for c in [‘ant’, ‘bee’] ]).to(device) input_index = 24

with torch.no_grad(): text_features = model.encode_text(text_all_class_hyme) text_features /= text_features.norm(dim=-1, keepdim=True)

image_input_, class_id = image_datasets[input_index] image_input = (image_input_).unsqueeze(0).to(device) image_features = model.encode_image(image_input) image_features /= image_features.norm(dim=-1, keepdim=True)

similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) values, indices = similarity[0].topk(2)

print(f”The true label is {[‘ant’, ‘bee’][class_id]}”)

print(“\nTop predictions:\n”) for value, index in zip(values, indices): print(f”{[‘ant’, ‘bee’][index]:>16s}: {100 * value.item():.2f}%”)

1
2
3
4
5
6
7
* On a randomly chosen sample, CLIP well predicts its class. 
```python
dataloader_hyme = torch.utils.data.DataLoader(image_datasets, batch_size=64,
                                        shuffle=True, num_workers=2)
text_all_class_hyme = torch.cat([clip.tokenize(f"a photo of a {c}") for c in ['ant', 'bee'] ]).to(device)
accuracy = evaluate(model, dataloader_hyme, text_all_class_hyme)
print(f"For Hymenoptera dataset, The accuracy of zero-shot inference is {accuracy*100:.2f}%")

Q7: Fine tuning on CIFAR

  • Fine tuning on CLIP (based on Vision Transformer Base, with patch size 32) ```python model_finetune, preprocess = clip.load(‘ViT-B/32’, device, jit=False)

Load the dataset

root = ‘./data’ train = CIFAR100(root, download=True, train=True, transform=preprocess) test = CIFAR100(root, download=True, train=False, transform=preprocess) trainloader = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True, num_workers=2)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
* Adapt the fine tuning code from the [discussion](https://github.com/openai/CLIP/issues/83). CLIP uses mix-precision training: float32 for optimization and float16 for forward/backward, as discussed [here](https://github.com/openai/CLIP/issues/57). 
* Thus, we need to convert parameters and grads to float32 before applying optimizer, then convert it back. See it in the following code.
```python
text_all_class = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
loss_img = torch.nn.CrossEntropyLoss()
loss_txt = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_finetune.parameters(), lr=5e-5)
train_epoch = 3

model_finetune.train()
for epoch in range(train_epoch):  
    for inputs, labels in trainloader:
        texts = torch.cat([clip.tokenize(f"a photo of a {cifar100.classes[i]}") for i in np.array(labels)]).to(device)
        images = inputs.to(device) 

        optimizer.zero_grad()
        logits_per_image, logits_per_text  = model_finetune(images, texts) 
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)

        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
        total_loss.backward()
        # probs = logits_per_image.softmax(dim=-1)

        if device == "cpu":
          optimizer.step()
        else:
          # convert the weights and to float32 
          model_finetune.float()
          optimizer.step()
          # convert back to float16
          clip.model.convert_weights(model_finetune)

    train_acc = evaluate(model_finetune, trainloader, text_all_class)    
    test_acc = evaluate(model_finetune, testloader, text_all_class)
    print(f"Epoch {epoch+1}: training accuracy is {train_acc*100:.3f}% || test accuracy is {test_acc*100:.3f}%")
  • check total trained parameters
    1
    2
    3
    
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"There are {n_params:,} parameters in CLIP ViT-B/32 model.")
    print(f"There are {len(train)} training samples.")
    

    Q8: Freezing CLIP and add layers

  • Here, we’d like to add some fully connected layers on top of CLIP.

  • One common way to do that is to remove the last layer in the first time (or just set them as Identity()); then add some fully connected layers as “classifier head”. However, in CLIP, it is not so easy.

  • In forward function of CLIP’s source code, the blocks are not sequentially used: self.visual and self.transformer are called in parallel, respectively in encode_image and encode_text. Thus, adding sequentially the fully connected layers would not work here.

  • Inspired by source code, we apply fully connected layers on images features and text features. Then we compute their similarities, just as in source code.
    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    29
    30
    31
    32
    33
    34
    35
    36
    37
    38
    39
    40
    41
    42
    43
    44
    45
    
    class MyCLIP(torch.nn.Module):
      def __init__(self, model_CLIP):
          super().__init__()
          self.model = model_CLIP
          self.fc1 = torch.nn.Linear(512, 512, dtype=torch.float16)
          self.fc2 = torch.nn.Linear(512, 512, dtype=torch.float16)
    
      def forward(self, imgs, text):
          # x1, x2 = self.model(imgs, text) 经过colab实验得到,这个x1,x2和直接用model.encode_image, encode_text 之后的矩阵进行矩阵乘法再softmax之后的结果是一样的,这里的x1,x2同样需要softmax之后才行,不然差距很小
          # x1 = self.fc1(x1)
          # x1 = self.fc2(x1)
          # x2 = x1.t()
          image_features = self.model.encode_image(imgs)
          image_features = self.fc1(image_features)
          image_features = image_features / image_features.norm(dim=1, keepdim=True)
    
          text_features = self.model.encode_text(text)
          text_features = self.fc2(text_features)
          text_features = text_features / text_features.norm(dim=1, keepdim=True)
    
          logit_scale = self.model.logit_scale.exp()
          logits_per_image = logit_scale * image_features @ text_features.t()
          logits_per_text = logits_per_image.t()
          return logits_per_image, logits_per_text
        
      def evaluate(self, dataloader, text_all_class):
          n_correct = 0
          total = 0
          with torch.no_grad():
              text_features = self.model.encode_text(text_all_class)
              text_features = self.fc2(text_features)
              text_features /= text_features.norm(dim=-1, keepdim=True)
              for images, labels in dataloader:
                images = images.to(device)
                image_features = self.model.encode_image(images)
                image_features = self.fc1(image_features)
                image_features /= image_features.norm(dim=-1, keepdim=True)
    
                similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
                _, indices = similarity.topk(1)
                indices = indices.reshape(-1,)
                n_correct += indices.eq(labels.to(device)).sum().item()
                total += labels.size(0)
          accuracy = n_correct / total
          return accuracy
    
  • load model ```python model, preprocess = clip.load(‘ViT-B/32’, device, jit=False) myModel = MyCLIP(model) myModel = myModel.to(device)

Load the dataset

batch_size = 64 root = ‘./data’ train = CIFAR100(root, download=True, train=True, transform=preprocess) test = CIFAR100(root, download=True, train=False, transform=preprocess) trainloader = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(test, batch_size=batch_size, shuffle=True, num_workers=2)

1
2
3
4
5
```python
# we freeze the CLIP network
for para in myModel.model.parameters():
    para.requires_grad = False
# Maybe we can also use @torch.no_grad notation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
train_epoch = 5
optimizer = torch.optim.SGD(myModel.parameters(), lr=1e-3)
loss_img = torch.nn.CrossEntropyLoss()
loss_txt = torch.nn.CrossEntropyLoss()
text_all_class = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
for epoch in range(train_epoch):  
    for inputs, labels in trainloader:
        texts = torch.cat([clip.tokenize(f"a photo of a {cifar100.classes[i]}") for i in np.array(labels)]).to(device)
        images = inputs.to(device) 

        optimizer.zero_grad()
        logits_per_image, logits_per_text  = myModel(images, texts)
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)
        total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
        total_loss.backward()
        optimizer.step()

    train_acc = myModel.evaluate(trainloader, text_all_class)    
    test_acc = myModel.evaluate(testloader, text_all_class)
    print(f"Epoch {epoch+1}: training accuracy is {train_acc*100:.3f}% || test accuracy is {test_acc*100:.3f}%")
  • Using SGD optimizer with learning rate 1e-3 and by training only the last linear layer added on top of CLIP, we obtained a similar accuracy after 5 epochs compared with the results obtained by fine tuning the CLIP on CIFAR.

  • We would recommend adding some layers as classifier head to the CLIP network, freezing the CLIP and training only the added layers instead of the big network CLIP, since it would be much more faster.

    Reference

  • The Annotated CLIP
  • related work over prompts
This post is licensed under CC BY 4.0 by the author.