Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -129,13 +129,15 @@ def load_model(model_path):
|
|
| 129 |
model.to(device)
|
| 130 |
return model
|
| 131 |
|
| 132 |
-
#
|
| 133 |
-
model = load_model('gpt_model.pth')
|
| 134 |
enc = tiktoken.get_encoding('gpt2')
|
| 135 |
|
| 136 |
# Update the generate_text function
|
| 137 |
-
@spaces.GPU(duration=60)
|
| 138 |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
|
|
|
|
|
|
| 139 |
device = next(model.parameters()).device
|
| 140 |
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
|
| 141 |
generated = []
|
|
@@ -159,18 +161,11 @@ async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
|
| 159 |
if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
|
| 160 |
break
|
| 161 |
|
| 162 |
-
await asyncio.sleep(0.02)
|
| 163 |
|
| 164 |
if len(generated) == max_length:
|
| 165 |
yield "... (output truncated due to length)"
|
| 166 |
|
| 167 |
-
# Update the gradio_generate function
|
| 168 |
-
@spaces.GPU(duration=60) # Adjust duration as needed
|
| 169 |
-
async def gradio_generate(prompt, max_length, temperature, top_k):
|
| 170 |
-
output = ""
|
| 171 |
-
async for token in generate_text(prompt, max_length, temperature, top_k):
|
| 172 |
-
output += token
|
| 173 |
-
yield output
|
| 174 |
|
| 175 |
# # Your existing imports and model code here...
|
| 176 |
|
|
|
|
| 129 |
model.to(device)
|
| 130 |
return model
|
| 131 |
|
| 132 |
+
# Don't load the model here
|
| 133 |
+
# model = load_model('gpt_model.pth')
|
| 134 |
enc = tiktoken.get_encoding('gpt2')
|
| 135 |
|
| 136 |
# Update the generate_text function
|
| 137 |
+
@spaces.GPU(duration=60)
|
| 138 |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
| 139 |
+
# Load the model inside the GPU-decorated function
|
| 140 |
+
model = load_model('gpt_model.pth')
|
| 141 |
device = next(model.parameters()).device
|
| 142 |
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).to(device)
|
| 143 |
generated = []
|
|
|
|
| 161 |
if next_token.item() == enc.encode('\n')[0] and len(generated) > 100:
|
| 162 |
break
|
| 163 |
|
| 164 |
+
await asyncio.sleep(0.02)
|
| 165 |
|
| 166 |
if len(generated) == max_length:
|
| 167 |
yield "... (output truncated due to length)"
|
| 168 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
# # Your existing imports and model code here...
|
| 171 |
|