File size: 8,250 Bytes
a4b70d9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
from __future__ import annotations
import json
import re
import os
import requests
import base64
import uuid
from typing import AsyncIterator
try:
from curl_cffi.requests import Session
from curl_cffi import CurlMime
has_curl_cffi = True
except ImportError:
has_curl_cffi = False
from ...base_provider import ProviderModelMixin, AsyncAuthedProvider, AuthResult
from ...helper import format_prompt, format_media_prompt, get_last_user_message
from ....typing import AsyncResult, Messages, Cookies, MediaListType
from ....errors import MissingRequirementsError, MissingAuthError, ResponseError
from ....image import to_bytes
from ....requests import get_args_from_nodriver, DEFAULT_HEADERS
from ....requests.raise_for_status import raise_for_status
from ....providers.response import JsonConversation, ImageResponse, Sources, TitleGeneration, Reasoning, RequestLogin, FinishReason
from ....cookies import get_cookies
from ....tools.media import merge_media
from .models import default_model, default_vision_model, fallback_models, image_models, model_aliases, image_model_aliases
from .... import debug
class Conversation(JsonConversation):
def __init__(self, models: dict):
self.models: dict = models
class HuggingChat(AsyncAuthedProvider, ProviderModelMixin):
domain = "huggingface.co"
origin = f"https://{domain}"
url = f"{origin}/chat"
working = False
use_nodriver = True
supports_stream = True
needs_auth = True
default_model = default_model
default_vision_model = default_vision_model
model_aliases = {**model_aliases, **image_model_aliases}
image_models = image_models
text_models = fallback_models
@classmethod
def get_models(cls):
if not cls.models:
try:
models = requests.get(f"{cls.url}/api/v2/models").json().get("json")
cls.text_models = [model["id"] for model in models]
cls.models = cls.text_models + cls.image_models
cls.vision_models = [model["id"] for model in models if model["multimodal"]]
except Exception as e:
debug.error(f"{cls.__name__}: Error reading models: {type(e).__name__}: {e}")
cls.models = [*fallback_models]
return cls.models
@classmethod
async def on_auth_async(cls, cookies: Cookies = None, proxy: str = None, **kwargs) -> AsyncIterator:
if cookies is None:
cookies = get_cookies(cls.domain, single_browser=True)
if "hf-chat" in cookies:
yield AuthResult(
cookies=cookies,
headers=DEFAULT_HEADERS,
impersonate="chrome"
)
return
if cls.needs_auth:
yield RequestLogin(cls.__name__, os.environ.get("G4F_LOGIN_URL") or "")
yield AuthResult(
**await get_args_from_nodriver(
cls.url,
proxy=proxy,
wait_for='form[action$="/logout"]'
)
)
else:
yield AuthResult(
cookies={
"hf-chat": str(uuid.uuid4()) # Generate a session ID
},
headers=DEFAULT_HEADERS,
impersonate="chrome"
)
@classmethod
async def create_authed(
cls,
model: str,
messages: Messages,
auth_result: AuthResult,
prompt: str = None,
media: MediaListType = None,
return_conversation: bool = True,
conversation: Conversation = None,
web_search: bool = False,
**kwargs
) -> AsyncResult:
if not has_curl_cffi:
raise MissingRequirementsError('Install "curl_cffi" package | pip install -U curl_cffi')
if not model and media is not None:
model = cls.default_vision_model
model = cls.get_model(model)
session = Session(**auth_result.get_dict())
if conversation is None or not hasattr(conversation, "models"):
conversation = Conversation({})
if model not in conversation.models:
conversationId = cls.create_conversation(session, model)
debug.log(f"Conversation created: {json.dumps(conversationId[8:] + '...')}")
messageId = cls.fetch_message_id(session, conversationId)
conversation.models[model] = {"conversationId": conversationId, "messageId": messageId}
if return_conversation:
yield conversation
inputs = format_prompt(messages)
else:
conversationId = conversation.models[model]["conversationId"]
conversation.models[model]["messageId"] = cls.fetch_message_id(session, conversationId)
inputs = get_last_user_message(messages)
settings = {
"inputs": inputs,
"id": conversation.models[model]["messageId"],
"is_retry": False,
"is_continue": False,
"web_search": web_search,
"tools": ["000000000000000000000001"] if model in cls.image_models else [],
}
headers = {
'accept': '*/*',
'origin': cls.origin,
'referer': f'{cls.url}/conversation/{conversationId}',
}
data = CurlMime()
data.addpart('data', data=json.dumps(settings, separators=(',', ':')))
for image, filename in merge_media(media, messages):
data.addpart(
"files",
filename=f"base64;{filename}",
data=base64.b64encode(to_bytes(image))
)
response = session.post(
f'{cls.url}/conversation/{conversationId}',
headers=headers,
multipart=data,
stream=True
)
raise_for_status(response)
sources = None
for line in response.iter_lines():
if not line:
continue
try:
line = json.loads(line)
except json.JSONDecodeError as e:
debug.error(f"Failed to decode JSON: {line}, error: {e}")
continue
if "type" not in line:
raise RuntimeError(f"Response: {line}")
elif line["type"] == "stream":
yield line["token"].replace('\u0000', '')
elif line["type"] == "finalAnswer":
if sources is not None:
yield sources
yield FinishReason("stop")
break
elif line["type"] == "file":
url = f"{cls.url}/conversation/{conversationId}/output/{line['sha']}"
yield ImageResponse(url, format_media_prompt(messages, prompt), options={"cookies": auth_result.cookies})
elif line["type"] == "webSearch" and "sources" in line:
sources = Sources(line["sources"])
elif line["type"] == "title":
yield TitleGeneration(line["title"])
elif line["type"] == "reasoning":
yield Reasoning(line.get("token"), status=line.get("status"))
@classmethod
def create_conversation(cls, session: Session, model: str):
if model in cls.image_models:
model = cls.default_model
json_data = {
'model': model,
}
response = session.post(f'{cls.url}/conversation', json=json_data)
if response.status_code == 401:
raise MissingAuthError(response.text)
if response.status_code == 400:
raise ResponseError(f"{response.text}: Model: {model}")
raise_for_status(response)
return response.json().get('conversationId')
@classmethod
def fetch_message_id(cls, session: Session, conversation_id: str):
response = session.get(
f"{cls.url}/api/v2/conversations/{conversation_id}"
)
raise_for_status(response)
try:
data = response.json()['json']
except json.JSONDecodeError as e:
debug.error(f"Failed to decode JSON: {e}")
return None
messages_data_list = data.get("messages", [])
return messages_data_list[-1]['id'] if messages_data_list else None |