Jensin commited on
Commit
d154ff2
Β·
1 Parent(s): 61cba9d

"modified app.py and added new file"

Browse files
Files changed (2) hide show
  1. app.py +175 -5
  2. live_preview_helpers.py +37 -0
app.py CHANGED
@@ -1,7 +1,177 @@
1
- import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import gradio as gr, json, os, random, torch, spaces
3
+ from diffusers import FluxPipeline, AutoencoderKL
4
+ from gradio_client import Client
5
+ from live_preview_helpers import (
6
+ flux_pipe_call_that_returns_an_iterable_of_images as flux_iter,
7
+ )
8
 
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ pipe = FluxPipeline.from_pretrained(
11
+ "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
12
+ ).to(device)
13
+ good_vae = AutoencoderKL.from_pretrained(
14
+ "black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
15
+ ).to(device)
16
+ pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_iter.__get__(pipe)
17
 
18
+ LLM_SPACES = [
19
+ "https://huggingfaceh4-zephyr-chat.hf.space",
20
+ "huggingface-projects/gemma-2-9b-it",
21
+ ]
22
+
23
+ def first_live_space(space_ids: list[str]) -> Client:
24
+ for sid in space_ids:
25
+ try:
26
+ print(f"[info] probing {sid}")
27
+ c = Client(sid, hf_token=os.getenv("HF_TOKEN"))
28
+ _ = c.predict("ping", 8, api_name="/chat")
29
+ print(f"[info] using {sid}")
30
+ return c
31
+ except Exception as e:
32
+ print(f"[warn] {sid} unusable β†’ {e}")
33
+ raise RuntimeError("No live chat Space found!")
34
+
35
+ llm_client = first_live_space(LLM_SPACES)
36
+ CHAT_API = "/chat"
37
+
38
+ def call_llm(prompt: str, max_tokens: int = 256, temperature: float = 0.6, top_p: float = 0.9) -> str:
39
+ try:
40
+ return llm_client.predict(
41
+ prompt, max_tokens, temperature, top_p, api_name=CHAT_API
42
+ ).strip()
43
+ except Exception as exc:
44
+ print(f"[error] LLM failure β†’ {exc}")
45
+ return "…"
46
+
47
+ # Datasets and prompt templates
48
+ ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
49
+ def random_persona() -> str:
50
+ return ds[random.randint(0, len(ds) - 1)]["persona"]
51
+
52
+ WORLD_PROMPT = (
53
+ "Invent a short, unique and vivid world description. Respond with the description only."
54
+ )
55
+ def random_world() -> str:
56
+ return call_llm(WORLD_PROMPT, max_tokens=120)
57
+
58
+ # Standard single character prompt (optional)
59
+ PROMPT_TEMPLATE = """Generate a character with this persona description:
60
+ {persona_description}
61
+ In a world with this description:
62
+ {world_description}
63
+ Write the character in JSON with keys:
64
+ name, background, appearance, personality, skills_and_abilities, goals, conflicts, backstory, current_situation, spoken_lines (list of strings).
65
+ Respond with JSON only (no markdown)."""
66
+
67
+ def generate_character(world_desc: str, persona_desc: str, progress=gr.Progress(track_tqdm=True)):
68
+ raw = call_llm(
69
+ PROMPT_TEMPLATE.format(
70
+ persona_description=persona_desc,
71
+ world_description=world_desc,
72
+ ),
73
+ max_tokens=1024,
74
+ )
75
+ try:
76
+ return json.loads(raw)
77
+ except json.JSONDecodeError:
78
+ raw = call_llm(
79
+ PROMPT_TEMPLATE.format(
80
+ persona_description=persona_desc,
81
+ world_description=world_desc,
82
+ ),
83
+ max_tokens=1024,
84
+ )
85
+ return json.loads(raw)
86
+
87
+ # Chaining (connected characters)
88
+ CHAIN_PROMPT_TEMPLATE = """
89
+ You are crafting an interconnected character ensemble for a shared world.
90
+
91
+ [WORLD]:
92
+ {world_description}
93
+
94
+ [PRIMARY PERSONA]:
95
+ {primary_persona}
96
+
97
+ Generate 3 interconnected character JSON profiles:
98
+ 1. PROTAGONIST: A compelling lead based on the given persona
99
+ 2. ALLY or FOIL: A character closely linked to the protagonist, either as support or contrast
100
+ 3. NEMESIS: A rival or antagonist with clashing philosophy, history, or goals
101
+
102
+ Each character must include:
103
+ - name
104
+ - role (protagonist, ally, nemesis)
105
+ - appearance
106
+ - background
107
+ - personality
108
+ - shared_history (relation to other character)
109
+ - goals
110
+ - conflicts
111
+ - current_situation
112
+ - spoken_lines (3 lines of dialogue)
113
+ Respond with pure JSON array.
114
+ """
115
+
116
+ def generate_connected_characters(world_desc: str, persona_desc: str, progress=gr.Progress(track_tqdm=True)):
117
+ raw = call_llm(
118
+ CHAIN_PROMPT_TEMPLATE.format(
119
+ world_description=world_desc,
120
+ primary_persona=persona_desc
121
+ ),
122
+ max_tokens=2048
123
+ )
124
+ try:
125
+ return json.loads(raw)
126
+ except json.JSONDecodeError:
127
+ raw = call_llm(
128
+ CHAIN_PROMPT_TEMPLATE.format(
129
+ world_description=world_desc,
130
+ primary_persona=persona_desc
131
+ ),
132
+ max_tokens=2048
133
+ )
134
+ return json.loads(raw)
135
+
136
+ # Gradio UI
137
+ DESCRIPTION = """
138
+ * Generates a trio of connected character sheets for a world + persona.
139
+ * Images via **FLUX-dev**; story text via Zephyr-chat or Gemma fallback.
140
+ * Personas sampled from **FinePersonas-Lite**.
141
+ Tip β†’ Shuffle the world then persona for rapid inspiration.
142
+ """
143
+
144
+ with gr.Blocks(title="Connected Character Chain Generator", theme="Nymbo/Nymbo_Theme") as demo:
145
+ gr.Markdown("<h1 style='text-align:center'>🧬 Connected Character Chain Generator</h1>")
146
+ gr.Markdown(DESCRIPTION.strip())
147
+
148
+ with gr.Row():
149
+ world_tb = gr.Textbox(label="World Description", lines=10, scale=4)
150
+ persona_tb = gr.Textbox(
151
+ label="Persona Description", value=random_persona(), lines=10, scale=1
152
+ )
153
+
154
+ with gr.Row():
155
+ btn_world = gr.Button("πŸ”„ Random World", variant="secondary")
156
+ btn_generate = gr.Button("✨ Generate Character", variant="primary", scale=5)
157
+ btn_persona = gr.Button("πŸ”„ Random Persona", variant="secondary")
158
+ btn_chain = gr.Button("🧬 Chain Characters", variant="secondary")
159
+
160
+ with gr.Row():
161
+ img_out = gr.Image(label="Character Image")
162
+ json_out = gr.JSON(label="Character Description")
163
+ chained_out = gr.JSON(label="Connected Characters (Protagonist, Ally, Nemesis)")
164
+
165
+ btn_generate.click(
166
+ generate_character, [world_tb, persona_tb], [json_out]
167
+ ).then(
168
+ lambda character: infer_flux(character), [json_out], [img_out]
169
+ )
170
+ btn_chain.click(
171
+ generate_connected_characters, [world_tb, persona_tb], [chained_out]
172
+ )
173
+
174
+ btn_world.click(random_world, outputs=[world_tb])
175
+ btn_persona.click(random_persona, outputs=[persona_tb])
176
+
177
+ demo.queue().launch(share=False)
live_preview_helpers.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # live_preview_helpers.py
2
+
3
+ import torch
4
+ from typing import Iterator
5
+
6
+ def flux_pipe_call_that_returns_an_iterable_of_images(
7
+ self,
8
+ prompt: str,
9
+ guidance_scale: float = 3.5,
10
+ num_inference_steps: int = 28,
11
+ width: int = 1024,
12
+ height: int = 1024,
13
+ generator=None,
14
+ output_type: str = "pil",
15
+ good_vae=None,
16
+ ) -> Iterator:
17
+ """
18
+ Streams an iterable of images as generated by the FLUX pipeline, for use with Gradio's live preview.
19
+ """
20
+ # You may use your actual FLUX pipeline API if different.
21
+ pipe = self # usually the FLUX pipeline instance
22
+
23
+ if generator is None:
24
+ generator = torch.Generator(device="cpu").manual_seed(0)
25
+ images = pipe(
26
+ prompt=prompt,
27
+ guidance_scale=guidance_scale,
28
+ num_inference_steps=num_inference_steps,
29
+ width=width,
30
+ height=height,
31
+ generator=generator,
32
+ output_type=output_type,
33
+ vae=good_vae,
34
+ ).images
35
+
36
+ for img in images:
37
+ yield img