Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -121,14 +121,12 @@ class GPT(nn.Module):
|
|
| 121 |
|
| 122 |
return logits, loss
|
| 123 |
|
| 124 |
-
|
| 125 |
def load_model(model_path):
|
| 126 |
config = GPTConfig()
|
| 127 |
model = GPT(config)
|
| 128 |
|
| 129 |
-
checkpoint = torch.load(model_path, map_location=torch.device('
|
| 130 |
-
|
| 131 |
-
print("Checkpoint keys:", checkpoint.keys()) # Debug print
|
| 132 |
|
| 133 |
if 'model_state_dict' in checkpoint:
|
| 134 |
model.load_state_dict(checkpoint['model_state_dict'])
|
|
@@ -136,24 +134,17 @@ def load_model(model_path):
|
|
| 136 |
model.load_state_dict(checkpoint)
|
| 137 |
|
| 138 |
model.eval()
|
|
|
|
| 139 |
return model
|
| 140 |
|
| 141 |
# Load the model
|
| 142 |
model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
|
| 143 |
enc = tiktoken.get_encoding('gpt2')
|
| 144 |
|
| 145 |
-
#
|
| 146 |
-
|
| 147 |
-
import torch.nn as nn
|
| 148 |
-
from torch.nn import functional as F
|
| 149 |
-
import tiktoken
|
| 150 |
-
import gradio as gr
|
| 151 |
-
|
| 152 |
-
# [Your existing model code remains unchanged]
|
| 153 |
-
|
| 154 |
-
# Modify the generate_text function to be asynchronous
|
| 155 |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
| 156 |
-
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0)
|
| 157 |
generated = []
|
| 158 |
|
| 159 |
with torch.no_grad():
|
|
@@ -179,7 +170,9 @@ async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
|
| 179 |
|
| 180 |
if len(generated) == max_length:
|
| 181 |
yield "... (output truncated due to length)"
|
| 182 |
-
|
|
|
|
|
|
|
| 183 |
async def gradio_generate(prompt, max_length, temperature, top_k):
|
| 184 |
output = ""
|
| 185 |
async for token in generate_text(prompt, max_length, temperature, top_k):
|
|
|
|
| 121 |
|
| 122 |
return logits, loss
|
| 123 |
|
| 124 |
+
@spaces.GPU
|
| 125 |
def load_model(model_path):
|
| 126 |
config = GPTConfig()
|
| 127 |
model = GPT(config)
|
| 128 |
|
| 129 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cuda'))
|
|
|
|
|
|
|
| 130 |
|
| 131 |
if 'model_state_dict' in checkpoint:
|
| 132 |
model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
| 134 |
model.load_state_dict(checkpoint)
|
| 135 |
|
| 136 |
model.eval()
|
| 137 |
+
model.to('cuda')
|
| 138 |
return model
|
| 139 |
|
| 140 |
# Load the model
|
| 141 |
model = load_model('gpt_model.pth') # Replace with the actual path to your .pt file
|
| 142 |
enc = tiktoken.get_encoding('gpt2')
|
| 143 |
|
| 144 |
+
# Update the generate_text function
|
| 145 |
+
@spaces.GPU(duration=60) # Adjust duration as needed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
async def generate_text(prompt, max_length=432, temperature=0.8, top_k=40):
|
| 147 |
+
input_ids = torch.tensor(enc.encode(prompt)).unsqueeze(0).cuda()
|
| 148 |
generated = []
|
| 149 |
|
| 150 |
with torch.no_grad():
|
|
|
|
| 170 |
|
| 171 |
if len(generated) == max_length:
|
| 172 |
yield "... (output truncated due to length)"
|
| 173 |
+
|
| 174 |
+
# Update the gradio_generate function
|
| 175 |
+
@spaces.GPU(duration=60) # Adjust duration as needed
|
| 176 |
async def gradio_generate(prompt, max_length, temperature, top_k):
|
| 177 |
output = ""
|
| 178 |
async for token in generate_text(prompt, max_length, temperature, top_k):
|