character-chain / app.py
Jensin's picture
Updated app.py
0e41fe6
from datasets import load_dataset
import gradio as gr, json, os, random, torch, spaces
from diffusers import StableDiffusionPipeline, AutoencoderKL
from gradio_client import Client
from live_preview_helpers import (
flux_pipe_call_that_returns_an_iterable_of_images as flux_iter,
)
# Device and dtype selection
USE_CUDA = torch.cuda.is_available()
DTYPE = torch.float16 if USE_CUDA else torch.float32
device = torch.device("cuda" if USE_CUDA else "cpu")
# PUBLIC Stable Diffusion pipeline setup
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=DTYPE
).to(device)
good_vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="vae", torch_dtype=DTYPE
).to(device)
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_iter.__get__(pipe)
# LLM client config (Zephyr or Gemma fallback)
LLM_SPACES = [
"https://huggingfaceh4-zephyr-chat.hf.space",
"huggingface-projects/gemma-2-9b-it",
]
def first_live_space(space_ids):
for sid in space_ids:
try:
print(f"[info] probing {sid}")
c = Client(sid, hf_token=os.getenv("HF_TOKEN"))
_ = c.predict("ping", 8, api_name="/chat")
print(f"[info] using {sid}")
return c
except Exception as e:
print(f"[warn] {sid} unusable β†’ {e}")
print("[warn] No live chat Space found; falling back to local responses.")
return None
llm_client = first_live_space(LLM_SPACES)
CHAT_API = "/chat"
def call_llm(prompt, max_tokens=256, temperature=0.6, top_p=0.9):
if llm_client is not None:
try:
return llm_client.predict(
prompt, max_tokens, temperature, top_p, api_name=CHAT_API
).strip()
except Exception as exc:
print(f"[error] LLM failure β†’ {exc}")
# Local fallback if no LLM API available
print("[warn] Returning local fallback response.")
return "No LLM API available. Please enter your own text."
# Datasets and prompt templates
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
def random_persona():
return ds[random.randint(0, len(ds) - 1)]["persona"]
WORLD_PROMPT = (
"Invent a short, unique and vivid world description. Respond with the description only."
)
def random_world():
return call_llm(WORLD_PROMPT, max_tokens=120)
PROMPT_TEMPLATE = """Generate a character with this persona description:
{persona_description}
In a world with this description:
{world_description}
Write the character in JSON with keys:
name, background, appearance, personality, skills_and_abilities, goals, conflicts, backstory, current_situation, spoken_lines (list of strings).
Respond with JSON only (no markdown)."""
def generate_character(world_desc, persona_desc, progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
try:
return json.loads(raw)
except Exception:
# Fallback for user input/manual override
return {"name": "Unnamed", "appearance": "Manual entry required."}
CHAIN_PROMPT_TEMPLATE = """
You are crafting an interconnected character ensemble for a shared world.
[WORLD]:
{world_description}
[PRIMARY PERSONA]:
{primary_persona}
Generate 3 interconnected character JSON profiles:
1. PROTAGONIST: A compelling lead based on the given persona
2. ALLY or FOIL: A character closely linked to the protagonist, either as support or contrast
3. NEMESIS: A rival or antagonist with clashing philosophy, history, or goals
Each character must include:
- name
- role (protagonist, ally, nemesis)
- appearance
- background
- personality
- shared_history (relation to other character)
- goals
- conflicts
- current_situation
- spoken_lines (3 lines of dialogue)
Respond with pure JSON array.
"""
def generate_connected_characters(world_desc, persona_desc, progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
CHAIN_PROMPT_TEMPLATE.format(
world_description=world_desc,
primary_persona=persona_desc
),
max_tokens=2048
)
try:
return json.loads(raw)
except Exception:
# Fallback for user input/manual override
return [{"name": "Unnamed Protagonist", "role": "protagonist"},
{"name": "Unnamed Ally", "role": "ally"},
{"name": "Unnamed Nemesis", "role": "nemesis"}]
# Gradio UI
DESCRIPTION = """
* Generates a trio of connected character sheets for a world + persona.
* Images via **Stable Diffusion**; story text via Zephyr-chat or Gemma fallback.
* Personas sampled from **FinePersonas-Lite**.
Tip β†’ Shuffle the world then persona for rapid inspiration.
"""
with gr.Blocks(title="Connected Character Chain Generator", theme="Nymbo/Nymbo_Theme") as demo:
gr.Markdown("<h1 style='text-align:center'>🧬 Connected Character Chain Generator</h1>")
gr.Markdown(DESCRIPTION.strip())
with gr.Row():
world_tb = gr.Textbox(label="World Description", lines=10, scale=4)
persona_tb = gr.Textbox(
label="Persona Description", value=random_persona(), lines=10, scale=1
)
with gr.Row():
btn_world = gr.Button("πŸ”„ Random World", variant="secondary")
btn_generate = gr.Button("✨ Generate Character", variant="primary", scale=5)
btn_persona = gr.Button("πŸ”„ Random Persona", variant="secondary")
btn_chain = gr.Button("🧬 Chain Characters", variant="secondary")
with gr.Row():
img_out = gr.Image(label="Character Image")
json_out = gr.JSON(label="Character Description")
chained_out = gr.JSON(label="Connected Characters (Protagonist, Ally, Nemesis)")
def sd_image_from_character(character):
# Use appearance or fallback if needed
prompt = character.get("appearance", "A unique portrait, digital art, fantasy character, 4k")
return next(pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt,
guidance_scale=7.5,
num_inference_steps=25,
width=512,
height=512,
output_type="pil"
))
btn_generate.click(
generate_character, [world_tb, persona_tb], [json_out]
).then(
sd_image_from_character, [json_out], [img_out]
)
btn_chain.click(
generate_connected_characters, [world_tb, persona_tb], [chained_out]
)
btn_world.click(random_world, outputs=[world_tb])
btn_persona.click(random_persona, outputs=[persona_tb])
demo.queue().launch(share=True)