Spaces:
Running
on
T4
Running
on
T4
sparkleman
commited on
Commit
·
271e92e
1
Parent(s):
6706d54
FIX: cpu fallback
Browse files
app.py
CHANGED
|
@@ -6,11 +6,6 @@ from snowflake import SnowflakeGenerator
|
|
| 6 |
|
| 7 |
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
|
| 8 |
|
| 9 |
-
from pynvml import *
|
| 10 |
-
|
| 11 |
-
nvmlInit()
|
| 12 |
-
gpu_h = nvmlDeviceGetHandleByIndex(0)
|
| 13 |
-
|
| 14 |
from typing import List, Optional, Union
|
| 15 |
from pydantic import BaseModel, Field
|
| 16 |
from pydantic_settings import BaseSettings
|
|
@@ -40,6 +35,17 @@ import numpy as np
|
|
| 40 |
import torch
|
| 41 |
|
| 42 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
torch.backends.cudnn.benchmark = True
|
| 44 |
torch.backends.cudnn.allow_tf32 = True
|
| 45 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
@@ -520,19 +526,17 @@ async def chatResponseStream(
|
|
| 520 |
yield "data: [DONE]\n\n"
|
| 521 |
|
| 522 |
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
|
| 526 |
@app.post("/api/v1/chat/completions")
|
| 527 |
async def chat_completions(request: ChatCompletionRequest):
|
| 528 |
completionId = str(next(CompletionIdGenerator))
|
| 529 |
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
|
| 530 |
|
| 531 |
-
def chatResponseStreamDisconnect():
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
|
|
|
| 536 |
|
| 537 |
model_state = None
|
| 538 |
|
|
@@ -545,7 +549,6 @@ async def chat_completions(request: ChatCompletionRequest):
|
|
| 545 |
else:
|
| 546 |
r = await chatResponse(request, model_state, completionId)
|
| 547 |
|
| 548 |
-
|
| 549 |
return r
|
| 550 |
|
| 551 |
|
|
|
|
| 6 |
|
| 7 |
CompletionIdGenerator = SnowflakeGenerator(42, timestamp=1741101491595)
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from typing import List, Optional, Union
|
| 10 |
from pydantic import BaseModel, Field
|
| 11 |
from pydantic_settings import BaseSettings
|
|
|
|
| 35 |
import torch
|
| 36 |
|
| 37 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
+
|
| 39 |
+
if device == "cpu" and CONFIG.STRATEGY != "cpu":
|
| 40 |
+
logger.info(f"Cuda not found, fall back to cpu")
|
| 41 |
+
CONFIG.STRATEGY = "cpu"
|
| 42 |
+
|
| 43 |
+
if "cuda" in CONFIG.STRATEGY:
|
| 44 |
+
from pynvml import *
|
| 45 |
+
|
| 46 |
+
nvmlInit()
|
| 47 |
+
gpu_h = nvmlDeviceGetHandleByIndex(0)
|
| 48 |
+
|
| 49 |
torch.backends.cudnn.benchmark = True
|
| 50 |
torch.backends.cudnn.allow_tf32 = True
|
| 51 |
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
| 526 |
yield "data: [DONE]\n\n"
|
| 527 |
|
| 528 |
|
|
|
|
|
|
|
|
|
|
| 529 |
@app.post("/api/v1/chat/completions")
|
| 530 |
async def chat_completions(request: ChatCompletionRequest):
|
| 531 |
completionId = str(next(CompletionIdGenerator))
|
| 532 |
logger.info(f"[REQ] {completionId} - {request.model_dump()}")
|
| 533 |
|
| 534 |
+
def chatResponseStreamDisconnect():
|
| 535 |
+
if "cuda" in CONFIG.STRATEGY:
|
| 536 |
+
gpu_info = nvmlDeviceGetMemoryInfo(gpu_h)
|
| 537 |
+
logger.info(
|
| 538 |
+
f"[STATUS] vram {gpu_info.total} used {gpu_info.used} free {gpu_info.free}"
|
| 539 |
+
)
|
| 540 |
|
| 541 |
model_state = None
|
| 542 |
|
|
|
|
| 549 |
else:
|
| 550 |
r = await chatResponse(request, model_state, completionId)
|
| 551 |
|
|
|
|
| 552 |
return r
|
| 553 |
|
| 554 |
|