Jensin commited on
Commit
6fe9315
·
verified ·
1 Parent(s): 04f4a68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -77
app.py CHANGED
@@ -1,73 +1,66 @@
1
- import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from huggingface_hub import InferenceClient
3
- from typing import List, Tuple, Iterator
4
- from huggingface_hub.utils import HfHubHTTPError
5
 
6
- # Custom HuggingFace Inference endpoint
7
- MODEL_ID = Dushyant4342/ft-llama3-8b-credit-analyst
8
- # Context window constraints (Llama‑3‑8B supports ~4096 tokens)
9
- CONTEXT_WINDOW = 4096
10
- RESERVED_TOKENS = 512 # reserve space for response
11
- MAX_HISTORY_ENTRIES = 10 # cap history length to prevent context overflow
12
 
13
- # Default system prompt
14
  DEFAULT_SYSTEM = (
15
  "You are an expert credit analyst. Your role is to analyze a customer's "
16
  "credit data and generate a concise summary of the most important "
17
  "positive and negative changes."
18
  )
19
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def respond(
22
- user_message: str,
23
- history: List[Tuple[str, str]],
24
- system_message: str,
25
- max_tokens: int,
26
- temperature: float,
27
- top_p: float,
28
- ) -> Iterator[str]:
29
- """
30
- Builds a chat history payload and streams back the assistant response.
31
- Caps history length, instantiates a fresh InferenceClient per request, and
32
- ensures max_tokens is clipped to context window limits.
33
- Yields streamed token deltas or an error message if the call fails.
34
- """
35
- # Initialize a new client for this request (avoids lock contention)
36
- client = InferenceClient(base_url=ENDPOINT)
37
-
38
- # Strip system_message once
39
- sys_content = system_message.strip()
40
-
41
- # Trim history to the most recent entries
42
- trimmed = history[-MAX_HISTORY_ENTRIES:]
43
-
44
- # Build messages list, starting with system prompt
45
- messages = [{
46
- "role": "system",
47
- "content": sys_content if sys_content else DEFAULT_SYSTEM
48
- }]
49
-
50
- # Append trimmed, non-empty history entries
51
- for usr, bot in trimmed:
52
- usr_text = usr.strip()
53
- bot_text = bot.strip()
54
- if usr_text:
55
- messages.append({"role": "user", "content": usr_text})
56
- if bot_text:
57
- messages.append({"role": "assistant", "content": bot_text})
58
-
59
- # Append current user message
60
- um_text = user_message.strip()
61
- if um_text:
62
- messages.append({"role": "user", "content": um_text})
63
-
64
- # Clip max_tokens to fit within context
65
- allowed = max(0, CONTEXT_WINDOW - RESERVED_TOKENS)
66
  max_tok = min(max_tokens, allowed)
 
 
 
67
 
68
- # Stream generation without shared locks
69
  try:
70
  for chunk in client.chat_completion(
 
71
  messages=messages,
72
  max_tokens=max_tok,
73
  temperature=temperature,
@@ -77,30 +70,101 @@ def respond(
77
  delta = chunk.choices[0].delta.get("content", "")
78
  if delta:
79
  yield delta
80
- except HfHubHTTPError as e:
81
  yield f"[Error] Inference request failed: {e}"
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- # Gradio Chat UI setup
85
-
86
- demo = gr.ChatInterface(
87
- fn=respond,
88
- title="Credit Analyst Bot",
89
- description="Ask about customer credit profile changes.",
90
- additional_inputs=[
91
- gr.Textbox(value=DEFAULT_SYSTEM, label="System message"),
92
- gr.Slider(minimum=1, maximum=4096, value=512, step=1, label="Max new tokens"),
93
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
94
- gr.Slider(
95
- minimum=0.1,
96
- maximum=1.0,
97
- value=0.9,
98
- step=0.05,
99
- label="Top-p (nucleus sampling)",
100
- ),
101
- ],
102
- type='messages'
103
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
- demo.launch()
 
 
1
+ # File: app.py
2
+ import os
3
+
4
+ # Attempt Gradio import; disable UI if ssl is unavailable
5
+ try:
6
+ import gradio as gr
7
+ import ssl # noqa: F401
8
+ USE_GRADIO = True
9
+ except ModuleNotFoundError as e:
10
+ if 'ssl' in str(e):
11
+ USE_GRADIO = False
12
+ print("Warning: ssl module unavailable; Gradio UI disabled.")
13
+ else:
14
+ raise
15
+
16
  from huggingface_hub import InferenceClient
17
+ import requests # for HTTP error handling
 
18
 
19
+ MODEL_ID = "Dushyant4342/ft-llama3-8b-credit-analyst"
20
+ ENDPOINT = f"https://api-inference.huggingface.co/models/{MODEL_ID}"
21
+ CONTEXT_WINDOW = 4096 # model context size
22
+ RESERVED_TOKENS = 512 # space reserved for generation
23
+ MAX_HISTORY_ENTRIES = 10 # context truncation length
 
24
 
 
25
  DEFAULT_SYSTEM = (
26
  "You are an expert credit analyst. Your role is to analyze a customer's "
27
  "credit data and generate a concise summary of the most important "
28
  "positive and negative changes."
29
  )
30
 
31
+ _client = None
32
+
33
+ def get_client():
34
+ """Singleton InferenceClient to reduce instantiation overhead."""
35
+ global _client
36
+ if _client is None:
37
+ _client = InferenceClient(base_url=ENDPOINT)
38
+ return _client
39
+
40
+
41
+ def respond(user_message, history, system_message, max_tokens, temperature, top_p):
42
+ client = get_client()
43
 
44
+ # Build system + history + user messages
45
+ sys_content = system_message.strip() or DEFAULT_SYSTEM
46
+ messages = [{"role": "system", "content": sys_content}]
47
+ for usr, bot in history[-MAX_HISTORY_ENTRIES:]:
48
+ if usr.strip(): messages.append({"role": "user", "content": usr.strip()})
49
+ if bot.strip(): messages.append({"role": "assistant", "content": bot.strip()})
50
+ if user_message.strip():
51
+ messages.append({"role": "user", "content": user_message.strip()})
52
+
53
+ # Token budget guard
54
+ allowed = CONTEXT_WINDOW - RESERVED_TOKENS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  max_tok = min(max_tokens, allowed)
56
+ if max_tok <= 0:
57
+ yield "[Error] Token budget exhausted."
58
+ return
59
 
60
+ # Stream response, catch network errors
61
  try:
62
  for chunk in client.chat_completion(
63
+ model=MODEL_ID,
64
  messages=messages,
65
  max_tokens=max_tok,
66
  temperature=temperature,
 
70
  delta = chunk.choices[0].delta.get("content", "")
71
  if delta:
72
  yield delta
73
+ except requests.exceptions.RequestException as e:
74
  yield f"[Error] Inference request failed: {e}"
75
 
76
+ if USE_GRADIO:
77
+ demo = gr.ChatInterface(
78
+ fn=respond,
79
+ title="Credit Analyst Bot",
80
+ description="Ask about customer credit profile changes.",
81
+ additional_inputs=[
82
+ gr.Textbox(value=DEFAULT_SYSTEM, label="System message"),
83
+ gr.Slider(1, CONTEXT_WINDOW, value=512, step=1, label="Max new tokens"),
84
+ gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
85
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
86
+ ],
87
+ type="messages"
88
+ )
89
+
90
+ if __name__ == "__main__":
91
+ demo.launch()
92
+ else:
93
+ if __name__ == "__main__":
94
+ print("Gradio UI disabled. Use local_inference.py for direct calls.")
95
+
96
+
97
+ # File: local_inference.py
98
+ import torch
99
+ from transformers import AutoTokenizer, AutoModelForCausalLM
100
+
101
+ MODEL_ID = "Dushyant4342/ft-llama3-8b-credit-analyst"
102
 
103
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
104
+ model = AutoModelForCausalLM.from_pretrained(
105
+ MODEL_ID,
106
+ torch_dtype=torch.bfloat16,
107
+ device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
109
+ model.eval()
110
+
111
+ def summarize_credit(customer_data: str, user_command: str,
112
+ max_new_tokens=128, temperature=0.6, top_p=0.9):
113
+ """Return a concise credit summary given structured data and a user command."""
114
+ system_prompt = DEFAULT_SYSTEM
115
+ messages = [
116
+ {"role": "system", "content": system_prompt},
117
+ {"role": "user", "content": f"{user_command}\n\n--- DATA ---\n{customer_data}"}
118
+ ]
119
+ prompt = tokenizer.apply_chat_template(
120
+ messages, tokenize=False, add_generation_prompt=True
121
+ )
122
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
123
+
124
+ with torch.no_grad():
125
+ outputs = model.generate(
126
+ **inputs,
127
+ max_new_tokens=max_new_tokens,
128
+ do_sample=True,
129
+ temperature=temperature,
130
+ top_p=top_p,
131
+ eos_token_id=tokenizer.eos_token_id,
132
+ )
133
+ # Decode only the generated part
134
+ gen = outputs[0][inputs["input_ids"].shape[-1]:]
135
+ return tokenizer.decode(gen, skip_special_tokens=True)
136
+
137
+
138
+ # File: tests/test_credit_analyst.py
139
+ import unittest
140
+ from local_inference import summarize_credit
141
+
142
+ class TestCreditAnalystSummarization(unittest.TestCase):
143
+ def test_basic_output_type(self):
144
+ data = (
145
+ "--- Credit Profile Report ---\n"
146
+ "Risk Score: 600 (was 650)"
147
+ )
148
+ cmd = "Summarize changes in one sentence."
149
+ output = summarize_credit(data, cmd, max_new_tokens=32, temperature=0.0, top_p=1.0)
150
+ self.assertIsInstance(output, str)
151
+ self.assertTrue(len(output) > 0)
152
+
153
+ def test_empty_data(self):
154
+ data = ""
155
+ cmd = "Summarize changes."
156
+ output = summarize_credit(data, cmd, max_new_tokens=16, temperature=0.0, top_p=1.0)
157
+ self.assertIsInstance(output, str)
158
+
159
+ def test_token_budget_exhaustion(self):
160
+ # Simulate a scenario where max_tokens <= RESERVED_TOKENS
161
+ # This uses the respond() logic; here we simply ensure summarize_credit doesn't error
162
+ data = "--- Credit Profile Report ---"
163
+ cmd = "Summarize."
164
+ # Pass a very low max_new_tokens to test generate with zero budget
165
+ output = summarize_credit(data, cmd, max_new_tokens=0, temperature=0.0, top_p=1.0)
166
+ self.assertIsInstance(output, str)
167
 
168
  if __name__ == "__main__":
169
+ unittest.main()
170
+