Spaces:
Runtime error
Runtime error
Daniel Fried
commited on
Commit
·
44efa8c
1
Parent(s):
8a85023
fix query encoding and add new examples
Browse files- modules/app.py +27 -1
- static/index.html +23 -7
modules/app.py
CHANGED
|
@@ -2,6 +2,7 @@ import sys
|
|
| 2 |
from typing import List
|
| 3 |
import traceback
|
| 4 |
import os
|
|
|
|
| 5 |
# needs to be imported *before* transformers
|
| 6 |
if os.path.exists('use_normal_tokenizers'):
|
| 7 |
import tokenizers
|
|
@@ -11,8 +12,10 @@ else:
|
|
| 11 |
import tokenizers_patch
|
| 12 |
BIG_MODEL = True
|
| 13 |
CUDA = True
|
|
|
|
| 14 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 15 |
import json
|
|
|
|
| 16 |
|
| 17 |
# from flask import Flask, request, render_template
|
| 18 |
# from flask_cors import CORS
|
|
@@ -32,8 +35,14 @@ TRUNCATION_MESSAGE = f'warning: This demo is limited to {MAX_LENGTH} tokens in t
|
|
| 32 |
|
| 33 |
if BIG_MODEL:
|
| 34 |
model_name = "facebook/incoder-6B"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
else:
|
| 36 |
model_name = "facebook/incoder-1B"
|
|
|
|
| 37 |
|
| 38 |
from fastapi import FastAPI, Request
|
| 39 |
from fastapi.staticfiles import StaticFiles
|
|
@@ -43,7 +52,7 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
| 43 |
|
| 44 |
|
| 45 |
print("loading model")
|
| 46 |
-
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 47 |
print("loading tokenizer")
|
| 48 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 49 |
print("loading complete")
|
|
@@ -154,9 +163,18 @@ def index() -> FileResponse:
|
|
| 154 |
return FileResponse(path="static/index.html", media_type="text/html")
|
| 155 |
|
| 156 |
@app.get('/generate')
|
|
|
|
| 157 |
async def generate_maybe(info: str):
|
| 158 |
# form = await info.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
form = json.loads(info)
|
|
|
|
|
|
|
| 160 |
prompt = form['prompt']
|
| 161 |
length_limit = int(form['length'])
|
| 162 |
temperature = float(form['temperature'])
|
|
@@ -174,9 +192,17 @@ async def generate_maybe(info: str):
|
|
| 174 |
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
|
| 175 |
|
| 176 |
@app.get('/infill')
|
|
|
|
| 177 |
async def infill_maybe(info: str):
|
| 178 |
# form = await info.json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
form = json.loads(info)
|
|
|
|
| 180 |
length_limit = int(form['length'])
|
| 181 |
temperature = float(form['temperature'])
|
| 182 |
max_retries = 1
|
|
|
|
| 2 |
from typing import List
|
| 3 |
import traceback
|
| 4 |
import os
|
| 5 |
+
import base64
|
| 6 |
# needs to be imported *before* transformers
|
| 7 |
if os.path.exists('use_normal_tokenizers'):
|
| 8 |
import tokenizers
|
|
|
|
| 12 |
import tokenizers_patch
|
| 13 |
BIG_MODEL = True
|
| 14 |
CUDA = True
|
| 15 |
+
import torch
|
| 16 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 17 |
import json
|
| 18 |
+
import pprint
|
| 19 |
|
| 20 |
# from flask import Flask, request, render_template
|
| 21 |
# from flask_cors import CORS
|
|
|
|
| 35 |
|
| 36 |
if BIG_MODEL:
|
| 37 |
model_name = "facebook/incoder-6B"
|
| 38 |
+
kwargs = dict(
|
| 39 |
+
revision="float16",
|
| 40 |
+
torch_dtype=torch.float16,
|
| 41 |
+
low_cpu_mem_usage=True,
|
| 42 |
+
)
|
| 43 |
else:
|
| 44 |
model_name = "facebook/incoder-1B"
|
| 45 |
+
kwargs = dict()
|
| 46 |
|
| 47 |
from fastapi import FastAPI, Request
|
| 48 |
from fastapi.staticfiles import StaticFiles
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
print("loading model")
|
| 55 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, **kwargs)
|
| 56 |
print("loading tokenizer")
|
| 57 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 58 |
print("loading complete")
|
|
|
|
| 163 |
return FileResponse(path="static/index.html", media_type="text/html")
|
| 164 |
|
| 165 |
@app.get('/generate')
|
| 166 |
+
# async def generate_maybe(request: Request):
|
| 167 |
async def generate_maybe(info: str):
|
| 168 |
# form = await info.json()
|
| 169 |
+
# form = await request.json()
|
| 170 |
+
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
|
| 171 |
+
# fix padding, following https://stackoverflow.com/a/9956217/1319683
|
| 172 |
+
print(info)
|
| 173 |
+
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
|
| 174 |
+
print(info)
|
| 175 |
form = json.loads(info)
|
| 176 |
+
pprint.pprint(form)
|
| 177 |
+
# print(form)
|
| 178 |
prompt = form['prompt']
|
| 179 |
length_limit = int(form['length'])
|
| 180 |
temperature = float(form['temperature'])
|
|
|
|
| 192 |
return {'result': 'error', 'type': 'generate', 'prompt': prompt, 'message': f'Error: {e}.'}
|
| 193 |
|
| 194 |
@app.get('/infill')
|
| 195 |
+
# async def infill_maybe(request: Request):
|
| 196 |
async def infill_maybe(info: str):
|
| 197 |
# form = await info.json()
|
| 198 |
+
# form = await request.json()
|
| 199 |
+
# info is a base64-encoded, url-escaped json string (since GET doesn't support a body, and POST leads to CORS issues)
|
| 200 |
+
# fix padding, following https://stackoverflow.com/a/9956217/1319683
|
| 201 |
+
print(info)
|
| 202 |
+
info = base64.urlsafe_b64decode(info + '=' * (4 - len(info) % 4)).decode('utf-8')
|
| 203 |
+
print(info)
|
| 204 |
form = json.loads(info)
|
| 205 |
+
pprint.pprint(form)
|
| 206 |
length_limit = int(form['length'])
|
| 207 |
temperature = float(form['temperature'])
|
| 208 |
max_retries = 1
|
static/index.html
CHANGED
|
@@ -134,6 +134,7 @@ label {
|
|
| 134 |
<span class="softspan">Infill Examples:</span>
|
| 135 |
<br>
|
| 136 |
<span class="softspan"><a href='javascript:select_example("type-pred");'>Type prediction</a></span>
|
|
|
|
| 137 |
<span class="softspan"><a href='javascript:select_example("docstring");'>Function to docstring</a></span>
|
| 138 |
<span class="softspan"><a href='javascript:select_example("python-infill2");'>Docstring to function</a></span>
|
| 139 |
<span class="softspan"><a href='javascript:select_example("class");'>Class generation</a></span>
|
|
@@ -252,12 +253,20 @@ def <infill>
|
|
| 252 |
"temperature": 0.2,
|
| 253 |
"mode": "python"
|
| 254 |
},
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
"type-pred": {
|
| 257 |
"prompt":
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
def count_words(filename: str) -> <infill>
|
| 261 |
"""Count the number of occurrences of each word in the file."""
|
| 262 |
with open(filename, 'r') as f:
|
| 263 |
word_counts = {}
|
|
@@ -310,7 +319,7 @@ def count_words(filename):
|
|
| 310 |
"mode": "python"
|
| 311 |
},
|
| 312 |
"javascript": {
|
| 313 |
-
"prompt": "
|
| 314 |
"length": 64,
|
| 315 |
"temperature": 0.6,
|
| 316 |
"mode": "javascript"
|
|
@@ -529,6 +538,7 @@ function make_generate_listener(url) {
|
|
| 529 |
console.log("Response:");
|
| 530 |
console.log(receive_data);
|
| 531 |
if (receive_data["result"] == "success") {
|
|
|
|
| 532 |
// $("#prompt").text(data["prompt"]);
|
| 533 |
// $("#response").text(data["text"]);
|
| 534 |
set_text(receive_data["text"]);
|
|
@@ -540,6 +550,7 @@ function make_generate_listener(url) {
|
|
| 540 |
$("#warning").text("");
|
| 541 |
}
|
| 542 |
} else {
|
|
|
|
| 543 |
set_text(receive_data["text"])
|
| 544 |
$("#error").text(receive_data["message"]);
|
| 545 |
}
|
|
@@ -552,13 +563,18 @@ function make_generate_listener(url) {
|
|
| 552 |
$("#error").text(err);
|
| 553 |
}
|
| 554 |
|
| 555 |
-
encoded_data = JSON.stringify(send_data)
|
| 556 |
|
| 557 |
try {
|
| 558 |
const response = await fetch(`${url}?info=${encoded_data}`);
|
|
|
|
|
|
|
|
|
|
|
|
|
| 559 |
if (response.status >= 400) {
|
| 560 |
error(response.statusText);
|
| 561 |
-
|
|
|
|
| 562 |
} else {
|
| 563 |
response.json().then(success).catch(error).finally(complete);
|
| 564 |
}
|
|
|
|
| 134 |
<span class="softspan">Infill Examples:</span>
|
| 135 |
<br>
|
| 136 |
<span class="softspan"><a href='javascript:select_example("type-pred");'>Type prediction</a></span>
|
| 137 |
+
<span class="softspan"><a href='javascript:select_example("multi-region");'>Multi-region</a></span>
|
| 138 |
<span class="softspan"><a href='javascript:select_example("docstring");'>Function to docstring</a></span>
|
| 139 |
<span class="softspan"><a href='javascript:select_example("python-infill2");'>Docstring to function</a></span>
|
| 140 |
<span class="softspan"><a href='javascript:select_example("class");'>Class generation</a></span>
|
|
|
|
| 253 |
"temperature": 0.2,
|
| 254 |
"mode": "python"
|
| 255 |
},
|
| 256 |
+
"multi-region": {
|
| 257 |
+
"prompt":
|
| 258 |
+
`<| file ext=.py |>
|
| 259 |
+
<infill>
|
| 260 |
+
""" Load the given gzip jsonl file. """
|
| 261 |
+
<infill>
|
| 262 |
+
`,
|
| 263 |
+
"length": 64,
|
| 264 |
+
"temperature": 0.2,
|
| 265 |
+
"mode": "python"
|
| 266 |
+
},
|
| 267 |
"type-pred": {
|
| 268 |
"prompt":
|
| 269 |
+
`def count_words(filename: str) -> <infill>
|
|
|
|
|
|
|
| 270 |
"""Count the number of occurrences of each word in the file."""
|
| 271 |
with open(filename, 'r') as f:
|
| 272 |
word_counts = {}
|
|
|
|
| 319 |
"mode": "python"
|
| 320 |
},
|
| 321 |
"javascript": {
|
| 322 |
+
"prompt": "// fetch from the given URL and load the response contents into a new div",
|
| 323 |
"length": 64,
|
| 324 |
"temperature": 0.6,
|
| 325 |
"mode": "javascript"
|
|
|
|
| 538 |
console.log("Response:");
|
| 539 |
console.log(receive_data);
|
| 540 |
if (receive_data["result"] == "success") {
|
| 541 |
+
console.log("success");
|
| 542 |
// $("#prompt").text(data["prompt"]);
|
| 543 |
// $("#response").text(data["text"]);
|
| 544 |
set_text(receive_data["text"]);
|
|
|
|
| 550 |
$("#warning").text("");
|
| 551 |
}
|
| 552 |
} else {
|
| 553 |
+
console.log("error");
|
| 554 |
set_text(receive_data["text"])
|
| 555 |
$("#error").text(receive_data["message"]);
|
| 556 |
}
|
|
|
|
| 563 |
$("#error").text(err);
|
| 564 |
}
|
| 565 |
|
| 566 |
+
encoded_data = encodeURIComponent(btoa(JSON.stringify(send_data)))
|
| 567 |
|
| 568 |
try {
|
| 569 |
const response = await fetch(`${url}?info=${encoded_data}`);
|
| 570 |
+
// const response = await fetch(`${url}` {
|
| 571 |
+
// method: 'GET',
|
| 572 |
+
// body: encoded_data,
|
| 573 |
+
// });
|
| 574 |
if (response.status >= 400) {
|
| 575 |
error(response.statusText);
|
| 576 |
+
console.log("here");
|
| 577 |
+
console.log(response.status);
|
| 578 |
} else {
|
| 579 |
response.json().then(success).catch(error).finally(complete);
|
| 580 |
}
|