File size: 4,023 Bytes
c738e51
6fe9315
 
c738e51
6fe9315
 
 
6a801be
 
c738e51
 
 
 
d81a354
c738e51
 
 
 
 
 
 
d81a354
c738e51
 
 
 
 
329fa80
c738e51
 
329fa80
d81a354
c738e51
6fe9315
c738e51
 
 
6fe9315
 
c738e51
6fe9315
 
6a801be
c738e51
 
 
 
 
 
 
 
 
 
 
6fe9315
c738e51
 
6a801be
c738e51
 
 
 
 
 
6fe9315
d81a354
c738e51
 
329fa80
34a42fd
6fe9315
c738e51
6a801be
34a42fd
 
 
 
 
 
 
c738e51
 
d81a354
c738e51
6fe9315
 
 
c738e51
 
 
 
 
6fe9315
 
 
 
 
 
6a801be
6fe9315
 
 
ec49698
6fe9315
 
c738e51
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# app.py  ── run with  DISABLE_INFERENCE=1  to skip model calls
import os

# ----------------──── 1. Optional Gradio UI  ──────────────────────────
try:
    import gradio as gr
    USE_GRADIO = True
except ImportError:
    USE_GRADIO = False
    print("Gradio not installed β†’ UI disabled.")

# ----------------──── 2. Inference toggle  ────────────────────────────
DISABLE_INFERENCE = os.getenv("DISABLE_INFERENCE", "0") == "1"

if not DISABLE_INFERENCE:
    try:
        from huggingface_hub import InferenceClient
    except ImportError as err:
        raise RuntimeError(
            "huggingface_hub missing -- install or set DISABLE_INFERENCE=1"
        ) from err

# ----------------──── 3. Config  ──────────────────────────────────────
MODEL_ID          = "Dushyant4342/ft-llama3-8b-credit-analyst"
CONTEXT_WINDOW    = 4096
RESERVED_TOKENS   = 512
MAX_HISTORY       = 10
DEFAULT_SYSTEM = (
    "You are an expert credit analyst. Summarise key positive and negative "
    "changes in a customer's credit profile."
)

_client: "InferenceClient|None" = None          # type hint for clarity


def _get_client():
    """Lazy-init the InferenceClient unless inference is disabled."""
    global _client
    if _client is None:
        _client = InferenceClient(repo_id=MODEL_ID)  # picks up HF_HUB_TOKEN
    return _client


# ----------------──── 4. Chat handler  ────────────────────────────────
def respond(user_msg, history, system_msg, max_tokens, temperature, top_p):
    """Gradio streaming callback -- transparently stubs when inference is off."""
    # ╭─ Build the prompt
    sys = system_msg.strip() or DEFAULT_SYSTEM
    msgs = [{"role": "system", "content": sys}]
    for u, a in history[-MAX_HISTORY:]:
        if u.strip(): msgs.append({"role": "user", "content": u.strip()})
        if a.strip(): msgs.append({"role": "assistant", "content": a.strip()})
    if user_msg.strip():
        msgs.append({"role": "user", "content": user_msg.strip()})

    # ╭─ Token budget guard
    budget = min(max_tokens, CONTEXT_WINDOW - RESERVED_TOKENS)
    if budget <= 0:
        yield "[Error] token budget exhausted."
        return

    # ╭─ 4a Stub path (no inference)
    if DISABLE_INFERENCE:
        yield f"(stub) echo: {user_msg}"
        return

    # ╭─ 4b Live inference path
    client = _get_client()
    try:
        for chunk in client.chat_completion(
            model=MODEL_ID,
            messages=msgs,
            max_tokens=budget,
            temperature=temperature,
            top_p=top_p,
            stream=True,
        ):
            delta = chunk.choices[0].delta.get("content", "")
            if delta:
                yield delta
    except Exception as err:          # broad to catch connection + auth errors
        yield f"[Error] inference failed: {err}"

# ----------------──── 5. Optional Gradio UI  ──────────────────────────
if USE_GRADIO:
    demo = gr.ChatInterface(
        fn=respond,
        title="Credit Analyst Bot (stub-ready)",
        description=(
            "Set <code>DISABLE_INFERENCE=1</code> to work offline. "
            "Otherwise the app will call the hosted model."
        ),
        additional_inputs=[
            gr.Textbox(value=DEFAULT_SYSTEM, label="System message"),
            gr.Slider(1, CONTEXT_WINDOW, value=512, step=1, label="Max new tokens"),
            gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature"),
            gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p"),
        ],
        type="messages",
    )

    if __name__ == "__main__":
        demo.launch(show_error=True)
else:
    if __name__ == "__main__":
        print("Gradio UI disabled.")