character-chain / app.py
Jensin's picture
"modified app.py and added new file"
d154ff2
raw
history blame
6.12 kB
from datasets import load_dataset
import gradio as gr, json, os, random, torch, spaces
from diffusers import FluxPipeline, AutoencoderKL
from gradio_client import Client
from live_preview_helpers import (
flux_pipe_call_that_returns_an_iterable_of_images as flux_iter,
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
good_vae = AutoencoderKL.from_pretrained(
"black-forest-labs/FLUX.1-dev", subfolder="vae", torch_dtype=torch.bfloat16
).to(device)
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_iter.__get__(pipe)
LLM_SPACES = [
"https://huggingfaceh4-zephyr-chat.hf.space",
"huggingface-projects/gemma-2-9b-it",
]
def first_live_space(space_ids: list[str]) -> Client:
for sid in space_ids:
try:
print(f"[info] probing {sid}")
c = Client(sid, hf_token=os.getenv("HF_TOKEN"))
_ = c.predict("ping", 8, api_name="/chat")
print(f"[info] using {sid}")
return c
except Exception as e:
print(f"[warn] {sid} unusable β†’ {e}")
raise RuntimeError("No live chat Space found!")
llm_client = first_live_space(LLM_SPACES)
CHAT_API = "/chat"
def call_llm(prompt: str, max_tokens: int = 256, temperature: float = 0.6, top_p: float = 0.9) -> str:
try:
return llm_client.predict(
prompt, max_tokens, temperature, top_p, api_name=CHAT_API
).strip()
except Exception as exc:
print(f"[error] LLM failure β†’ {exc}")
return "…"
# Datasets and prompt templates
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
def random_persona() -> str:
return ds[random.randint(0, len(ds) - 1)]["persona"]
WORLD_PROMPT = (
"Invent a short, unique and vivid world description. Respond with the description only."
)
def random_world() -> str:
return call_llm(WORLD_PROMPT, max_tokens=120)
# Standard single character prompt (optional)
PROMPT_TEMPLATE = """Generate a character with this persona description:
{persona_description}
In a world with this description:
{world_description}
Write the character in JSON with keys:
name, background, appearance, personality, skills_and_abilities, goals, conflicts, backstory, current_situation, spoken_lines (list of strings).
Respond with JSON only (no markdown)."""
def generate_character(world_desc: str, persona_desc: str, progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
try:
return json.loads(raw)
except json.JSONDecodeError:
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
return json.loads(raw)
# Chaining (connected characters)
CHAIN_PROMPT_TEMPLATE = """
You are crafting an interconnected character ensemble for a shared world.
[WORLD]:
{world_description}
[PRIMARY PERSONA]:
{primary_persona}
Generate 3 interconnected character JSON profiles:
1. PROTAGONIST: A compelling lead based on the given persona
2. ALLY or FOIL: A character closely linked to the protagonist, either as support or contrast
3. NEMESIS: A rival or antagonist with clashing philosophy, history, or goals
Each character must include:
- name
- role (protagonist, ally, nemesis)
- appearance
- background
- personality
- shared_history (relation to other character)
- goals
- conflicts
- current_situation
- spoken_lines (3 lines of dialogue)
Respond with pure JSON array.
"""
def generate_connected_characters(world_desc: str, persona_desc: str, progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
CHAIN_PROMPT_TEMPLATE.format(
world_description=world_desc,
primary_persona=persona_desc
),
max_tokens=2048
)
try:
return json.loads(raw)
except json.JSONDecodeError:
raw = call_llm(
CHAIN_PROMPT_TEMPLATE.format(
world_description=world_desc,
primary_persona=persona_desc
),
max_tokens=2048
)
return json.loads(raw)
# Gradio UI
DESCRIPTION = """
* Generates a trio of connected character sheets for a world + persona.
* Images via **FLUX-dev**; story text via Zephyr-chat or Gemma fallback.
* Personas sampled from **FinePersonas-Lite**.
Tip β†’ Shuffle the world then persona for rapid inspiration.
"""
with gr.Blocks(title="Connected Character Chain Generator", theme="Nymbo/Nymbo_Theme") as demo:
gr.Markdown("<h1 style='text-align:center'>🧬 Connected Character Chain Generator</h1>")
gr.Markdown(DESCRIPTION.strip())
with gr.Row():
world_tb = gr.Textbox(label="World Description", lines=10, scale=4)
persona_tb = gr.Textbox(
label="Persona Description", value=random_persona(), lines=10, scale=1
)
with gr.Row():
btn_world = gr.Button("πŸ”„ Random World", variant="secondary")
btn_generate = gr.Button("✨ Generate Character", variant="primary", scale=5)
btn_persona = gr.Button("πŸ”„ Random Persona", variant="secondary")
btn_chain = gr.Button("🧬 Chain Characters", variant="secondary")
with gr.Row():
img_out = gr.Image(label="Character Image")
json_out = gr.JSON(label="Character Description")
chained_out = gr.JSON(label="Connected Characters (Protagonist, Ally, Nemesis)")
btn_generate.click(
generate_character, [world_tb, persona_tb], [json_out]
).then(
lambda character: infer_flux(character), [json_out], [img_out]
)
btn_chain.click(
generate_connected_characters, [world_tb, persona_tb], [chained_out]
)
btn_world.click(random_world, outputs=[world_tb])
btn_persona.click(random_persona, outputs=[persona_tb])
demo.queue().launch(share=False)