Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,441 Bytes
14b8ec7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e7319ea5",
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Optional\n",
"from langchain.callbacks.manager import CallbackManagerForLLMRun\n",
"from pydantic import Field\n",
"from gradio_client import Client, handle_file\n",
"\n",
"# Use local BaseChatModel if available, else fallback to langchain_core\n",
"try:\n",
" from BaseChatModel import BaseChatModel\n",
"except ImportError:\n",
" from langchain_core.language_models.chat_models import BaseChatModel\n",
"\n",
"try:\n",
" from langchain_core.messages.base import BaseMessage\n",
"except ImportError:\n",
" from langchain.schema import BaseMessage\n",
"\n",
"try:\n",
" from langchain_core.messages import AIMessage\n",
"except ImportError:\n",
" from langchain.schema import AIMessage\n",
"\n",
"try:\n",
" from langchain_core.outputs import ChatResult\n",
"except ImportError:\n",
" from langchain.schema import ChatResult\n",
"\n",
"try:\n",
" from langchain_core.outputs import ChatGeneration\n",
"except ImportError:\n",
" from langchain.schema import ChatGeneration\n",
"\n",
"\n",
"class GradioChatModel(BaseChatModel):\n",
" client: Client = Field(default=None, description=\"Gradio client for API communication\")\n",
"\n",
" def __init__(self, client: Client = None, **kwargs):\n",
" super().__init__(**kwargs)\n",
" if client is None:\n",
" client = Client(\"apjanco/fantastic-futures\")\n",
" object.__setattr__(self, 'client', client)\n",
"\n",
" @property\n",
" def _llm_type(self) -> str:\n",
" return \"gradio_chat_model\"\n",
"\n",
" def _generate(\n",
" self,\n",
" messages: List[BaseMessage],\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" ) -> ChatResult:\n",
" # Use the first message as prompt, and optionally extract image url if present\n",
" prompt = None\n",
" image_url = None\n",
" for msg in messages:\n",
" if hasattr(msg, \"content\") and msg.content:\n",
" if prompt is None:\n",
" prompt = msg.content\n",
" # Optionally, look for an image url in the message metadata or content\n",
" if hasattr(msg, \"image\") and msg.image:\n",
" image_url = msg.image\n",
" if prompt is None:\n",
" prompt = \"Hello!!\"\n",
" if image_url is None:\n",
" # fallback image\n",
" image_url = 'https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png'\n",
"\n",
" image_file = handle_file(image_url)\n",
" response = self.client.predict(\n",
" image=image_file,\n",
" model_id='nanonets/Nanonets-OCR-s',\n",
" prompt=prompt,\n",
" api_name=\"/run_example\"\n",
" )\n",
" # The response may be a string or dict; wrap as AIMessage\n",
" if isinstance(response, dict) and \"message\" in response:\n",
" content = response[\"message\"]\n",
" else:\n",
" content = str(response)\n",
" message = AIMessage(content=content)\n",
" # Wrap the AIMessage in a ChatGeneration object\n",
" chat_generation = ChatGeneration(message=message)\n",
" return ChatResult(generations=[chat_generation])"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "e9a50bd3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded as API: https://apjanco-fantastic-futures.hf.space ✔\n",
"content='This is an icon of a bus.'\n"
]
}
],
"source": [
"from langchain.schema import HumanMessage\n",
"\n",
"# Create a HumanMessage with content and image attribute\n",
"class HumanMessageWithImage(HumanMessage):\n",
" def __init__(self, content, image=None, **kwargs):\n",
" super().__init__(content=content, **kwargs)\n",
" self.image = image\n",
"\n",
"custom_llm = GradioChatModel()\n",
"image_url = \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\"\n",
"prompt = \"what is this?\"\n",
"msg = HumanMessageWithImage(content=prompt, image=image_url)\n",
"\n",
"# Call invoke with a list of messages\n",
"result = custom_llm.invoke([msg])\n",
"print(result)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "920ba4af",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|