File size: 6,689 Bytes
d154ff2
 
2fa993f
d154ff2
 
 
 
61cba9d
32e5e92
 
 
 
2fa993f
 
 
32e5e92
d154ff2
 
32e5e92
d154ff2
 
61cba9d
2fa993f
d154ff2
 
 
 
32e5e92
2fa993f
d154ff2
 
 
 
 
 
 
 
 
32e5e92
 
 
d154ff2
2fa993f
32e5e92
2fa993f
32e5e92
 
 
 
 
 
 
 
 
 
d154ff2
 
 
2fa993f
d154ff2
32e5e92
d154ff2
 
 
2fa993f
d154ff2
 
 
 
 
 
 
 
 
 
2fa993f
d154ff2
 
 
 
 
 
 
 
 
32e5e92
 
 
d154ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2fa993f
d154ff2
 
 
 
 
 
 
 
 
32e5e92
 
 
 
 
d154ff2
 
 
 
2fa993f
d154ff2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32e5e92
 
 
 
 
 
 
 
 
 
 
 
d154ff2
 
 
32e5e92
d154ff2
 
 
 
 
 
 
 
0e41fe6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
from datasets import load_dataset
import gradio as gr, json, os, random, torch, spaces
from diffusers import StableDiffusionPipeline, AutoencoderKL
from gradio_client import Client
from live_preview_helpers import (
    flux_pipe_call_that_returns_an_iterable_of_images as flux_iter,
)

# Device and dtype selection
USE_CUDA = torch.cuda.is_available()
DTYPE = torch.float16 if USE_CUDA else torch.float32
device = torch.device("cuda" if USE_CUDA else "cpu")

# PUBLIC Stable Diffusion pipeline setup
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", torch_dtype=DTYPE
).to(device)
good_vae = AutoencoderKL.from_pretrained(
    "runwayml/stable-diffusion-v1-5", subfolder="vae", torch_dtype=DTYPE
).to(device)
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_iter.__get__(pipe)

# LLM client config (Zephyr or Gemma fallback)
LLM_SPACES = [
    "https://huggingfaceh4-zephyr-chat.hf.space",
    "huggingface-projects/gemma-2-9b-it",
]

def first_live_space(space_ids):
    for sid in space_ids:
        try:
            print(f"[info] probing {sid}")
            c = Client(sid, hf_token=os.getenv("HF_TOKEN"))
            _ = c.predict("ping", 8, api_name="/chat")
            print(f"[info] using {sid}")
            return c
        except Exception as e:
            print(f"[warn] {sid} unusable β†’ {e}")
    print("[warn] No live chat Space found; falling back to local responses.")
    return None

llm_client = first_live_space(LLM_SPACES)
CHAT_API = "/chat"

def call_llm(prompt, max_tokens=256, temperature=0.6, top_p=0.9):
    if llm_client is not None:
        try:
            return llm_client.predict(
                prompt, max_tokens, temperature, top_p, api_name=CHAT_API
            ).strip()
        except Exception as exc:
            print(f"[error] LLM failure β†’ {exc}")
    # Local fallback if no LLM API available
    print("[warn] Returning local fallback response.")
    return "No LLM API available. Please enter your own text."

# Datasets and prompt templates
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
def random_persona():
    return ds[random.randint(0, len(ds) - 1)]["persona"]

WORLD_PROMPT = (
    "Invent a short, unique and vivid world description. Respond with the description only."
)
def random_world():
    return call_llm(WORLD_PROMPT, max_tokens=120)

PROMPT_TEMPLATE = """Generate a character with this persona description:
{persona_description}
In a world with this description:
{world_description}
Write the character in JSON with keys:
name, background, appearance, personality, skills_and_abilities, goals, conflicts, backstory, current_situation, spoken_lines (list of strings).
Respond with JSON only (no markdown)."""

def generate_character(world_desc, persona_desc, progress=gr.Progress(track_tqdm=True)):
    raw = call_llm(
        PROMPT_TEMPLATE.format(
            persona_description=persona_desc,
            world_description=world_desc,
        ),
        max_tokens=1024,
    )
    try:
        return json.loads(raw)
    except Exception:
        # Fallback for user input/manual override
        return {"name": "Unnamed", "appearance": "Manual entry required."}

CHAIN_PROMPT_TEMPLATE = """
You are crafting an interconnected character ensemble for a shared world.

[WORLD]:
{world_description}

[PRIMARY PERSONA]:
{primary_persona}

Generate 3 interconnected character JSON profiles:
1. PROTAGONIST: A compelling lead based on the given persona
2. ALLY or FOIL: A character closely linked to the protagonist, either as support or contrast
3. NEMESIS: A rival or antagonist with clashing philosophy, history, or goals

Each character must include:
- name
- role (protagonist, ally, nemesis)
- appearance
- background
- personality
- shared_history (relation to other character)
- goals
- conflicts
- current_situation
- spoken_lines (3 lines of dialogue)
Respond with pure JSON array.
"""

def generate_connected_characters(world_desc, persona_desc, progress=gr.Progress(track_tqdm=True)):
    raw = call_llm(
        CHAIN_PROMPT_TEMPLATE.format(
            world_description=world_desc,
            primary_persona=persona_desc
        ),
        max_tokens=2048
    )
    try:
        return json.loads(raw)
    except Exception:
        # Fallback for user input/manual override
        return [{"name": "Unnamed Protagonist", "role": "protagonist"},
                {"name": "Unnamed Ally", "role": "ally"},
                {"name": "Unnamed Nemesis", "role": "nemesis"}]

# Gradio UI
DESCRIPTION = """
* Generates a trio of connected character sheets for a world + persona.
* Images via **Stable Diffusion**; story text via Zephyr-chat or Gemma fallback.
* Personas sampled from **FinePersonas-Lite**.
Tip β†’ Shuffle the world then persona for rapid inspiration.
"""

with gr.Blocks(title="Connected Character Chain Generator", theme="Nymbo/Nymbo_Theme") as demo:
    gr.Markdown("<h1 style='text-align:center'>🧬 Connected Character Chain Generator</h1>")
    gr.Markdown(DESCRIPTION.strip())

    with gr.Row():
        world_tb = gr.Textbox(label="World Description", lines=10, scale=4)
        persona_tb = gr.Textbox(
            label="Persona Description", value=random_persona(), lines=10, scale=1
        )

    with gr.Row():
        btn_world   = gr.Button("πŸ”„ Random World", variant="secondary")
        btn_generate = gr.Button("✨ Generate Character", variant="primary", scale=5)
        btn_persona = gr.Button("πŸ”„ Random Persona", variant="secondary")
        btn_chain   = gr.Button("🧬 Chain Characters", variant="secondary")

    with gr.Row():
        img_out  = gr.Image(label="Character Image")
        json_out = gr.JSON(label="Character Description")
    chained_out = gr.JSON(label="Connected Characters (Protagonist, Ally, Nemesis)")

    def sd_image_from_character(character):
        # Use appearance or fallback if needed
        prompt = character.get("appearance", "A unique portrait, digital art, fantasy character, 4k")
        return next(pipe.flux_pipe_call_that_returns_an_iterable_of_images(
            prompt=prompt,
            guidance_scale=7.5,
            num_inference_steps=25,
            width=512,
            height=512,
            output_type="pil"
        ))

    btn_generate.click(
        generate_character, [world_tb, persona_tb], [json_out]
    ).then(
        sd_image_from_character, [json_out], [img_out]
    )
    btn_chain.click(
        generate_connected_characters, [world_tb, persona_tb], [chained_out]
    )

    btn_world.click(random_world, outputs=[world_tb])
    btn_persona.click(random_persona, outputs=[persona_tb])

demo.queue().launch(share=True)