lamhieu commited on
Commit
58a5b8d
·
1 Parent(s): 9604bdd

chore: update something

Browse files
Files changed (1) hide show
  1. lightweight_embeddings/router.py +18 -2
lightweight_embeddings/router.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  from datetime import datetime
6
  from typing import Dict, List, Union
7
 
8
- from fastapi import APIRouter, BackgroundTasks, HTTPException
9
  from pydantic import BaseModel, Field
10
 
11
  from .analytics import Analytics
@@ -117,11 +117,27 @@ analytics = Analytics(
117
 
118
  @router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
119
  async def create_embeddings(
120
- request: EmbeddingRequest, background_tasks: BackgroundTasks
 
 
121
  ):
122
  """
123
  Generate embeddings for the given text or image inputs.
124
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  try:
126
  modality = detect_model_kind(request.model)
127
  embeddings = await embeddings_service.generate_embeddings(
 
5
  from datetime import datetime
6
  from typing import Dict, List, Union
7
 
8
+ from fastapi import APIRouter, BackgroundTasks, HTTPException, Header
9
  from pydantic import BaseModel, Field
10
 
11
  from .analytics import Analytics
 
117
 
118
  @router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
119
  async def create_embeddings(
120
+ request: EmbeddingRequest,
121
+ background_tasks: BackgroundTasks,
122
+ authorization: str = Header(None)
123
  ):
124
  """
125
  Generate embeddings for the given text or image inputs.
126
  """
127
+ # Check authorization
128
+ expected_token = os.environ.get("ACCESS_TOKEN")
129
+ if expected_token:
130
+ if not authorization:
131
+ raise HTTPException(status_code=401, detail="Authorization header required")
132
+
133
+ # Support both "Bearer <token>" and plain token formats
134
+ token = authorization
135
+ if authorization.startswith("Bearer "):
136
+ token = authorization[7:] # Remove "Bearer " prefix
137
+
138
+ if token != expected_token:
139
+ raise HTTPException(status_code=401, detail="Invalid authorization token")
140
+
141
  try:
142
  modality = detect_model_kind(request.model)
143
  embeddings = await embeddings_service.generate_embeddings(