Ravi-9's picture
Update train.py
06aa4a6 verified
raw
history blame
1.92 kB
import torch
from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup
# from utils.dataset import BDTtsDataset
from inference import tts # reuse your model
training_config = {
"learning_rate": 1e-4,
"batch_size": 16,
"warmup_steps": 1000,
"gradient_accumulation_steps": 4,
"mixed_precision": True,
"save_strategy": "steps",
"save_steps": 500,
"eval_steps": 100,
"num_epochs": 5
}
def train():
dataset = BDTtsDataset("./data/train")
dataloader = DataLoader(dataset, batch_size=training_config["batch_size"], shuffle=True)
optimizer = AdamW(tts.model.parameters(), lr=training_config["learning_rate"])
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=training_config["warmup_steps"],
num_training_steps=len(dataloader) * training_config["num_epochs"]
)
scaler = torch.cuda.amp.GradScaler() if training_config["mixed_precision"] else None
step = 0
for epoch in range(training_config["num_epochs"]):
for batch in dataloader:
inputs, targets = batch
optimizer.zero_grad()
with torch.cuda.amp.autocast(enabled=scaler is not None):
outputs = tts.model(inputs)
loss = outputs.loss if hasattr(outputs, "loss") else torch.nn.functional.mse_loss(outputs, targets)
if scaler:
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
optimizer.step()
scheduler.step()
step += 1
if step % training_config["save_steps"] == 0:
torch.save(tts.model.state_dict(), f"checkpoints/model_step{step}.pth")
print(f"Saved checkpoint at step {step}")
if __name__ == "__main__":
train()