apjanco commited on
Commit
14b8ec7
·
1 Parent(s): 0f90cd2

add code for langgraph llm using gradio client

Browse files
Files changed (1) hide show
  1. gradio_client.ipynb +159 -0
gradio_client.ipynb ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "e7319ea5",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "from typing import List, Optional\n",
11
+ "from langchain.callbacks.manager import CallbackManagerForLLMRun\n",
12
+ "from pydantic import Field\n",
13
+ "from gradio_client import Client, handle_file\n",
14
+ "\n",
15
+ "# Use local BaseChatModel if available, else fallback to langchain_core\n",
16
+ "try:\n",
17
+ " from BaseChatModel import BaseChatModel\n",
18
+ "except ImportError:\n",
19
+ " from langchain_core.language_models.chat_models import BaseChatModel\n",
20
+ "\n",
21
+ "try:\n",
22
+ " from langchain_core.messages.base import BaseMessage\n",
23
+ "except ImportError:\n",
24
+ " from langchain.schema import BaseMessage\n",
25
+ "\n",
26
+ "try:\n",
27
+ " from langchain_core.messages import AIMessage\n",
28
+ "except ImportError:\n",
29
+ " from langchain.schema import AIMessage\n",
30
+ "\n",
31
+ "try:\n",
32
+ " from langchain_core.outputs import ChatResult\n",
33
+ "except ImportError:\n",
34
+ " from langchain.schema import ChatResult\n",
35
+ "\n",
36
+ "try:\n",
37
+ " from langchain_core.outputs import ChatGeneration\n",
38
+ "except ImportError:\n",
39
+ " from langchain.schema import ChatGeneration\n",
40
+ "\n",
41
+ "\n",
42
+ "class GradioChatModel(BaseChatModel):\n",
43
+ " client: Client = Field(default=None, description=\"Gradio client for API communication\")\n",
44
+ "\n",
45
+ " def __init__(self, client: Client = None, **kwargs):\n",
46
+ " super().__init__(**kwargs)\n",
47
+ " if client is None:\n",
48
+ " client = Client(\"apjanco/fantastic-futures\")\n",
49
+ " object.__setattr__(self, 'client', client)\n",
50
+ "\n",
51
+ " @property\n",
52
+ " def _llm_type(self) -> str:\n",
53
+ " return \"gradio_chat_model\"\n",
54
+ "\n",
55
+ " def _generate(\n",
56
+ " self,\n",
57
+ " messages: List[BaseMessage],\n",
58
+ " stop: Optional[List[str]] = None,\n",
59
+ " run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
60
+ " ) -> ChatResult:\n",
61
+ " # Use the first message as prompt, and optionally extract image url if present\n",
62
+ " prompt = None\n",
63
+ " image_url = None\n",
64
+ " for msg in messages:\n",
65
+ " if hasattr(msg, \"content\") and msg.content:\n",
66
+ " if prompt is None:\n",
67
+ " prompt = msg.content\n",
68
+ " # Optionally, look for an image url in the message metadata or content\n",
69
+ " if hasattr(msg, \"image\") and msg.image:\n",
70
+ " image_url = msg.image\n",
71
+ " if prompt is None:\n",
72
+ " prompt = \"Hello!!\"\n",
73
+ " if image_url is None:\n",
74
+ " # fallback image\n",
75
+ " image_url = 'https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png'\n",
76
+ "\n",
77
+ " image_file = handle_file(image_url)\n",
78
+ " response = self.client.predict(\n",
79
+ " image=image_file,\n",
80
+ " model_id='nanonets/Nanonets-OCR-s',\n",
81
+ " prompt=prompt,\n",
82
+ " api_name=\"/run_example\"\n",
83
+ " )\n",
84
+ " # The response may be a string or dict; wrap as AIMessage\n",
85
+ " if isinstance(response, dict) and \"message\" in response:\n",
86
+ " content = response[\"message\"]\n",
87
+ " else:\n",
88
+ " content = str(response)\n",
89
+ " message = AIMessage(content=content)\n",
90
+ " # Wrap the AIMessage in a ChatGeneration object\n",
91
+ " chat_generation = ChatGeneration(message=message)\n",
92
+ " return ChatResult(generations=[chat_generation])"
93
+ ]
94
+ },
95
+ {
96
+ "cell_type": "code",
97
+ "execution_count": 2,
98
+ "id": "e9a50bd3",
99
+ "metadata": {},
100
+ "outputs": [
101
+ {
102
+ "name": "stdout",
103
+ "output_type": "stream",
104
+ "text": [
105
+ "Loaded as API: https://apjanco-fantastic-futures.hf.space ✔\n",
106
+ "content='This is an icon of a bus.'\n"
107
+ ]
108
+ }
109
+ ],
110
+ "source": [
111
+ "from langchain.schema import HumanMessage\n",
112
+ "\n",
113
+ "# Create a HumanMessage with content and image attribute\n",
114
+ "class HumanMessageWithImage(HumanMessage):\n",
115
+ " def __init__(self, content, image=None, **kwargs):\n",
116
+ " super().__init__(content=content, **kwargs)\n",
117
+ " self.image = image\n",
118
+ "\n",
119
+ "custom_llm = GradioChatModel()\n",
120
+ "image_url = \"https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png\"\n",
121
+ "prompt = \"what is this?\"\n",
122
+ "msg = HumanMessageWithImage(content=prompt, image=image_url)\n",
123
+ "\n",
124
+ "# Call invoke with a list of messages\n",
125
+ "result = custom_llm.invoke([msg])\n",
126
+ "print(result)"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": null,
132
+ "id": "920ba4af",
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": []
136
+ }
137
+ ],
138
+ "metadata": {
139
+ "kernelspec": {
140
+ "display_name": "base",
141
+ "language": "python",
142
+ "name": "python3"
143
+ },
144
+ "language_info": {
145
+ "codemirror_mode": {
146
+ "name": "ipython",
147
+ "version": 3
148
+ },
149
+ "file_extension": ".py",
150
+ "mimetype": "text/x-python",
151
+ "name": "python",
152
+ "nbconvert_exporter": "python",
153
+ "pygments_lexer": "ipython3",
154
+ "version": "3.10.14"
155
+ }
156
+ },
157
+ "nbformat": 4,
158
+ "nbformat_minor": 5
159
+ }