Spaces:
Running
Running
| #python app.py | |
| import gradio as gr | |
| import os | |
| import pandas as pd | |
| import requests | |
| from pathlib import Path | |
| import ctranslate2 | |
| import time | |
| import logging | |
| import transformers | |
| import json | |
| import io | |
| from tqdm import tqdm | |
| import subprocess | |
| from huggingface_hub import snapshot_download, upload_file, HfApi, create_repo | |
| # Function to download a Parquet file from a specified URL | |
| def download_parquet(url, local_path): | |
| response = requests.get(url, stream=True) | |
| if response.status_code == 200: | |
| with open(local_path, 'wb') as file: | |
| for chunk in response.iter_content(chunk_size=1024): | |
| file.write(chunk) | |
| print("File downloaded successfully.") | |
| else: | |
| print(f"Failed to download file, status code: {response.status_code}") | |
| # Function to convert Parquet files to JSONL format | |
| def convert_parquet_to_jsonl_polars(input_file, output_dir, override=False): | |
| output_dir_path = Path(output_dir) | |
| output_dir_path.mkdir(parents=True, exist_ok=True) | |
| input_path = Path(input_file) | |
| output_file_path = output_dir_path / input_path.with_suffix(".jsonl").name | |
| if output_file_path.exists() and not override: | |
| print(f"Skipping because output exists already: {output_file_path}") | |
| else: | |
| df = pl.read_parquet(input_path) | |
| df.write_ndjson(output_file_path) | |
| print(f"Data written to {output_file_path}") | |
| def convert_parquet_to_jsonl(parquet_filename, jsonl_filename): | |
| try: | |
| # Read the parquet file | |
| df = pd.read_parquet(parquet_filename) | |
| logger.info(f"Read Parquet file {parquet_filename} successfully.") | |
| # Convert the dataframe to a JSON string and handle Unicode characters and forward slashes | |
| json_str = df.to_json(orient='records', lines=True, force_ascii=False) | |
| logger.info(f"Converted Parquet file to JSON string.") | |
| # Replace escaped forward slashes if needed | |
| json_str = json_str.replace('\\/', '/') | |
| # Write the modified JSON string to the JSONL file | |
| jsonl_filename += '/train.jsonl' | |
| logger.info(f"Attempting to save to {jsonl_filename}") | |
| with open(jsonl_filename, 'w', encoding='utf-8') as file: | |
| file.write(json_str) | |
| logger.info(f"Data saved to {jsonl_filename}") | |
| except Exception as e: | |
| logger.error(f"Failed to convert Parquet to JSONL: {e}") | |
| raise | |
| # Function to count lines in a JSONL file | |
| def count_lines_in_jsonl(file_path): | |
| with open(file_path, 'r', encoding='utf-8') as file: | |
| line_count = sum(1 for _ in file) | |
| return line_count | |
| def parse_range_specification(range_specification, file_length): | |
| line_indices = [] | |
| ranges = range_specification.split(',') | |
| for r in ranges: | |
| if '-' in r: | |
| parts = r.split('-') | |
| start = int(parts[0]) - 1 if parts[0] else 0 | |
| end = int(parts[1]) - 1 if parts[1] else file_length - 1 | |
| if start < 0 or end >= file_length: | |
| logging.error(f"Range {r} is out of bounds.") | |
| continue # Skip ranges that are out of bounds | |
| line_indices.extend(range(start, end + 1)) | |
| else: | |
| single_line = int(r) - 1 | |
| if single_line < 0 or single_line >= file_length: | |
| logging.error(f"Line number {r} is out of bounds.") | |
| continue # Skip line numbers that are out of bounds | |
| line_indices.append(single_line) | |
| return line_indices | |
| def translate_text(text, translator, tokenizer, target_language): | |
| """ | |
| Translates the given text from English to German using CTranslate2 and the WMT21 model, | |
| with special handling for newlines and segmenting text longer than 500 characters. | |
| Ensures sequences of newlines (\n\n, \n\n\n, etc.) are accurately reproduced. | |
| """ | |
| try: | |
| segments = [] | |
| newline_sequences = [] # To store sequences of newlines | |
| segment = "" | |
| i = 0 | |
| while i < len(text): | |
| # Collect sequences of newlines | |
| if text[i] == '\n': | |
| newline_sequence = '\n' | |
| while i + 1 < len(text) and text[i + 1] == '\n': | |
| newline_sequence += '\n' | |
| i += 1 | |
| if segment: | |
| segments.append(segment) # Add the preceding text segment | |
| segment = "" | |
| newline_sequences.append(newline_sequence) # Store the newline sequence | |
| else: | |
| segment += text[i] | |
| # If segment exceeds 500 characters, or if we reach the end of the text, process it | |
| if len(segment) >= 500 or i == len(text) - 1: | |
| end_index = max(segment.rfind('.', 0, 500), segment.rfind('?', 0, 500), segment.rfind('!', 0, 500)) | |
| if end_index != -1 and len(segment) > 500: | |
| # Split at the last punctuation within the first 500 characters | |
| segments.append(segment[:end_index+1]) | |
| segment = segment[end_index+1:].lstrip() | |
| else: | |
| # No suitable punctuation or end of text, add the whole segment | |
| segments.append(segment) | |
| segment = "" | |
| i += 1 | |
| # Translate the collected text segments | |
| translated_segments = [] | |
| for segment in segments: | |
| source = tokenizer.convert_ids_to_tokens(tokenizer.encode(segment)) | |
| target_prefix = [tokenizer.lang_code_to_token[target_language]] | |
| results = translator.translate_batch([source], target_prefix=[target_prefix]) | |
| target = results[0].hypotheses[0][1:] | |
| translated_segment = tokenizer.decode(tokenizer.convert_tokens_to_ids(target)) | |
| translated_segments.append(translated_segment) | |
| # Reassemble the translated text with original newline sequences | |
| translated_text = "" | |
| for i, segment in enumerate(translated_segments): | |
| translated_text += segment | |
| if i < len(newline_sequences): | |
| translated_text += newline_sequences[i] # Insert the newline sequence | |
| return translated_text.strip() | |
| except Exception as e: | |
| logging.error(f"An error occurred during translation: {e}") | |
| return None | |
| def translate_item_ufb(item, raw_file_path, translator, tokenizer, target_language): | |
| try: | |
| # Translate the prompt directly since it's a string | |
| translated_prompt = translate_text(item['prompt'], translator, tokenizer) | |
| # Translate the chosen and rejected contents | |
| translated_chosen = [] | |
| for choice in item['chosen']: | |
| translated_content = translate_text(choice['content'], translator, tokenizer, target_language) | |
| translated_chosen.append({'content': translated_content, 'role': choice['role']}) | |
| translated_rejected = [] | |
| for choice in item['rejected']: | |
| translated_content = translate_text(choice['content'], translator, tokenizer, target_language) | |
| translated_rejected.append({'content': translated_content, 'role': choice['role']}) | |
| # Write the raw response to a backup file | |
| with open(raw_file_path, 'a', encoding='utf-8') as raw_file: | |
| raw_file.write(f"Prompt: {translated_prompt}\n") | |
| raw_file.write(f"Chosen: {json.dumps(translated_chosen, ensure_ascii=False)}\n") | |
| raw_file.write(f"Rejected: {json.dumps(translated_rejected, ensure_ascii=False)}\n\n") | |
| logging.info("Translation request successful.") | |
| # Update the original item with the translated fields | |
| item['prompt'] = translated_prompt | |
| item['chosen'] = translated_chosen | |
| item['rejected'] = translated_rejected | |
| return item | |
| except Exception as e: | |
| logging.error(f"An error occurred during translation: {e}") | |
| return None | |
| def validate_item_ufb(item): | |
| # Check basic required fields including 'prompt' as a simple string | |
| required_fields = ['source', 'prompt', 'chosen', 'rejected'] | |
| for field in required_fields: | |
| if field not in item: | |
| logging.warning(f"Missing required field: {field}") | |
| return False | |
| if field == 'prompt' and not isinstance(item['prompt'], str): | |
| logging.warning("Prompt must be a string.") | |
| return False | |
| # Check 'chosen' and 'rejected' which should be lists of dictionaries | |
| for field in ['chosen', 'rejected']: | |
| if not isinstance(item[field], list) or not item[field]: | |
| logging.warning(f"No entries or incorrect type for section: {field}") | |
| return False | |
| for idx, message in enumerate(item[field]): | |
| if 'content' not in message or 'role' not in message: | |
| logging.warning(f"Missing 'content' or 'role' field in {field} at index {idx}") | |
| return False | |
| if not isinstance(message['content'], str) or not isinstance(message['role'], str): | |
| logging.warning(f"Invalid type for 'content' or 'role' field in {field} at index {idx}") | |
| return False | |
| return True | |
| def translate_item_mix(item, raw_file_path, translator, tokenizer, target_language): | |
| """ | |
| Translates the relevant fields in the given item from English to German using CTranslate2 and the WMT21 model, | |
| and saves the raw response to a backup file. | |
| """ | |
| #print ("translating:", item) | |
| try: | |
| # Translate each part of the prompt separately and preserve the order | |
| translated_prompts = [] | |
| for message in item['prompt']: | |
| translated_content = translate_text(message['content'], translator, tokenizer, target_language) | |
| translated_prompts.append({'content': translated_content, 'role': message['role']}) | |
| # Translate the chosen and rejected contents | |
| translated_chosen_content = translate_text(item['chosen'][0]['content'], translator, tokenizer, target_language) | |
| translated_rejected_content = translate_text(item['rejected'][0]['content'], translator, tokenizer, target_language) | |
| # Write the raw response to a backup file | |
| with open(raw_file_path, 'a', encoding='utf-8') as raw_file: | |
| raw_file.write("Prompt content:\n") | |
| for translated_prompt in translated_prompts: | |
| raw_file.write(f"{translated_prompt['role']}: {translated_prompt['content']}\n") | |
| raw_file.write(f"Chosen content: {translated_chosen_content}\n") | |
| raw_file.write(f"Rejected content: {translated_rejected_content}\n\n") | |
| logging.info("Translation request successful.") | |
| except Exception as e: | |
| logging.error(f"An error occurred during translation: {e}") | |
| return None | |
| # Update the original item with the translated fields | |
| item['prompt'] = translated_prompts | |
| item['chosen'][0]['content'] = translated_chosen_content | |
| item['rejected'][0]['content'] = translated_rejected_content | |
| logging.info("Translation processing successful.") | |
| return item | |
| def validate_item_mix(item): | |
| """ | |
| Validates the structure, presence, and content of required fields in the given item, | |
| allowing for multiple elements in the 'prompt' field for multi-turn conversations. | |
| """ | |
| required_fields = ['dataset', 'prompt', 'chosen', 'rejected'] | |
| for field in required_fields: | |
| if field not in item: | |
| logging.warning(f"Missing required field: {field}") | |
| return False | |
| # Check for at least one element in 'prompt' and exactly one element in 'chosen' and 'rejected' | |
| if len(item['prompt']) < 1 or len(item['chosen']) != 1 or len(item['rejected']) != 1: | |
| logging.warning("Invalid number of elements in 'prompt', 'chosen', or 'rejected' field.") | |
| return False | |
| # Validate 'content' and 'role' fields in all messages of 'prompt', and single elements of 'chosen' and 'rejected' | |
| for choice in item['prompt'] + item['chosen'] + item['rejected']: | |
| if 'content' not in choice or 'role' not in choice: | |
| logging.warning("Missing 'content' or 'role' field in choice.") | |
| return False | |
| if not isinstance(choice['content'], str) or not isinstance(choice['role'], str): | |
| logging.warning("Invalid type for 'content' or 'role' field in choice.") | |
| return False | |
| return True | |
| def translate_item_ufb_cached(item, raw_file_path, translator, tokenizer, target_language): | |
| try: | |
| translated_texts = {} # Cache to store translated texts | |
| # Translate the prompt if necessary (which is a user input and can appear again) | |
| if item['prompt'] not in translated_texts: | |
| translated_prompt = translate_text(item['prompt'], translator, tokenizer, target_language) | |
| translated_texts[item['prompt']] = translated_prompt | |
| else: | |
| translated_prompt = translated_texts[item['prompt']] | |
| # Helper function to handle content translation with caching | |
| def get_translated_content(content): | |
| if content not in translated_texts: | |
| translated_texts[content] = translate_text(content, translator, tokenizer, target_language) | |
| return translated_texts[content] | |
| # Process translations for chosen and rejected sections | |
| def translate_interactions(interactions): | |
| translated_interactions = [] | |
| for interaction in interactions: | |
| translated_content = get_translated_content(interaction['content']) | |
| translated_interactions.append({'content': translated_content, 'role': interaction['role']}) | |
| return translated_interactions | |
| translated_chosen = translate_interactions(item['chosen']) | |
| translated_rejected = translate_interactions(item['rejected']) | |
| # Write the raw response to a backup file | |
| with open(raw_file_path, 'a', encoding='utf-8') as raw_file: | |
| raw_file.write(f"Prompt: {translated_prompt}\n") | |
| raw_file.write(f"Chosen: {json.dumps(translated_chosen, ensure_ascii=False)}\n") | |
| raw_file.write(f"Rejected: {json.dumps(translated_rejected, ensure_ascii=False)}\n\n") | |
| logging.info("Translation request successful.") | |
| # Update the original item with the translated fields | |
| item['prompt'] = translated_prompt | |
| item['chosen'] = translated_chosen | |
| item['rejected'] = translated_rejected | |
| return item | |
| except Exception as e: | |
| logging.error(f"An error occurred during translation: {e}") | |
| return None | |
| def validate_item_ufb_cached(item): | |
| # Check basic required fields | |
| required_fields = ['source', 'prompt', 'chosen', 'rejected'] | |
| for field in required_fields: | |
| if field not in item: | |
| logging.warning(f"Missing required field: {field}") | |
| return False | |
| # Ensure 'prompt' is a string | |
| if not isinstance(item['prompt'], str): | |
| logging.warning("Prompt must be a string.") | |
| return False | |
| # Check 'chosen' and 'rejected' which should be lists of dictionaries | |
| for field in ['chosen', 'rejected']: | |
| if not isinstance(item[field], list) or not item[field]: | |
| logging.warning(f"No entries or incorrect type for section: {field}") | |
| return False | |
| for idx, message in enumerate(item[field]): | |
| if 'content' not in message or 'role' not in message: | |
| logging.warning(f"Missing 'content' or 'role' field in {field} at index {idx}") | |
| return False | |
| if not isinstance(message['content'], str) or not isinstance(message['role'], str): | |
| logging.warning(f"Invalid type for 'content' or 'role' field in {field} at index {idx}") | |
| return False | |
| return True | |
| def process_file(input_file_path, output_file_path, raw_file_path, line_indices, translator, tokenizer, model_type, target_language): | |
| try: | |
| # Assigning validation and translation functions based on model_type | |
| if model_type == "mix": | |
| print ("translating a mix-style model...") | |
| validate_item = validate_item_mix | |
| translate_item = translate_item_mix | |
| elif model_type == "ufb_cached": | |
| print ("translating an ufb_cached-style model...") | |
| validate_item = validate_item_ufb_cached | |
| translate_item = translate_item_ufb_cached # def translate_item_ufb(item, raw_file_path, translator, tokenizer): | |
| elif model_type == "ufb": | |
| print ("translating an ultrafeedback-style model...") | |
| validate_item = validate_item_ufb | |
| translate_item = translate_item_ufb # def translate_item_ufb(item, raw_file_path, translator, tokenizer): | |
| else: | |
| raise ValueError(f"Unsupported model_type: {model_type}") | |
| with open(input_file_path, 'r', encoding='utf-8') as file: | |
| data_points = [json.loads(line) for line in file] | |
| failed_items = [] | |
| failed_items_indices = [] | |
| for index in tqdm(line_indices, desc="Processing lines", unit="item"): | |
| item = data_points[index] | |
| # Validate the item structure | |
| if not validate_item(item): | |
| logging.warning("Skipping item due to invalid structure.") | |
| failed_items.append(item) | |
| continue | |
| # Translate the relevant fields in the item | |
| translated_item = None | |
| retry_count = 0 | |
| while translated_item is None and retry_count < 3: | |
| print ("going to translate the item...") | |
| translated_item = translate_item(item, raw_file_path, translator, tokenizer, target_language) | |
| retry_count += 1 | |
| if translated_item is None: | |
| logging.warning(f"Translation failed for item. Retry attempt: {retry_count}") | |
| time.sleep(1) | |
| if translated_item is not None: | |
| translated_item['index'] = index | |
| with open(output_file_path, 'a', encoding='utf-8') as file: | |
| file.write(json.dumps(translated_item, ensure_ascii=False) + "\n") | |
| else: | |
| failed_items_indices.append(index) | |
| failed_items.append(item) | |
| logging.error("Translation failed after multiple attempts. Skipping item.") | |
| # Validate the translated item structure | |
| if not validate_item(translated_item): | |
| logging.warning("Skipping translated item due to invalid structure.") | |
| failed_items.append(item) | |
| continue | |
| with open('failed_items.jsonl', 'w', encoding='utf-8') as file: | |
| for item in failed_items: | |
| file.write(json.dumps(item, ensure_ascii=False) + "\n") | |
| failed_items_str = generate_failed_items_str(failed_items_indices) | |
| with open('failed_items_index.txt', 'w', encoding='utf-8') as f: | |
| f.write(failed_items_str) | |
| logging.info("Translation completed successfully.") | |
| except Exception as e: | |
| logging.error(f"An error occurred: {e}") | |
| def generate_failed_items_str(indices): | |
| """ | |
| Converts a list of failed item indices into a string. | |
| """ | |
| if not indices: | |
| return "" | |
| # Sort the list of indices and initialize the first range | |
| indices.sort() | |
| range_start = indices[0] | |
| current = range_start | |
| ranges = [] | |
| for i in indices[1:]: | |
| if i == current + 1: | |
| current = i | |
| else: | |
| if range_start == current: | |
| ranges.append(f"{range_start}") | |
| else: | |
| ranges.append(f"{range_start}-{current}") | |
| range_start = current = i | |
| # Add the last range | |
| if range_start == current: | |
| ranges.append(f"{range_start}") | |
| else: | |
| ranges.append(f"{range_start}-{current}") | |
| return ",".join(ranges) | |
| # Function to upload the output file to Hugging Face | |
| def upload_output_to_huggingface(output_file_path, repo_name, token): | |
| api = HfApi() | |
| # Check if the repository exists | |
| try: | |
| print ("checking repo:", repo_name) | |
| api.repo_info(repo_id=repo_name, repo_type="dataset", token=token) | |
| except Exception as e: | |
| if "404" in str(e): | |
| # Create the repository if it doesn't exist | |
| print ("creating it...") | |
| create_repo(repo_id=repo_name, repo_type="dataset", token=token) | |
| print(f"Created repository: {repo_name}") | |
| else: | |
| print(f"Failed to check repository existence: {e}") | |
| return | |
| # Upload the file to the repository | |
| try: | |
| print ("starting dataset upload from:", output_file_path) | |
| upload_file( | |
| path_or_fileobj=output_file_path, | |
| path_in_repo=output_file_path, | |
| repo_id=repo_name, | |
| repo_type="dataset", | |
| token=token | |
| ) | |
| print(f"Uploaded {output_file_path} to Hugging Face repository: {repo_name}") | |
| except Exception as e: | |
| print(f"Failed to upload {output_file_path} to Hugging Face: {e}") | |
| raise | |
| def translate_dataset(train_url, local_parquet_path, input_file_path, output_file_path, raw_file_path, range_specification, model_type, output_dir, output_repo_name, token, translator, tokenizer, target_language): | |
| try: | |
| # Download the Parquet file | |
| download_parquet(train_url, local_parquet_path) | |
| except Exception as e: | |
| logging.error(f"Failed to download the Parquet file from {train_url}: {e}") | |
| return | |
| try: | |
| # Convert the downloaded Parquet file to JSONL | |
| convert_parquet_to_jsonl(local_parquet_path, output_dir) | |
| except Exception as e: | |
| logging.error(f"Failed to convert Parquet to JSONL: {e}") | |
| return | |
| try: | |
| # Rename the JSONL file using subprocess to ensure correct handling | |
| subprocess.run(["mv", f"{output_dir}/train.jsonl", input_file_path], check=True) | |
| except subprocess.CalledProcessError as e: | |
| logging.error(f"Failed to rename the file from 'train.jsonl' to {input_file_path}: {e}") | |
| return | |
| try: | |
| # Count lines in the JSONL file to validate contents | |
| line_count = count_lines_in_jsonl(input_file_path) | |
| logging.info(f"Number of lines in the file: {line_count}") | |
| except Exception as e: | |
| logging.error(f"Failed to count lines in {input_file_path}: {e}") | |
| return | |
| try: | |
| # Parse the range specification for processing specific lines | |
| line_indices = parse_range_specification(range_specification, file_length=line_count) | |
| if not line_indices: | |
| logging.error("No valid line indices to process. Please check the range specifications.") | |
| return | |
| except Exception as e: | |
| logging.error(f"Error parsing range specification '{range_specification}': {e}") | |
| return | |
| try: | |
| # Process the file with specified model type and line indices | |
| process_file(input_file_path, output_file_path, raw_file_path, line_indices, translator, tokenizer, model_type, target_language) | |
| except Exception as e: | |
| logging.error(f"Failed to process the file {input_file_path}: {e}") | |
| return | |
| try: | |
| # Upload the output file to Hugging Face repository | |
| upload_output_to_huggingface(output_file_path, output_repo_name, token) | |
| except Exception as e: | |
| logging.error(f"Failed to upload {output_file_path} to Hugging Face: {e}") | |
| # Setup logging configuration | |
| log_stream = io.StringIO() | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler("translation.log", mode='a'), | |
| logging.StreamHandler(log_stream) | |
| ]) | |
| logger = logging.getLogger(__name__) | |
| # Main function to handle the translation workflow | |
| # Main function to handle the translation workflow | |
| def main(dataset_url, model_type, output_dataset_name, range_specification, target_language, token: gr.OAuthToken | None, profile: gr.OAuthProfile | None): | |
| try: | |
| # Login to Hugging Face | |
| if token is None or profile is None or token.token is None or profile.username is None: | |
| return "### You must be logged in to use this service." | |
| if token: | |
| logger.info("Logged in to Hugging Face") | |
| # Configuration and paths | |
| tokenizer_name = "facebook/wmt21-dense-24-wide-en-x" | |
| model_repo_name = "cstr/wmt21ct2_int8" # Repository to download the model from | |
| # Download the model snapshot from Hugging Face | |
| model_path = snapshot_download(repo_id=model_repo_name, token=token.token) | |
| logger.info(f"Model downloaded to: {model_path}") | |
| # Load the CTranslate2 model | |
| translator = ctranslate2.Translator(model_path, device="auto") | |
| logger.info("CTranslate2 model loaded successfully.") | |
| # Load the tokenizer | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) | |
| tokenizer.src_lang = "en" | |
| tokenizer.tgt_lang = target_language # Set target language | |
| logger.info("Tokenizer loaded successfully.") | |
| # Define the task based on user input | |
| task = { | |
| "url": dataset_url, | |
| "local_path": "train.parquet", | |
| "input_file": f"{model_type}_en.jsonl", | |
| "output_file": f"{model_type}_{target_language}.jsonl", # Include target language in the filename | |
| "raw_file": f"{model_type}_{target_language}_raw.jsonl", | |
| "range_spec": range_specification, | |
| "model_type": model_type, | |
| "target_language": target_language # Include target language in the task | |
| } | |
| # Call the translate_dataset function with the provided parameters | |
| translate_dataset( | |
| train_url=task["url"], | |
| local_parquet_path=task["local_path"], | |
| input_file_path=task["input_file"], | |
| output_file_path=task["output_file"], | |
| output_dir=".", | |
| output_repo_name=output_dataset_name, | |
| raw_file_path=task["raw_file"], | |
| token=token.token, | |
| range_specification=task["range_spec"], | |
| model_type=task["model_type"], | |
| translator=translator, | |
| tokenizer=tokenizer, | |
| target_language=task["target_language"] # Pass the target language | |
| ) | |
| logger.info("Dataset translation completed!") | |
| return "Dataset translation completed!\n\n### Logs:\n" + log_stream.getvalue() | |
| else: | |
| return "Login failed. Please try again." | |
| except Exception as e: | |
| logger.error(f"An error occurred in the main function: {e}") | |
| return f"An error occurred: {e}\n\n### Logs:\n{log_stream.getvalue()}" | |
| # Gradio interface setup | |
| gradio_title = "π§ WMT21 Dataset Translation" | |
| gradio_desc = """This tool translates english datasets using the WMT21 translation model. | |
| ## π What Does This Tool Do: | |
| - Translates datasets with structures based on the selected model type. | |
| - The translation model (facebook/wmt21-dense-24-wide-en-x) supports as target languages: Hausa (ha), Icelandic (is), Japanese (ja), Czech (cs), Russian (ru), Chinese (zh), German (de) | |
| - Uploads the translated dataset to Hugging Face. | |
| - At the moment, this works only on CPU, and therefore is very very slow (>1 minute per item depending on string lengths).""" | |
| datasets_desc = """## π Dataset Types: | |
| - **mix**: | |
| - `prompt`: List of dictionaries with 'content' and 'role' fields (multi-turn conversation). | |
| - `chosen`: Single dictionary with 'content' and 'role' fields. | |
| - `rejected`: Single dictionary with 'content' and 'role' fields. | |
| - **ufb_cached**: | |
| - `prompt`: String (user input). | |
| - `chosen`: List of dictionaries with 'content' and 'role' fields. | |
| - `rejected`: List of dictionaries with 'content' and 'role' fields. | |
| - **ufb**: | |
| - like ufb_cached, but we do not check for already translated strings | |
| ## π οΈ Backend: | |
| The translation backend runs on the Hugging Face Hub API.""" | |
| # Define the theme | |
| theme = gr.themes.Soft(text_size="lg", spacing_size="lg") | |
| with gr.Blocks(theme=theme) as demo: | |
| gr.HTML(f"""<h1 align="center" id="space-title">{gradio_title}</h1>""") | |
| gr.Markdown(gradio_desc) | |
| with gr.Row(variant="panel"): | |
| gr.Markdown(value="## π Login to Hugging Face"), | |
| gr.LoginButton(min_width=380) | |
| gr.Markdown(value="π¨ **This is needed to upload the resulting dataset.**") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(): | |
| dataset_url = gr.Textbox(label="Input Dataset URL", lines=2, placeholder = "https://huggingface.co/datasets/alvarobartt/dpo-mix-7k-simplified/resolve/main/data/train-00000-of-00001.parquet?download=true") | |
| model_type = gr.Dropdown(choices=["mix", "ufb_cached", "ufb"], label="Dataset Type") | |
| output_dataset_name = gr.Textbox(label="Output Dataset Name", lines=1, placeholder = "cstr/translated_datasets") | |
| range_specification = gr.Textbox(label="Range Specification", lines=1, placeholder="e.g., 1-100") | |
| target_language = gr.Dropdown(choices=["ha", "is", "ja", "cs", "ru", "zh", "de"], label="Target Language") # New dropdown for target language | |
| with gr.Column(): | |
| output = gr.Markdown(label="Output") | |
| submit_btn = gr.Button("Translate Dataset", variant="primary") | |
| submit_btn.click(main, inputs=[dataset_url, model_type, output_dataset_name, range_specification, target_language], outputs=output) | |
| gr.Markdown(datasets_desc) | |
| demo.queue(max_size=10).launch(share=True, show_api=True) |