|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import AdamW, get_linear_schedule_with_warmup |
|
|
|
|
|
from inference import tts |
|
|
|
|
|
training_config = { |
|
|
"learning_rate": 1e-4, |
|
|
"batch_size": 16, |
|
|
"warmup_steps": 1000, |
|
|
"gradient_accumulation_steps": 4, |
|
|
"mixed_precision": True, |
|
|
"save_strategy": "steps", |
|
|
"save_steps": 500, |
|
|
"eval_steps": 100, |
|
|
"num_epochs": 5 |
|
|
} |
|
|
|
|
|
def train(): |
|
|
dataset = BDTtsDataset("./data/train") |
|
|
dataloader = DataLoader(dataset, batch_size=training_config["batch_size"], shuffle=True) |
|
|
|
|
|
optimizer = AdamW(tts.model.parameters(), lr=training_config["learning_rate"]) |
|
|
scheduler = get_linear_schedule_with_warmup( |
|
|
optimizer, |
|
|
num_warmup_steps=training_config["warmup_steps"], |
|
|
num_training_steps=len(dataloader) * training_config["num_epochs"] |
|
|
) |
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler() if training_config["mixed_precision"] else None |
|
|
step = 0 |
|
|
|
|
|
for epoch in range(training_config["num_epochs"]): |
|
|
for batch in dataloader: |
|
|
inputs, targets = batch |
|
|
optimizer.zero_grad() |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=scaler is not None): |
|
|
outputs = tts.model(inputs) |
|
|
loss = outputs.loss if hasattr(outputs, "loss") else torch.nn.functional.mse_loss(outputs, targets) |
|
|
|
|
|
if scaler: |
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
else: |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
scheduler.step() |
|
|
step += 1 |
|
|
|
|
|
if step % training_config["save_steps"] == 0: |
|
|
torch.save(tts.model.state_dict(), f"checkpoints/model_step{step}.pth") |
|
|
print(f"Saved checkpoint at step {step}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
train() |
|
|
|
|
|
|