Spaces:
Paused
Paused
| import os | |
| import kaggle | |
| import tempfile | |
| import requests | |
| import multiprocessing | |
| import pandas as pd | |
| from bs4 import BeautifulSoup | |
| from concurrent.futures import ThreadPoolExecutor | |
| def _generate_sources() -> pd.DataFrame: | |
| """ Generate a dataset containing urls to retrieve data from""" | |
| dataset = pd.DataFrame({'type': [], 'name': [], 'url': []}) | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| kaggle.api.dataset_download_files('rohanrao/formula-1-world-championship-1950-2020', path=temp_dir, unzip=True) | |
| df = pd.read_csv(temp_dir + '/circuits.csv') | |
| # remove all columns except 'name' and 'url' | |
| df = df[['name', 'url']] | |
| df['type'] = 'circuit' | |
| dataset = pd.concat([dataset, df], ignore_index=True) | |
| # Drivers | |
| df = pd.read_csv(temp_dir + '/drivers.csv') | |
| # remove all columns except 'forename', 'surname' and 'url' | |
| df = df[['forename', 'surname', 'url']] | |
| # Join 'forename' and 'surname' columns | |
| df['name'] = df['forename'] + ' ' + df['surname'] | |
| df = df[['name', 'url']] | |
| df['type'] = 'driver' | |
| dataset = pd.concat([dataset, df], ignore_index=True) | |
| # Constructors | |
| df = pd.read_csv(temp_dir + '/constructors.csv') | |
| # Remove broken links | |
| df = df[(df['url'] != 'http://en.wikipedia.org/wiki/Turner_(constructor)') & (df['url'] != 'http://en.wikipedia.org/wiki/Hall_(constructor)')] | |
| # remove all columns except 'name' and 'url' | |
| df = df[['name', 'url']] | |
| df['type'] = 'constructor' | |
| dataset = pd.concat([dataset, df], ignore_index=True) | |
| # Races | |
| df = pd.read_csv(temp_dir + '/races.csv') | |
| # remove all columns except 'name' and 'url' | |
| df['name'] = df['name'] + " " + df['year'].astype(str) + "-" + df['round'].astype(str) | |
| df = df[['name', 'url']] | |
| df['type'] = 'race' | |
| dataset = pd.concat([dataset, df], ignore_index=True) | |
| # Seasons | |
| df = pd.read_csv(temp_dir + '/seasons.csv') | |
| # remove all columns except 'year' and 'url' | |
| df = df[['year', 'url']] | |
| df['name'] = 'Year ' + df['year'].astype(str) | |
| df = df[['name', 'url']] | |
| df['type'] = 'season' | |
| dataset = pd.concat([dataset, df], ignore_index=True) | |
| return dataset | |
| def _extract_paragraphs(url): | |
| response = requests.get(url) | |
| html = response.text | |
| soup = BeautifulSoup(html, "html.parser") | |
| pars = soup.find_all("p") | |
| pars = [p.get_text() for p in pars] | |
| return pars | |
| def generate_trainset(persist: bool = True, persist_path: str = './datasets', filename='train.csv') -> pd.DataFrame: | |
| """ | |
| Generate the dataset used to train the model. | |
| Parameters: | |
| persist (bool): Whether to save the generated dataset to a file. | |
| persist_path (str): The directory where the generated dataset will be saved. | |
| filename (str): The name of the file to save the dataset. | |
| Returns: | |
| pd.DataFrame: The generated DataFrame. | |
| """ | |
| if os.path.exists(persist_path + '/' + filename): | |
| return pd.read_csv(f"{persist_path}/{filename}") | |
| sources = _generate_sources() | |
| num_threads = multiprocessing.cpu_count() | |
| with ThreadPoolExecutor(max_workers=num_threads) as executor: | |
| paragraphs = list(executor.map(_extract_paragraphs, sources['url'])) | |
| paragraphs = [" ".join(p[0:5]).strip("\n") for p in paragraphs] # Take the first 4 paragraphs | |
| sources['description'] = paragraphs | |
| df = sources[['type', 'name', 'description']] | |
| if persist: | |
| os.makedirs(persist_path, exist_ok=True) | |
| df.to_csv(f"{persist_path}/{filename}", index=False) | |
| return df | |
| def generate_ragset(persist=True, persist_path: str = './datasets', filename='rag.csv') -> pd.DataFrame: | |
| """ | |
| Generate the dataset used for Retrieval-Augmented Generation. | |
| Parameters: | |
| persist (bool): Whether to save the generated dataset to a file. | |
| persist_path (str): The directory where the generated dataset will be saved. | |
| filename (str): The name of the file to save the dataset. | |
| Returns: | |
| pd.DataFrame: The generated DataFrame. | |
| """ | |
| if os.path.exists(persist_path + '/' + filename): | |
| return pd.read_csv(f"{persist_path}/{filename}") | |
| sources = _generate_sources() | |
| num_threads = multiprocessing.cpu_count() | |
| with ThreadPoolExecutor(max_workers=num_threads) as executor: | |
| paragraphs = list(executor.map(_extract_paragraphs, sources['url'])) | |
| paragraphs = [" ".join(p).strip("\n") for p in paragraphs] # Take all the paragraphs | |
| sources['description'] = paragraphs | |
| df = sources[['type', 'name', 'description']] | |
| if persist: | |
| os.makedirs(persist_path, exist_ok=True) | |
| df.to_csv(f"{persist_path}/{filename}", index=False) | |
| return df |