Spaces:
Runtime error
Runtime error
| """ | |
| Fine-tuning Trainer | |
| Handles the fine-tuning process with OpenAI API | |
| """ | |
| import os | |
| import time | |
| from pathlib import Path | |
| from typing import Optional, Dict, Any | |
| from openai import OpenAI | |
| class FineTuningTrainer: | |
| """Manages fine-tuning jobs with OpenAI""" | |
| def __init__(self, api_key: Optional[str] = None): | |
| self.client = OpenAI(api_key=api_key or os.getenv('OPENAI_API_KEY')) | |
| self.jobs_dir = Path('fine_tuning/jobs') | |
| self.jobs_dir.mkdir(parents=True, exist_ok=True) | |
| def upload_training_file(self, file_path: str) -> str: | |
| """ | |
| Upload training file to OpenAI | |
| Args: | |
| file_path: Path to training data file (JSONL format) | |
| Returns: | |
| File ID from OpenAI | |
| """ | |
| print(f"π€ Uploading training file: {file_path}") | |
| with open(file_path, 'rb') as f: | |
| response = self.client.files.create( | |
| file=f, | |
| purpose='fine-tune' | |
| ) | |
| file_id = response.id | |
| print(f"β File uploaded successfully: {file_id}") | |
| return file_id | |
| def create_fine_tuning_job( | |
| self, | |
| training_file_id: str, | |
| model: str = "gpt-4o-mini-2024-07-18", | |
| suffix: Optional[str] = None, | |
| hyperparameters: Optional[Dict[str, Any]] = None | |
| ) -> str: | |
| """ | |
| Create a fine-tuning job | |
| Args: | |
| training_file_id: ID of uploaded training file | |
| model: Base model to fine-tune | |
| suffix: Suffix for fine-tuned model name | |
| hyperparameters: Training hyperparameters | |
| Returns: | |
| Fine-tuning job ID | |
| """ | |
| print(f"π Creating fine-tuning job...") | |
| print(f" Base model: {model}") | |
| print(f" Training file: {training_file_id}") | |
| job_params = { | |
| 'training_file': training_file_id, | |
| 'model': model | |
| } | |
| if suffix: | |
| job_params['suffix'] = suffix | |
| if hyperparameters: | |
| job_params['hyperparameters'] = hyperparameters | |
| response = self.client.fine_tuning.jobs.create(**job_params) | |
| job_id = response.id | |
| print(f"β Fine-tuning job created: {job_id}") | |
| # Save job info | |
| self._save_job_info(job_id, { | |
| 'training_file_id': training_file_id, | |
| 'model': model, | |
| 'suffix': suffix, | |
| 'hyperparameters': hyperparameters, | |
| 'status': 'created' | |
| }) | |
| return job_id | |
| def check_job_status(self, job_id: str) -> Dict[str, Any]: | |
| """ | |
| Check status of fine-tuning job | |
| Args: | |
| job_id: Fine-tuning job ID | |
| Returns: | |
| Job status information | |
| """ | |
| response = self.client.fine_tuning.jobs.retrieve(job_id) | |
| status_info = { | |
| 'id': response.id, | |
| 'status': response.status, | |
| 'model': response.model, | |
| 'fine_tuned_model': response.fine_tuned_model, | |
| 'created_at': response.created_at, | |
| 'finished_at': response.finished_at, | |
| 'trained_tokens': response.trained_tokens | |
| } | |
| return status_info | |
| def wait_for_completion( | |
| self, | |
| job_id: str, | |
| check_interval: int = 60, | |
| timeout: int = 3600 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Wait for fine-tuning job to complete | |
| Args: | |
| job_id: Fine-tuning job ID | |
| check_interval: Seconds between status checks | |
| timeout: Maximum seconds to wait | |
| Returns: | |
| Final job status | |
| """ | |
| print(f"β³ Waiting for fine-tuning job {job_id} to complete...") | |
| start_time = time.time() | |
| while True: | |
| status_info = self.check_job_status(job_id) | |
| status = status_info['status'] | |
| print(f" Status: {status}") | |
| if status == 'succeeded': | |
| print(f"β Fine-tuning completed!") | |
| print(f" Fine-tuned model: {status_info['fine_tuned_model']}") | |
| self._save_job_info(job_id, status_info) | |
| return status_info | |
| elif status in ['failed', 'cancelled']: | |
| print(f"β Fine-tuning {status}") | |
| self._save_job_info(job_id, status_info) | |
| raise Exception(f"Fine-tuning job {status}") | |
| elif time.time() - start_time > timeout: | |
| print(f"β° Timeout reached") | |
| raise TimeoutError(f"Fine-tuning job exceeded {timeout} seconds") | |
| time.sleep(check_interval) | |
| def list_fine_tuned_models(self) -> list: | |
| """ | |
| List all fine-tuned models | |
| Returns: | |
| List of fine-tuned model information | |
| """ | |
| response = self.client.fine_tuning.jobs.list(limit=50) | |
| models = [] | |
| for job in response.data: | |
| if job.fine_tuned_model: | |
| models.append({ | |
| 'job_id': job.id, | |
| 'model_id': job.fine_tuned_model, | |
| 'base_model': job.model, | |
| 'status': job.status, | |
| 'created_at': job.created_at, | |
| 'finished_at': job.finished_at | |
| }) | |
| return models | |
| def cancel_job(self, job_id: str) -> None: | |
| """ | |
| Cancel a running fine-tuning job | |
| Args: | |
| job_id: Fine-tuning job ID | |
| """ | |
| print(f"π Cancelling job {job_id}...") | |
| self.client.fine_tuning.jobs.cancel(job_id) | |
| print(f"β Job cancelled") | |
| def _save_job_info(self, job_id: str, info: Dict[str, Any]) -> None: | |
| """Save job information to file""" | |
| import json | |
| job_file = self.jobs_dir / f"{job_id}.json" | |
| with open(job_file, 'w') as f: | |
| json.dump(info, f, indent=2, default=str) | |
| def fine_tune_agent( | |
| agent_name: str, | |
| training_file: str, | |
| model: str = "gpt-4o-mini-2024-07-18", | |
| suffix: Optional[str] = None, | |
| wait_for_completion: bool = True | |
| ) -> str: | |
| """ | |
| Convenience function to fine-tune an agent | |
| Args: | |
| agent_name: Name of agent (nutrition, exercise, etc.) | |
| training_file: Path to training data | |
| model: Base model to use | |
| suffix: Suffix for model name | |
| wait_for_completion: Whether to wait for job to finish | |
| Returns: | |
| Fine-tuned model ID or job ID | |
| """ | |
| trainer = FineTuningTrainer() | |
| # Upload file | |
| file_id = trainer.upload_training_file(training_file) | |
| # Create job | |
| if suffix is None: | |
| suffix = f"{agent_name}-{int(time.time())}" | |
| job_id = trainer.create_fine_tuning_job( | |
| training_file_id=file_id, | |
| model=model, | |
| suffix=suffix | |
| ) | |
| # Wait for completion if requested | |
| if wait_for_completion: | |
| status = trainer.wait_for_completion(job_id) | |
| return status['fine_tuned_model'] | |
| else: | |
| return job_id | |