Spaces:
Runtime error
Runtime error
| import json | |
| from typing import TYPE_CHECKING, List, Literal, Union | |
| from datasets import Dataset, concatenate_datasets | |
| from distilabel.llms.huggingface import InferenceEndpointsLLM | |
| from distilabel.pipeline import Pipeline | |
| from distilabel.steps import CombineOutputs, GeneratorStep, KeepColumns, Step, StepInput | |
| from distilabel.steps.tasks import TextGeneration | |
| from typing_extensions import override | |
| CHOSEN_TEMPLATE = """ | |
| You are provide with a conversation between a human and an AI assistant. | |
| The final message is of poor quality positively. Your task is to regenerate one of high quality. | |
| {% for message in conversation %} | |
| {{ message["role"] }}: {{ message["content"] }} | |
| {% endfor %} | |
| High quality response: | |
| """.rstrip() | |
| CHOSEN_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate high quality response when other assistants created a poor quality response." | |
| REJECT_TEMPLATE = """ | |
| You are provide with a conversation between a human and an AI assistant. | |
| The final message is of high quality positively. Your task is to regenerate one of poor quality. | |
| {% for message in conversation %} | |
| {{ message["role"] }}: {{ message["content"] }} | |
| {% endfor %} | |
| Poor quality response: | |
| """.rstrip() | |
| REJECT_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate a poor quality response when other assistants created a high quality response." | |
| class FilterConversationRatings(Step): | |
| """Filters conversations based on the rating of the last message.""" | |
| target_column: Union[Literal["chosen"], Literal["rejected"]] | |
| batch_size: int = 5 | |
| def process(self, dataset: StepInput) -> "GeneratorStepOutput": | |
| column_rating_map = { | |
| "chosen": 1, | |
| "rejected": -1, | |
| } | |
| target_rating = column_rating_map[self.target_column] | |
| for batch_start in range(0, len(dataset), self.batch_size): | |
| batch = dataset[batch_start : batch_start + self.batch_size] | |
| filtered_batch = [] | |
| for conversation in batch: | |
| for row in batch: | |
| _conversation = row["conversation"] | |
| conversation = None | |
| for idx, message in enumerate(_conversation, 1): | |
| if not isinstance(message["rating"], int): | |
| continue | |
| if message["rating"] == target_rating: | |
| conversation = _conversation[:idx] | |
| break | |
| if conversation: | |
| filtered_batch.append({"conversation": conversation}) | |
| yield filtered_batch | |
| def outputs(self) -> "StepColumns": | |
| return ["conversation"] | |
| class AppendToConversationStep(Step): | |
| """Appends a generated message to a conversation.""" | |
| def inputs(self) -> "StepColumns": | |
| return ["generation", "conversation"] | |
| def outputs(self) -> "StepColumns": | |
| return ["generated_conversation", "conversation"] | |
| def process(self, inputs: StepInput) -> "StepOutput": | |
| for input in inputs: | |
| if not input["generation"]: | |
| continue | |
| if not input["conversation"]: | |
| continue | |
| input["generated_conversation"] = [ | |
| {"role": message["role"], "content": message["content"]} | |
| for message in input["conversation"][:-1] | |
| ] + [{"role": "assistant", "content": input["generation"]}] | |
| input["conversation"] = [ | |
| {"role": message["role"], "content": message["content"]} | |
| for message in input["conversation"] | |
| ] | |
| yield inputs | |
| with Pipeline( | |
| name="conversation_rejection", | |
| description="Generate a chosen response to a rejected conversation.", | |
| ) as rejection_pipeline: | |
| rejected_dataset = FilterConversationRatings(target_column="rejected") | |
| chosen_text_gen = TextGeneration( | |
| llm=InferenceEndpointsLLM( | |
| model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", | |
| ), | |
| system_prompt=CHOSEN_SYSTEM_PROMPT, | |
| template=CHOSEN_TEMPLATE, | |
| columns=["conversation"], | |
| ) | |
| append_chosen = AppendToConversationStep( | |
| output_mappings={ | |
| "generated_conversation": "chosen", | |
| "conversation": "rejected", | |
| }, | |
| ) | |
| keep_columns = KeepColumns( | |
| columns=["chosen", "rejected"], | |
| ) | |
| rejected_dataset >> chosen_text_gen >> append_chosen >> keep_columns | |
| with Pipeline( | |
| name="conversation_chosen", | |
| description="Generate a rejected response to a chosen conversation.", | |
| ) as chosen_pipeline: | |
| chosen_dataset = FilterConversationRatings(target_column="chosen") | |
| rejected_text_gen = TextGeneration( | |
| llm=InferenceEndpointsLLM( | |
| model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", | |
| ), | |
| system_prompt=REJECT_SYSTEM_PROMPT, | |
| template=REJECT_TEMPLATE, | |
| columns=["conversation"], | |
| ) | |
| append_rejected = AppendToConversationStep( | |
| output_mappings={ | |
| "generated_conversation": "rejected", | |
| "conversation": "chosen", | |
| }, | |
| ) | |
| keep_columns = KeepColumns( | |
| columns=["chosen", "rejected"], | |
| ) | |
| chosen_dataset >> rejected_text_gen >> append_rejected >> keep_columns | |
| if __name__ == "__main__": | |
| dataset_path = "example_data.json" | |
| data = json.load(open(dataset_path)) | |
| dataset = Dataset.from_list(data) | |
| rejected_dataset = rejection_pipeline.run(dataset=dataset, use_cache=False) | |
| chosen_dataset = chosen_pipeline.run(dataset=dataset, use_cache=False) | |
| dataset = concatenate_datasets( | |
| dsets=[rejected_dataset["default"]["train"], chosen_dataset["default"]["train"]] | |
| ) | |