-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathgeneric_clip_pairs.py
107 lines (100 loc) · 5.63 KB
/
generic_clip_pairs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import os
import clip
import json
import argparse
import ruamel.yaml as yaml
from PIL import Image
import torch
import torchvision.transforms as transforms
from tqdm import tqdm
from albef.utils import *
from executor import AlbefExecutor
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str, help="Path to input JSON file")
parser.add_argument("--image_root", type=str, help="Path to directory containing images")
parser.add_argument("--albef_path", type=str, default=None, help="Path to ALBEF model/config/etc. if the goal is to use ALBEF")
parser.add_argument("--albef_itc", action="store_true", help="Use ITC output of ALBEF")
parser.add_argument("--clip_model", type=str, help="CLIP model to use")
parser.add_argument("--gpu", type=int, default=-1, help="Which gpu to use")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size for running CLIP")
args = parser.parse_args()
if args.albef_path is not None:
executor = AlbefExecutor(checkpoint_path = os.path.join(args.albef_path, "checkpoint.pth"), config_path = os.path.join(args.albef_path, "config.yaml"), device = "cpu" if args.gpu < 0 else "cuda:"+str(args.gpu))
model = executor.models[0]
preprocess = executor.preprocesses[0]
model = model.eval()
else:
model, preprocess = clip.load(args.clip_model, jit=False, device="cuda:"+str(args.gpu))
preprocess.transforms[0] == transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), transforms.InterpolationMode.BICUBIC)
model = model.eval()
input_file = open(args.input_path)
data = json.load(input_file)
input_file.close()
correct = 0
for i in tqdm(range(0, len(data), args.batch_size)):
batch_images = []
batch_text = []
for datum in data[i:min(i+args.batch_size, len(data))]:
img = Image.open(os.path.join(args.image_root, datum["image_filename"])).convert('RGB')
batch_images.append(preprocess(img))
if "text2" in datum:
if args.albef_path is None:
datum["text1"] = "a photo of "+datum["text1"]
datum["text2"] = "a photo of "+datum["text2"]
batch_text.append(datum["text1"])
batch_text.append(datum["text2"])
else:
img2 = Image.open(os.path.join(args.image_root, datum["image_filename2"])).convert('RGB')
batch_images.append(preprocess(img2))
batch_text.append(datum["text1"])
batch_images = torch.stack(batch_images).to("cuda:"+str(args.gpu))
if args.albef_path is None:
batch_text = clip.tokenize(batch_text).to("cuda:"+str(args.gpu))
else:
modified_text = [pre_caption(txt, executor.max_words) for txt in batch_text]
batch_text = executor.tokenizer(modified_text, padding='longest', return_tensors="pt")
for key in batch_text:
batch_text[key] = batch_text[key].to(batch_images.device)
with torch.no_grad():
if args.albef_path is None:
logits_per_image, logits_per_text = model(batch_images, batch_text)
else:
if not args.albef_itc:
if batch_images.shape[0]*2 == batch_text.input_ids.shape[0]:
batch_images = batch_images.unsqueeze(1).repeat(1, 2, 1, 1, 1).view(batch_images.shape[0]*2, batch_images.shape[1], batch_images.shape[2], batch_images.shape[3])
else:
assert batch_images.shape[0] ==2*batch_text.input_ids.shape[0]
batch_text.input_ids = batch_text.input_ids.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1)
batch_text.attention_mask = batch_text.attention_mask.unsqueeze(1).repeat(1, 2, 1).view(batch_images.shape[0], -1)
image_embeds = model.visual_encoder(batch_images)
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(batch_images.device)
output = model.text_encoder(
batch_text.input_ids,
attention_mask = batch_text.attention_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
)
vl_embeddings = output.last_hidden_state[:,0,:]
vl_output = model.itm_head(vl_embeddings)
logits_per_image = vl_output[:,1:2].view(-1, 2)
else:
image_embeds = model.visual_encoder(batch_images)
image_feat = torch.nn.functional.normalize(model.vision_proj(image_embeds[:,0,:]),dim=-1)
text_output = model.text_encoder(batch_text.input_ids, attention_mask = batch_text.attention_mask,
return_dict = True, mode = 'text')
text_embeds = text_output.last_hidden_state
text_feat = torch.nn.functional.normalize(model.text_proj(text_embeds[:,0,:]),dim=-1)
sim = image_feat@text_feat.t()/model.temp
logits_per_image = sim
if args.albef_path is None or args.albef_itc:
if logits_per_image.shape[0]*2 == logits_per_image.shape[1]:
for j in range(logits_per_image.shape[0]):
correct += 1 if logits_per_image[j,2*j].item() > logits_per_image[j,2*j+1].item() else 0
else:
assert logits_per_image.shape[0] == 2*logits_per_image.shape[1]
for j in range(logits_per_image.shape[1]):
correct += 1 if logits_per_image[2*j,j].item() > logits_per_image[2*j+1,j].item() else 0
else:
correct += (logits_per_image[:,0] > logits_per_image[:,1]).long().sum().item()
print("Accuracy:", correct/len(data))