Spaces:
Runtime error
Runtime error
File size: 6,689 Bytes
d154ff2 2fa993f d154ff2 61cba9d 32e5e92 2fa993f 32e5e92 d154ff2 32e5e92 d154ff2 61cba9d 2fa993f d154ff2 32e5e92 2fa993f d154ff2 32e5e92 d154ff2 2fa993f 32e5e92 2fa993f 32e5e92 d154ff2 2fa993f d154ff2 32e5e92 d154ff2 2fa993f d154ff2 2fa993f d154ff2 32e5e92 d154ff2 2fa993f d154ff2 32e5e92 d154ff2 2fa993f d154ff2 32e5e92 d154ff2 32e5e92 d154ff2 0e41fe6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
from datasets import load_dataset
import gradio as gr, json, os, random, torch, spaces
from diffusers import StableDiffusionPipeline, AutoencoderKL
from gradio_client import Client
from live_preview_helpers import (
flux_pipe_call_that_returns_an_iterable_of_images as flux_iter,
)
# Device and dtype selection
USE_CUDA = torch.cuda.is_available()
DTYPE = torch.float16 if USE_CUDA else torch.float32
device = torch.device("cuda" if USE_CUDA else "cpu")
# PUBLIC Stable Diffusion pipeline setup
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=DTYPE
).to(device)
good_vae = AutoencoderKL.from_pretrained(
"runwayml/stable-diffusion-v1-5", subfolder="vae", torch_dtype=DTYPE
).to(device)
pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_iter.__get__(pipe)
# LLM client config (Zephyr or Gemma fallback)
LLM_SPACES = [
"https://huggingfaceh4-zephyr-chat.hf.space",
"huggingface-projects/gemma-2-9b-it",
]
def first_live_space(space_ids):
for sid in space_ids:
try:
print(f"[info] probing {sid}")
c = Client(sid, hf_token=os.getenv("HF_TOKEN"))
_ = c.predict("ping", 8, api_name="/chat")
print(f"[info] using {sid}")
return c
except Exception as e:
print(f"[warn] {sid} unusable β {e}")
print("[warn] No live chat Space found; falling back to local responses.")
return None
llm_client = first_live_space(LLM_SPACES)
CHAT_API = "/chat"
def call_llm(prompt, max_tokens=256, temperature=0.6, top_p=0.9):
if llm_client is not None:
try:
return llm_client.predict(
prompt, max_tokens, temperature, top_p, api_name=CHAT_API
).strip()
except Exception as exc:
print(f"[error] LLM failure β {exc}")
# Local fallback if no LLM API available
print("[warn] Returning local fallback response.")
return "No LLM API available. Please enter your own text."
# Datasets and prompt templates
ds = load_dataset("MohamedRashad/FinePersonas-Lite", split="train")
def random_persona():
return ds[random.randint(0, len(ds) - 1)]["persona"]
WORLD_PROMPT = (
"Invent a short, unique and vivid world description. Respond with the description only."
)
def random_world():
return call_llm(WORLD_PROMPT, max_tokens=120)
PROMPT_TEMPLATE = """Generate a character with this persona description:
{persona_description}
In a world with this description:
{world_description}
Write the character in JSON with keys:
name, background, appearance, personality, skills_and_abilities, goals, conflicts, backstory, current_situation, spoken_lines (list of strings).
Respond with JSON only (no markdown)."""
def generate_character(world_desc, persona_desc, progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
PROMPT_TEMPLATE.format(
persona_description=persona_desc,
world_description=world_desc,
),
max_tokens=1024,
)
try:
return json.loads(raw)
except Exception:
# Fallback for user input/manual override
return {"name": "Unnamed", "appearance": "Manual entry required."}
CHAIN_PROMPT_TEMPLATE = """
You are crafting an interconnected character ensemble for a shared world.
[WORLD]:
{world_description}
[PRIMARY PERSONA]:
{primary_persona}
Generate 3 interconnected character JSON profiles:
1. PROTAGONIST: A compelling lead based on the given persona
2. ALLY or FOIL: A character closely linked to the protagonist, either as support or contrast
3. NEMESIS: A rival or antagonist with clashing philosophy, history, or goals
Each character must include:
- name
- role (protagonist, ally, nemesis)
- appearance
- background
- personality
- shared_history (relation to other character)
- goals
- conflicts
- current_situation
- spoken_lines (3 lines of dialogue)
Respond with pure JSON array.
"""
def generate_connected_characters(world_desc, persona_desc, progress=gr.Progress(track_tqdm=True)):
raw = call_llm(
CHAIN_PROMPT_TEMPLATE.format(
world_description=world_desc,
primary_persona=persona_desc
),
max_tokens=2048
)
try:
return json.loads(raw)
except Exception:
# Fallback for user input/manual override
return [{"name": "Unnamed Protagonist", "role": "protagonist"},
{"name": "Unnamed Ally", "role": "ally"},
{"name": "Unnamed Nemesis", "role": "nemesis"}]
# Gradio UI
DESCRIPTION = """
* Generates a trio of connected character sheets for a world + persona.
* Images via **Stable Diffusion**; story text via Zephyr-chat or Gemma fallback.
* Personas sampled from **FinePersonas-Lite**.
Tip β Shuffle the world then persona for rapid inspiration.
"""
with gr.Blocks(title="Connected Character Chain Generator", theme="Nymbo/Nymbo_Theme") as demo:
gr.Markdown("<h1 style='text-align:center'>𧬠Connected Character Chain Generator</h1>")
gr.Markdown(DESCRIPTION.strip())
with gr.Row():
world_tb = gr.Textbox(label="World Description", lines=10, scale=4)
persona_tb = gr.Textbox(
label="Persona Description", value=random_persona(), lines=10, scale=1
)
with gr.Row():
btn_world = gr.Button("π Random World", variant="secondary")
btn_generate = gr.Button("β¨ Generate Character", variant="primary", scale=5)
btn_persona = gr.Button("π Random Persona", variant="secondary")
btn_chain = gr.Button("𧬠Chain Characters", variant="secondary")
with gr.Row():
img_out = gr.Image(label="Character Image")
json_out = gr.JSON(label="Character Description")
chained_out = gr.JSON(label="Connected Characters (Protagonist, Ally, Nemesis)")
def sd_image_from_character(character):
# Use appearance or fallback if needed
prompt = character.get("appearance", "A unique portrait, digital art, fantasy character, 4k")
return next(pipe.flux_pipe_call_that_returns_an_iterable_of_images(
prompt=prompt,
guidance_scale=7.5,
num_inference_steps=25,
width=512,
height=512,
output_type="pil"
))
btn_generate.click(
generate_character, [world_tb, persona_tb], [json_out]
).then(
sd_image_from_character, [json_out], [img_out]
)
btn_chain.click(
generate_connected_characters, [world_tb, persona_tb], [chained_out]
)
btn_world.click(random_world, outputs=[world_tb])
btn_persona.click(random_persona, outputs=[persona_tb])
demo.queue().launch(share=True)
|