Pretrained Models: Vision Transformers (ViTs)#

They are a newer architecture that brings in the power of transformers to computer vision. They have shown to be very effective in image classification tasks. They are trained on the ImageNet dataset so we can use the same class labels from the PyTorch’s Hub repo.

from torchvision.models import vit_b_16
import torch
import torchvision.transforms as T
from PIL import Image
import urllib.request
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device
'cuda'
# load a sample image
img = Image.open("car.jpg").convert("RGB")
model = vit_b_16(pretrained=True).eval().to(device)
transform = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
])
# pass the image through the model
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
    logits = model(img_tensor)
    probs = torch.nn.functional.softmax(logits, dim=1)
# Download the class labels
# Download the txt file with human-readable labels
url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
filename = "imagenet_classes.txt"
urllib.request.urlretrieve(url, filename)

# Load labels
with open(filename) as f:
    labels = [line.strip() for line in f.readlines()] # read all labels line by line
top5 = torch.topk(probs, 5) # get the top 5 probabilities and their indices
for i in range(5):
    class_id = top5.indices[0][i].item() # get the index of the top 5 classes
    score = top5.values[0][i].item() # get the score of the top 5 classes
    print(f"{i+1}. {labels[class_id]} ({score:.4f})")
1. beach wagon (0.2748)
2. pickup (0.2295)
3. grille (0.1784)
4. convertible (0.1343)
5. cab (0.0189)

The imagenet dataset was limited to 1000 classes, so our output is inaccurate. We will need to find a pretrained model that was trained on the dataset we are interested in or on a bigger dataset.