File size: 5,393 Bytes
8390b91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
from transformers import GPT2Tokenizer
from transformers import AutoImageProcessor, AutoModel
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2PreTrainedModel
# from encoder_service import RadDINOEncoder, GPT2WithImagePrefix
from huggingface_hub import hf_hub_download

import os
os.environ["TRANSFORMERS_NO_TF"] = "1"
from transformers import GPT2Tokenizer
from transformers import AutoImageProcessor, AutoModel
import torch
import torch.nn as nn





processor = AutoImageProcessor.from_pretrained('microsoft/rad-dino')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token 



class RadDINOEncoder(nn.Module):
    def __init__(self, model_name="microsoft/rad-dino"):
        super().__init__()
        self.processor = AutoImageProcessor.from_pretrained(model_name, use_fast=True)
        self.encoder = AutoModel.from_pretrained(model_name)

    def forward(self, image):
        inputs = self.processor(images=image, return_tensors="pt")
        outputs = self.encoder(**inputs)
        cls_embedding = outputs.last_hidden_state[:, 0, :]  # CLS token
        return cls_embedding.squeeze(0)  # Shape: (768,)


class GPT2WithImagePrefix(nn.Module):
    def __init__(self, gpt2_model, prefix_length=10, embed_dim=768):
        super().__init__()
        self.gpt2 = gpt2_model
        self.prefix_length = prefix_length

        # Project image embedding to GPT2 embedding space
        self.image_projector = nn.Linear(embed_dim, prefix_length * gpt2_model.config.n_embd)

    def forward(self, image_embeds, input_ids, attention_mask, labels=None):
        batch_size = input_ids.size(0)

        # Project image embedding to prefix tokens
        prefix = self.image_projector(image_embeds).view(batch_size, self.prefix_length, -1).to(input_ids.device)


        # Get GPT2 token embeddings
        token_embeds = self.gpt2.transformer.wte(input_ids)

        # Concatenate image prefix with token embeddings
        inputs_embeds = torch.cat((prefix, token_embeds), dim=1)

        # Extend attention mask
        extended_attention_mask = torch.cat([
            torch.ones((batch_size, self.prefix_length), dtype=attention_mask.dtype, device=attention_mask.device),
            attention_mask
        ], dim=1)

        # Feed to GPT2
        outputs = self.gpt2(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_attention_mask,
            labels=labels
        )
        return outputs






# CHECKPOINT_PATH = "checkpoints/gpt2_with_prefix_epoch_56.pt"
# TEST_CSV = "D:/GP/Rad-Dino_yarab efregha/IU_XRay/csv/testing_set.csv"
IMAGE_DIR = "D:/GP/Rad-Dino_yarab efregha/IU_XRay/images"
MAX_LENGTH = 128
BATCH_SIZE = 1
PREFIX_LENGTH = 10
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_CSV = "generated_vs_groundtruth.csv"

# -------------------- Load Processor, Tokenizer, Encoder ----------------
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
processor = AutoImageProcessor.from_pretrained("microsoft/rad-dino")
# -------------------- Rebuild the Model --------------------
gpt2 = GPT2LMHeadModel.from_pretrained("gpt2")
gpt2.resize_token_embeddings(len(tokenizer))
model = GPT2WithImagePrefix(gpt2, prefix_length=PREFIX_LENGTH).to(DEVICE)



#Environment variable for Hugging Face token
CHECKPOINT_REPO = os.getenv("CHECKPOINT_REPO", "TransformingBerry/Raddino-vision-language-gpt2-CHEXMED")
CHECKPOINT_FILENAME = os.getenv("CHECKPOINT_FILENAME", "Gpt2_checkpoint.pt")
CHECKPOINT_PATH = hf_hub_download(repo_id=CHECKPOINT_REPO, filename=CHECKPOINT_FILENAME, cache_dir="/app/cache")





try:
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint["model_state_dict"])
except FileNotFoundError:
    raise FileNotFoundError(f"Checkpoint file not found at {CHECKPOINT_PATH}")



# # Load checkpoint
# checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
# model.load_state_dict(checkpoint["model_state_dict"])

image_encoder = RadDINOEncoder()
model.eval()


def generate_report_serviceFn(image):
    model.eval()
    image_encoder.eval()
    with torch.no_grad():
        # Process the image
        image_embeds = image_encoder(image).to(DEVICE)

        # Prepare empty input for generation
        empty_input_ids = tokenizer.encode("", return_tensors="pt").to(DEVICE).long()
        empty_attention_mask = torch.ones_like(empty_input_ids).to(DEVICE)

        # Generate report
        prefix = model.image_projector(image_embeds).view(1, model.prefix_length, -1)
        token_embeds = model.gpt2.transformer.wte(empty_input_ids)
        inputs_embeds = torch.cat((prefix, token_embeds), dim=1)

        extended_attention_mask = torch.cat([
            torch.ones((1, model.prefix_length), device=DEVICE),
            empty_attention_mask
        ], dim=1)

        generated_ids = model.gpt2.generate(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_attention_mask,
            max_length=model.prefix_length + 60,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

        generated_text = tokenizer.decode(generated_ids[0][model.prefix_length:], skip_special_tokens=True)

        return generated_text