Spaces:
Runtime error
Runtime error
GPU inference
Browse files- app.py +1 -1
- requirements.txt +2 -1
- utils/generation.py +6 -1
app.py
CHANGED
|
@@ -51,7 +51,7 @@ outro_text ="""
|
|
| 51 |
- [] support FIM task for better model context
|
| 52 |
- [x] include some context for prompt (title, comments before a functions) - now takes all comments directly before a function as well as all comments at the beginning inside a function. (misses comments between argument list and body)
|
| 53 |
- [] gradio examples
|
| 54 |
-
- [] use GPU if available, respect memory restrictions.
|
| 55 |
- [x] stream model generation (maybe in a new window?) - janky solution and only sometimes hangs up
|
| 56 |
- [] 2nd iFrame needs a lot of fixing (I am not a web developer, need help) BUG:background is white, so colors are wrong. Shadertoy uses black background (or we ignore alpha).
|
| 57 |
- [] (optional) filtering the dataset by license?
|
|
|
|
| 51 |
- [] support FIM task for better model context
|
| 52 |
- [x] include some context for prompt (title, comments before a functions) - now takes all comments directly before a function as well as all comments at the beginning inside a function. (misses comments between argument list and body)
|
| 53 |
- [] gradio examples
|
| 54 |
+
- [x] use GPU if available, respect memory restrictions (implemented via accelerate.Accelerator.device in utils.generation.py), tested with A750 successfully!
|
| 55 |
- [x] stream model generation (maybe in a new window?) - janky solution and only sometimes hangs up
|
| 56 |
- [] 2nd iFrame needs a lot of fixing (I am not a web developer, need help) BUG:background is white, so colors are wrong. Shadertoy uses black background (or we ignore alpha).
|
| 57 |
- [] (optional) filtering the dataset by license?
|
requirements.txt
CHANGED
|
@@ -5,4 +5,5 @@ torch
|
|
| 5 |
pillow
|
| 6 |
gradio
|
| 7 |
jupylet
|
| 8 |
-
tree-sitter
|
|
|
|
|
|
| 5 |
pillow
|
| 6 |
gradio
|
| 7 |
jupylet
|
| 8 |
+
tree-sitter
|
| 9 |
+
accelerate
|
utils/generation.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
from transformers import TextIteratorStreamer
|
| 2 |
from threading import Thread
|
| 3 |
from .tree_utils import full_func_head, grab_before_comments
|
|
@@ -15,17 +16,21 @@ def combine_generation_kwargs(temperature=2.0, max_new_tokens=512, top_p=0.95, r
|
|
| 15 |
|
| 16 |
|
| 17 |
def stream_generation(prompt:str, pipe, gen_kwargs:dict):
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
Text generation function
|
| 20 |
Args:
|
| 21 |
prompt (str): The context to start generation from.
|
| 22 |
-
pipe (Pipeline): The pipeline to use for generation
|
| 23 |
gen_kwargs (dict): The generation kwargs.
|
| 24 |
Returns:
|
| 25 |
str: The generated text. (it iterates over time)
|
| 26 |
"""
|
| 27 |
# Tokenize the model_context
|
| 28 |
model_inputs = pipe.tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
|
| 31 |
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
|
|
|
|
| 1 |
+
from accelerate import Accelerator
|
| 2 |
from transformers import TextIteratorStreamer
|
| 3 |
from threading import Thread
|
| 4 |
from .tree_utils import full_func_head, grab_before_comments
|
|
|
|
| 16 |
|
| 17 |
|
| 18 |
def stream_generation(prompt:str, pipe, gen_kwargs:dict):
|
| 19 |
+
accelerator = Accelerator()
|
| 20 |
+
device = accelerator.device
|
| 21 |
"""
|
| 22 |
Text generation function
|
| 23 |
Args:
|
| 24 |
prompt (str): The context to start generation from.
|
| 25 |
+
pipe (Pipeline): The pipeline to use for generation (we take the model and tokenizer form it)
|
| 26 |
gen_kwargs (dict): The generation kwargs.
|
| 27 |
Returns:
|
| 28 |
str: The generated text. (it iterates over time)
|
| 29 |
"""
|
| 30 |
# Tokenize the model_context
|
| 31 |
model_inputs = pipe.tokenizer(prompt, return_tensors="pt")
|
| 32 |
+
model_inputs.to(device)
|
| 33 |
+
model = pipe.model.to(device) #is this also required?
|
| 34 |
|
| 35 |
# Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
|
| 36 |
# in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
|