DornierDo17 commited on
Commit
2f79958
·
1 Parent(s): e1c09f7

updated model

Browse files
Files changed (1) hide show
  1. fineTuning.py +2 -167
fineTuning.py CHANGED
@@ -1,10 +1,6 @@
1
  import torch
2
  import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch.amp import GradScaler, autocast
5
- from torch.utils.tensorboard import SummaryWriter
6
- from tqdm import tqdm
7
- from transformers import RobertaTokenizerFast, get_linear_schedule_with_warmup
8
 
9
  from modelFineTuning import RoBERTa
10
 
@@ -17,161 +13,11 @@ class RoBERTaModule(nn.Module):
17
  vocab_size=self.tokenizer.vocab_size,
18
  padding_idx=self.tokenizer.pad_token_id,
19
  num_labels=2,
20
- ) # !!
21
 
22
  def forward(self, x, attn_mask):
23
  return self.model(x, attn_mask)
24
 
25
- def train_model(
26
- self,
27
- train_loader,
28
- validation_loader,
29
- num_epochs,
30
- lr=2e-5,
31
- optimizer=None,
32
- scheduler=None,
33
- scaler=None,
34
- device="cuda",
35
- ):
36
- # device = torch.device("cuda")
37
- self.model.to(device)
38
-
39
- total_steps = len(train_loader) * num_epochs
40
- # 6% of the total number of steps are warmups steps as in paper
41
- warmup_steps = int(0.06 * total_steps)
42
-
43
- if optimizer is None:
44
- optimizer = torch.optim.AdamW(
45
- self.model.parameters(),
46
- lr=lr,
47
- betas=(0.99, 0.999),
48
- eps=1e-6,
49
- weight_decay=0.01,
50
- )
51
-
52
- if scheduler is None:
53
- scheduler = get_linear_schedule_with_warmup(
54
- # linear insted of cosine as in paper
55
- optimizer,
56
- num_warmup_steps=warmup_steps,
57
- num_training_steps=total_steps,
58
- )
59
-
60
- if scaler is None:
61
- scaler = GradScaler()
62
-
63
- writer = SummaryWriter()
64
-
65
- # early stopping
66
- patience_counter = 0
67
- patience_limit = 5
68
- epsilon = 1e-3
69
- best_valid_loss = float("inf")
70
- best_model_state = None
71
-
72
- for epoch in range(num_epochs):
73
- # train part
74
- self.model.train()
75
- total_loss_train = 0
76
-
77
- for batch_idx, batch in enumerate(
78
- tqdm(train_loader, desc=f"Training Epoch {epoch + 1}")
79
- ):
80
- input_ids, attention_mask, labels = (
81
- batch["input_ids"].to(device),
82
- batch["attention_mask"].to(device),
83
- batch["labels"].to(device),
84
- )
85
-
86
- with autocast(device_type="cuda", dtype=torch.float16, enabled=True):
87
- output = self.model(input_ids, attention_mask)
88
- loss = F.cross_entropy(
89
- output.view(-1, output.shape[-1]),
90
- labels.view(-1),
91
- ignore_index=-100,
92
- )
93
-
94
- scaler.scale(loss).backward()
95
- scaler.unscale_(optimizer) # unscale before clipping
96
- torch.nn.utils.clip_grad_norm_(
97
- self.model.parameters(), max_norm=1.0
98
- ) # gradient clipping
99
- scaler.step(optimizer)
100
- scheduler.step()
101
- scaler.update()
102
- optimizer.zero_grad()
103
-
104
- total_loss_train += loss.item()
105
-
106
- train_loss = total_loss_train / len(train_loader)
107
-
108
- # validation part
109
-
110
- self.model.eval()
111
-
112
- total_loss_valid = 0
113
- total_correct = 0
114
- total_tokens = 0
115
-
116
- with torch.no_grad():
117
- for batch_idx, batch in enumerate(validation_loader):
118
- input_ids, attention_mask, labels = (
119
- batch["input_ids"].to(device),
120
- batch["attention_mask"].to(device),
121
- batch["labels"].to(device),
122
- )
123
-
124
- with autocast(
125
- device_type="cuda", dtype=torch.float16, enabled=True
126
- ):
127
- output = self.model(input_ids, attention_mask)
128
- loss = F.cross_entropy(
129
- output.view(-1, output.shape[-1]),
130
- labels.view(-1),
131
- ignore_index=-100,
132
- label_smoothing=0.05,
133
- )
134
-
135
- preds = torch.argmax(output, dim=-1)
136
- correct = (preds == labels).float().sum()
137
-
138
- total_loss_valid += loss.item()
139
- total_correct += correct
140
- total_tokens += labels.size(0)
141
-
142
- validation_loss = total_loss_valid / len(validation_loader)
143
- validation_accuracy = total_correct / total_tokens
144
-
145
- if validation_loss < best_valid_loss - epsilon:
146
- best_valid_loss = validation_loss
147
- patience_counter = 0
148
- best_model_state = self.model.state_dict().copy()
149
- else:
150
- patience_counter += 1
151
- if patience_counter >= patience_limit:
152
- self.model.load_state_dict(best_model_state)
153
- self.save_checkpoint(
154
- best_model_state,
155
- optimizer,
156
- scheduler,
157
- scaler,
158
- path="cpEarly.pt",
159
- )
160
- break
161
-
162
- print(
163
- f"Epoch {epoch + 1}, train loss: {train_loss:.4f}, "
164
- f"Valid loss: {validation_loss:.4f}, "
165
- f"Validation accuracy: {validation_accuracy:.4f}"
166
- )
167
-
168
- writer.close()
169
-
170
- self.model.load_state_dict(best_model_state)
171
- self.save_checkpoint(
172
- self.model.state_dict(), optimizer, scheduler, path="finishedBest.pt"
173
- )
174
-
175
  def inference(self, sentece):
176
  self.model.eval()
177
  sentece_tokenized = self.tokenizer(
@@ -185,17 +31,6 @@ class RoBERTaModule(nn.Module):
185
  preds = torch.argmax(outputs, dim=-1).item()
186
  return preds
187
 
188
- def save_checkpoint(self, model, optimizer, scheduler, path="checkpoint.pt"):
189
- torch.save(
190
- {
191
- "model_state_dict": model,
192
- "optimizer_state_dict": optimizer.state_dict(),
193
- "scheduler_state_dict": scheduler.state_dict(),
194
- },
195
- path,
196
- )
197
- print(f"Checkpoint saved to {path}")
198
-
199
  def load_checkpoint(self, model=None, path="finished.pt", location="cuda"):
200
  checkpoint = torch.load(path, map_location=location, weights_only=True)
201
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import RobertaTokenizerFast
 
 
 
 
4
 
5
  from modelFineTuning import RoBERTa
6
 
 
13
  vocab_size=self.tokenizer.vocab_size,
14
  padding_idx=self.tokenizer.pad_token_id,
15
  num_labels=2,
16
+ )
17
 
18
  def forward(self, x, attn_mask):
19
  return self.model(x, attn_mask)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def inference(self, sentece):
22
  self.model.eval()
23
  sentece_tokenized = self.tokenizer(
 
31
  preds = torch.argmax(outputs, dim=-1).item()
32
  return preds
33
 
 
 
 
 
 
 
 
 
 
 
 
34
  def load_checkpoint(self, model=None, path="finished.pt", location="cuda"):
35
  checkpoint = torch.load(path, map_location=location, weights_only=True)
36