Spaces:
Runtime error
Runtime error
| import sys | |
| from typing import List | |
| import traceback | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import json | |
| # from flask import Flask, request, render_template | |
| # from flask_cors import CORS | |
| # app = Flask(__name__, static_folder='static') | |
| # app.config['TEMPLATES_AUTO_RELOAD'] = True | |
| # CORS(app, resources= { | |
| # r"/generate": {"origins": origins}, | |
| # r"/infill": {"origins": origins}, | |
| # }) | |
| # origins=[f"http://localhost:{PORT}", "https://huggingface.co", "https://hf.space"] | |
| CUDA = True | |
| PORT = 7860 | |
| VERBOSE = False | |
| from fastapi import FastAPI, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse, StreamingResponse | |
| app = FastAPI(docs_url=None, redoc_url=None) | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| print("loading model") | |
| model = AutoModelForCausalLM.from_pretrained("facebook/incoder-6B") | |
| print("loading tokenizer") | |
| tokenizer = AutoTokenizer.from_pretrained("facebook/incoder-6B") | |
| print("loading complete") | |
| if CUDA: | |
| model = model.half().cuda() | |
| BOS = "<|endoftext|>" | |
| EOM = "<|endofmask|>" | |
| def make_sentinel(i): | |
| return f"<|mask:{i}|>" | |
| SPECIAL_TOKENS = [make_sentinel(i) for i in range(256)] + [EOM] | |
| def generate(input, length_limit=None, temperature=None): | |
| input_ids = tokenizer(input, return_tensors="pt").input_ids | |
| if CUDA: | |
| input_ids = input_ids.cuda() | |
| output = model.generate(input_ids=input_ids, do_sample=True, top_p=0.95, temperature=temperature, max_length=length_limit) | |
| detok_hypo_str = tokenizer.decode(output.flatten()) | |
| if detok_hypo_str.startswith(BOS): | |
| detok_hypo_str = detok_hypo_str[len(BOS):] | |
| return detok_hypo_str | |
| def infill(parts: List[str], length_limit=None, temperature=None, extra_sentinel=False, max_retries=1): | |
| assert isinstance(parts, list) | |
| retries_attempted = 0 | |
| done = False | |
| while (not done) and (retries_attempted < max_retries): | |
| retries_attempted += 1 | |
| if VERBOSE: | |
| print(f"retry {retries_attempted}") | |
| if len(parts) == 1: | |
| prompt = parts[0] | |
| else: | |
| prompt = "" | |
| # encode parts separated by sentinel | |
| for sentinel_ix, part in enumerate(parts): | |
| prompt += part | |
| if extra_sentinel or (sentinel_ix < len(parts) - 1): | |
| prompt += make_sentinel(sentinel_ix) | |
| # prompt += TokenizerWrapper.make_sentinel(0) | |
| infills = [] | |
| complete = [] | |
| done = True | |
| for sentinel_ix, part in enumerate(parts[:-1]): | |
| complete.append(part) | |
| prompt += make_sentinel(sentinel_ix) | |
| completion = generate(prompt, length_limit, temperature) | |
| completion = completion[len(prompt):] | |
| if EOM not in completion: | |
| if VERBOSE: | |
| print(f"warning: {EOM} not found") | |
| completion += EOM | |
| # TODO: break inner loop here | |
| done = False | |
| completion = completion[:completion.index(EOM) + len(EOM)] | |
| infilled = completion[:-len(EOM)] | |
| infills.append(infilled) | |
| complete.append(infilled) | |
| prompt += completion | |
| complete.append(parts[-1]) | |
| text = ''.join(complete) | |
| if VERBOSE: | |
| print("generated text:") | |
| print(prompt) | |
| print() | |
| print("parts:") | |
| print(parts) | |
| print() | |
| print("infills:") | |
| print(infills) | |
| print() | |
| print("restitched text:") | |
| print(text) | |
| print() | |
| return { | |
| 'text': text, | |
| 'parts': parts, | |
| 'infills': infills, | |
| 'retries_attempted': retries_attempted, | |
| } | |
| def index() -> FileResponse: | |
| return FileResponse(path="static/index.html", media_type="text/html") | |
| async def generate_maybe(info: str): | |
| # form = await info.json() | |
| form = json.loads(info) | |
| prompt = form['prompt'] | |
| length_limit = int(form['length']) | |
| temperature = float(form['temperature']) | |
| if VERBOSE: | |
| print(prompt) | |
| try: | |
| generation = generate(prompt, length_limit, temperature) | |
| return {'result': 'success', 'type': 'generate', 'prompt': prompt, 'text': generation} | |
| except Exception as e: | |
| traceback.print_exception(*sys.exc_info()) | |
| return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'text': f'There was an error: {e}. Tell Daniel.'} | |
| async def infill_maybe(info: str): | |
| # form = await info.json() | |
| form = json.loads(info) | |
| length_limit = int(form['length']) | |
| temperature = float(form['temperature']) | |
| max_retries = 1 | |
| extra_sentinel = True | |
| try: | |
| generation = infill(form['parts'], length_limit, temperature, extra_sentinel=extra_sentinel, max_retries=max_retries) | |
| generation['result'] = 'success' | |
| generation['type'] = 'infill' | |
| return generation | |
| # return {'result': 'success', 'prefix': prefix, 'suffix': suffix, 'text': generation['text']} | |
| except Exception as e: | |
| traceback.print_exception(*sys.exc_info()) | |
| print(e) | |
| return {'result': 'error', 'type': 'infill', 'text': f'There was an error: {e}.'} | |
| if __name__ == "__main__": | |
| app.run(host='0.0.0.0', port=PORT, threaded=False) | |