Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Load the implicit CoT model | |
| implicit_cot_model_name = 'yuntian-deng/implicit-cot-math-mistral7b' | |
| implicit_cot_model = AutoModelForCausalLM.from_pretrained(implicit_cot_model_name, torch_dtype=torch.bfloat16) | |
| tokenizer = AutoTokenizer.from_pretrained(implicit_cot_model_name) | |
| # Constants | |
| MAX_RESULT_TOKENS = 10 | |
| def predict_answer(question): | |
| input_text = ' '.join(question.split()).strip() + ' ' + tokenizer.eos_token | |
| inputs = tokenizer(input_text, return_tensors='pt').to('cuda' if torch.cuda.is_available() else 'cpu') | |
| implicit_cot_model.to('cuda' if torch.cuda.is_available() else 'cpu') | |
| input_ids = inputs['input_ids'] | |
| outputs = implicit_cot_model.generate(input_ids=input_ids, | |
| max_new_tokens=MAX_RESULT_TOKENS, | |
| do_sample=False) | |
| prediction = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return prediction | |
| color_map = {"correct": "green", "wrong": "red"} | |
| demo = gr.Interface( | |
| fn=predict_answer, | |
| inputs=[ | |
| gr.Textbox(label='Question', value='A set of 7 spoons costs $21. If each spoon would be sold separately, how much would 5 spoons cost?'), | |
| ], | |
| outputs=[ | |
| gr.HighlightedText(label='Implicit CoT Prediction', combine_adjacent=False, show_legend=False, color_map=color_map, show_inline_category=False), | |
| ], | |
| title='Solving Grade School Math Problems with Implicit CoT', | |
| description='This demo showcases Mistral-7B\'s ability to solve grade school math problems without producing intermediate steps, using our stepwise internalization method.', | |
| article=""" | |
| - [Paper 1: Implicit Chain of Thought Reasoning via Knowledge Distillation](https://arxiv.org/pdf/2311.01460) | |
| - [Paper 2: From Explicit CoT to Implicit CoT: Learning to Internalize CoT Step by Step](https://arxiv.org/pdf/2405.14838) | |
| - [Code Repository](https://github.com/da03/Internalize_CoT_Step_by_Step) | |
| - [Tweet Announcement](https://twitter.com/yuntiandeng/status/1795854740879774036) | |
| """, | |
| clear_btn=None, | |
| submit_btn="Get Answer!", | |
| live=False, | |
| concurrency_limit=1 | |
| ) | |
| demo.queue(max_size=5).launch() | |