Spaces:
Sleeping
Sleeping
| import subprocess | |
| import sys | |
| import time | |
| from typing import List | |
| from distilabel.steps.generators.data import LoadDataFromDicts | |
| from distilabel.steps.expand import ExpandColumns | |
| from distilabel.steps.keep import KeepColumns | |
| from distilabel.steps.tasks.self_instruct import SelfInstruct | |
| from distilabel.steps.tasks.evol_instruct.base import EvolInstruct | |
| from distilabel.llms.huggingface import InferenceEndpointsLLM | |
| from distilabel.pipeline import Pipeline | |
| from distilabel.steps import TextGenerationToArgilla | |
| from dotenv import load_dotenv | |
| from domain import ( | |
| DomainExpert, | |
| CleanNumberedList, | |
| create_topics, | |
| create_examples_template, | |
| APPLICATION_DESCRIPTION, | |
| ) | |
| load_dotenv() | |
| def define_pipeline( | |
| argilla_api_key: str, | |
| argilla_api_url: str, | |
| argilla_dataset_name: str, | |
| topics: List[str], | |
| perspectives: List[str], | |
| domain_expert_prompt: str, | |
| examples: List[dict], | |
| hub_token: str, | |
| endpoint_base_url: str, | |
| ): | |
| """Define the pipeline for the specific domain.""" | |
| terms = create_topics(topics, perspectives) | |
| template = create_examples_template(examples) | |
| with Pipeline("farming") as pipeline: | |
| load_data = LoadDataFromDicts( | |
| name="load_data", | |
| data=[{"input": term} for term in terms], | |
| batch_size=64, | |
| ) | |
| llm = InferenceEndpointsLLM( | |
| base_url=endpoint_base_url, | |
| api_key=hub_token, | |
| ) | |
| self_instruct = SelfInstruct( | |
| name="self-instruct", | |
| application_description=APPLICATION_DESCRIPTION, | |
| num_instructions=5, | |
| input_batch_size=8, | |
| llm=llm, | |
| ) | |
| evol_instruction_complexity = EvolInstruct( | |
| name="evol_instruction_complexity", | |
| llm=llm, | |
| num_evolutions=2, | |
| store_evolutions=True, | |
| input_batch_size=8, | |
| include_original_instruction=True, | |
| input_mappings={"instruction": "question"}, | |
| ) | |
| expand_instructions = ExpandColumns( | |
| name="expand_columns", columns={"instructions": "question"} | |
| ) | |
| cleaner = CleanNumberedList(name="clean_numbered_list") | |
| expand_evolutions = ExpandColumns( | |
| name="expand_columns_evolved", | |
| columns={"evolved_instructions": "evolved_questions"}, | |
| ) | |
| domain_expert = DomainExpert( | |
| name="domain_expert", | |
| llm=llm, | |
| input_batch_size=8, | |
| input_mappings={"instruction": "evolved_questions"}, | |
| output_mappings={"generation": "domain_expert_answer"}, | |
| ) | |
| domain_expert._system_prompt = domain_expert_prompt | |
| domain_expert._template = template | |
| keep_columns = KeepColumns( | |
| name="keep_columns", | |
| columns=["model_name", "evolved_questions", "domain_expert_answer"], | |
| ) | |
| to_argilla = TextGenerationToArgilla( | |
| name="text_generation_to_argilla", | |
| dataset_name=argilla_dataset_name, | |
| dataset_workspace="admin", | |
| api_url=argilla_api_url, | |
| api_key=argilla_api_key, | |
| input_mappings={ | |
| "instruction": "evolved_questions", | |
| "generation": "domain_expert_answer", | |
| }, | |
| ) | |
| load_data.connect(self_instruct) | |
| self_instruct.connect(expand_instructions) | |
| expand_instructions.connect(cleaner) | |
| cleaner.connect(evol_instruction_complexity) | |
| evol_instruction_complexity.connect(expand_evolutions) | |
| expand_evolutions.connect(domain_expert) | |
| domain_expert.connect(keep_columns) | |
| keep_columns.connect(to_argilla) | |
| return pipeline | |
| def serialize_pipeline( | |
| argilla_api_key: str, | |
| argilla_api_url: str, | |
| argilla_dataset_name: str, | |
| topics: List[str], | |
| perspectives: List[str], | |
| domain_expert_prompt: str, | |
| hub_token: str, | |
| endpoint_base_url: str, | |
| pipeline_config_path: str = "pipeline.yaml", | |
| examples: List[dict] = [], | |
| ): | |
| """Serialize the pipeline to a yaml file.""" | |
| pipeline = define_pipeline( | |
| argilla_api_key=argilla_api_key, | |
| argilla_api_url=argilla_api_url, | |
| argilla_dataset_name=argilla_dataset_name, | |
| topics=topics, | |
| perspectives=perspectives, | |
| domain_expert_prompt=domain_expert_prompt, | |
| hub_token=hub_token, | |
| endpoint_base_url=endpoint_base_url, | |
| examples=examples, | |
| ) | |
| pipeline.save(path=pipeline_config_path, overwrite=True, format="yaml") | |
| def create_pipelines_run_command( | |
| hub_token: str, | |
| argilla_api_key: str, | |
| argilla_api_url: str, | |
| pipeline_config_path: str = "pipeline.yaml", | |
| argilla_dataset_name: str = "domain_specific_datasets", | |
| ): | |
| """Create the command to run the pipeline.""" | |
| command_to_run = [ | |
| sys.executable, | |
| "-m", | |
| "distilabel", | |
| "pipeline", | |
| "run", | |
| "--config", | |
| pipeline_config_path, | |
| "--param", | |
| f"text_generation_to_argilla.dataset_name={argilla_dataset_name}", | |
| "--param", | |
| f"text_generation_to_argilla.api_key={argilla_api_key}", | |
| "--param", | |
| f"text_generation_to_argilla.api_url={argilla_api_url}", | |
| "--param", | |
| f"self-instruct.llm.api_key={hub_token}", | |
| "--param", | |
| f"evol_instruction_complexity.llm.api_key={hub_token}", | |
| "--param", | |
| f"domain_expert.llm.api_key={hub_token}", | |
| "--ignore-cache", | |
| ] | |
| return command_to_run | |
| def run_pipeline( | |
| hub_token: str, | |
| argilla_api_key: str, | |
| argilla_api_url: str, | |
| pipeline_config_path: str = "pipeline.yaml", | |
| argilla_dataset_name: str = "domain_specific_datasets", | |
| ): | |
| """Run the pipeline and yield the output as a generator of logs.""" | |
| command_to_run = create_pipelines_run_command( | |
| hub_token=hub_token, | |
| pipeline_config_path=pipeline_config_path, | |
| argilla_dataset_name=argilla_dataset_name, | |
| argilla_api_key=argilla_api_key, | |
| argilla_api_url=argilla_api_url, | |
| ) | |
| # Run the script file | |
| process = subprocess.Popen( | |
| args=command_to_run, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.PIPE, | |
| env={"HF_TOKEN": hub_token}, | |
| ) | |
| while process.stdout and process.stdout.readable(): | |
| time.sleep(0.2) | |
| line = process.stdout.readline() | |
| if not line: | |
| break | |
| yield line.decode("utf-8") | |