Spaces:
Running
Running
update to add provider selection
Browse files
app.py
CHANGED
|
@@ -227,10 +227,11 @@ DEMO_LIST = [
|
|
| 227 |
|
| 228 |
# HF Inference Client
|
| 229 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
|
|
|
|
|
|
| 230 |
|
| 231 |
-
def get_inference_client(model_id):
|
| 232 |
-
"""Return an InferenceClient with provider based on model_id."""
|
| 233 |
-
provider = "groq" if model_id == "moonshotai/Kimi-K2-Instruct" else "auto"
|
| 234 |
return InferenceClient(
|
| 235 |
provider=provider,
|
| 236 |
api_key=HF_TOKEN,
|
|
@@ -940,20 +941,24 @@ The HTML code above contains the complete original website structure with all im
|
|
| 940 |
except Exception as e:
|
| 941 |
return f"Error extracting website content: {str(e)}"
|
| 942 |
|
| 943 |
-
def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[History], _current_model: Dict, enable_search: bool = False, language: str = "html"):
|
| 944 |
if query is None:
|
| 945 |
query = ''
|
| 946 |
if _history is None:
|
| 947 |
_history = []
|
| 948 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 949 |
# Check if there's existing HTML content in history to determine if this is a modification request
|
| 950 |
has_existing_html = False
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
last_assistant_msg = _history[-1][1]
|
| 954 |
if '<!DOCTYPE html>' in last_assistant_msg or '<html' in last_assistant_msg:
|
| 955 |
has_existing_html = True
|
| 956 |
-
|
| 957 |
# Choose system prompt based on context
|
| 958 |
if has_existing_html:
|
| 959 |
# Use follow-up prompt for modifying existing HTML
|
|
@@ -964,9 +969,9 @@ def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optio
|
|
| 964 |
system_prompt = HTML_SYSTEM_PROMPT_WITH_SEARCH if enable_search else HTML_SYSTEM_PROMPT
|
| 965 |
else:
|
| 966 |
system_prompt = GENERIC_SYSTEM_PROMPT_WITH_SEARCH.format(language=language) if enable_search else GENERIC_SYSTEM_PROMPT.format(language=language)
|
| 967 |
-
|
| 968 |
messages = history_to_messages(_history, system_prompt)
|
| 969 |
-
|
| 970 |
# Extract file text and append to query if file is present
|
| 971 |
file_text = ""
|
| 972 |
if file:
|
|
@@ -974,7 +979,7 @@ def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optio
|
|
| 974 |
if file_text:
|
| 975 |
file_text = file_text[:5000] # Limit to 5000 chars for prompt size
|
| 976 |
query = f"{query}\n\n[Reference file content below]\n{file_text}"
|
| 977 |
-
|
| 978 |
# Extract website content and append to query if website URL is present
|
| 979 |
website_text = ""
|
| 980 |
if website_url and website_url.strip():
|
|
@@ -994,12 +999,12 @@ Since I couldn't extract the website content, please provide additional details
|
|
| 994 |
|
| 995 |
This will help me create a better design for you."""
|
| 996 |
query = f"{query}\n\n[Error extracting website: {website_text}]{fallback_guidance}"
|
| 997 |
-
|
| 998 |
# Enhance query with search if enabled
|
| 999 |
enhanced_query = enhance_query_with_search(query, enable_search)
|
| 1000 |
-
|
| 1001 |
# Use dynamic client based on selected model
|
| 1002 |
-
client = get_inference_client(_current_model["id"])
|
| 1003 |
|
| 1004 |
if image is not None:
|
| 1005 |
messages.append(create_multimodal_message(enhanced_query, image))
|
|
@@ -1014,7 +1019,8 @@ This will help me create a better design for you."""
|
|
| 1014 |
)
|
| 1015 |
content = ""
|
| 1016 |
for chunk in completion:
|
| 1017 |
-
if chunk.choices
|
|
|
|
| 1018 |
content += chunk.choices[0].delta.content
|
| 1019 |
clean_code = remove_code_block(content)
|
| 1020 |
search_status = " (with web search)" if enable_search and tavily_client else ""
|
|
@@ -1027,7 +1033,7 @@ This will help me create a better design for you."""
|
|
| 1027 |
sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
|
| 1028 |
}
|
| 1029 |
else:
|
| 1030 |
-
last_html = _history[-1][1] if _history else ""
|
| 1031 |
modified_html = apply_search_replace_changes(last_html, clean_code)
|
| 1032 |
clean_html = remove_code_block(modified_html)
|
| 1033 |
yield {
|
|
@@ -1041,6 +1047,8 @@ This will help me create a better design for you."""
|
|
| 1041 |
history_output: history_to_chatbot_messages(_history),
|
| 1042 |
sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
|
| 1043 |
}
|
|
|
|
|
|
|
| 1044 |
# Handle response based on whether this is a modification or new generation
|
| 1045 |
if has_existing_html:
|
| 1046 |
# Fallback: If the model returns a full HTML file, use it directly
|
|
@@ -1048,14 +1056,11 @@ This will help me create a better design for you."""
|
|
| 1048 |
if final_code.strip().startswith("<!DOCTYPE html>") or final_code.strip().startswith("<html"):
|
| 1049 |
clean_html = final_code
|
| 1050 |
else:
|
| 1051 |
-
last_html = _history[-1][1] if _history else ""
|
| 1052 |
modified_html = apply_search_replace_changes(last_html, final_code)
|
| 1053 |
clean_html = remove_code_block(modified_html)
|
| 1054 |
# Update history with the cleaned HTML
|
| 1055 |
-
_history
|
| 1056 |
-
'role': 'assistant',
|
| 1057 |
-
'content': clean_html
|
| 1058 |
-
}])
|
| 1059 |
yield {
|
| 1060 |
code_output: clean_html,
|
| 1061 |
history: _history,
|
|
@@ -1064,10 +1069,7 @@ This will help me create a better design for you."""
|
|
| 1064 |
}
|
| 1065 |
else:
|
| 1066 |
# Regular generation - use the content as is
|
| 1067 |
-
_history
|
| 1068 |
-
'role': 'assistant',
|
| 1069 |
-
'content': content
|
| 1070 |
-
}])
|
| 1071 |
yield {
|
| 1072 |
code_output: remove_code_block(content),
|
| 1073 |
history: _history,
|
|
@@ -1156,6 +1158,16 @@ with gr.Blocks(
|
|
| 1156 |
label="Model",
|
| 1157 |
visible=True # Always visible
|
| 1158 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1159 |
gr.Markdown("**Quick start**", visible=True)
|
| 1160 |
with gr.Column(visible=True) as quick_examples_col:
|
| 1161 |
for i, demo_item in enumerate(DEMO_LIST[:3]):
|
|
@@ -1251,7 +1263,7 @@ with gr.Blocks(
|
|
| 1251 |
|
| 1252 |
btn.click(
|
| 1253 |
generation_code,
|
| 1254 |
-
inputs=[input, image_input, file_input, website_url_input, setting, history, current_model, search_toggle, language_dropdown],
|
| 1255 |
outputs=[code_output, history, sandbox, history_output]
|
| 1256 |
)
|
| 1257 |
# Update preview when code or language changes
|
|
@@ -1259,5 +1271,14 @@ with gr.Blocks(
|
|
| 1259 |
language_dropdown.change(preview_logic, inputs=[code_output, language_dropdown], outputs=sandbox)
|
| 1260 |
clear_btn.click(clear_history, outputs=[history, history_output, file_input, website_url_input])
|
| 1261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1262 |
if __name__ == "__main__":
|
| 1263 |
demo.queue(api_open=False, default_concurrency_limit=20).launch(ssr_mode=True, mcp_server=False, show_api=False)
|
|
|
|
| 227 |
|
| 228 |
# HF Inference Client
|
| 229 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
| 230 |
+
if not HF_TOKEN:
|
| 231 |
+
raise RuntimeError("HF_TOKEN environment variable is not set. Please set it to your Hugging Face API token.")
|
| 232 |
|
| 233 |
+
def get_inference_client(model_id, provider="auto"):
|
| 234 |
+
"""Return an InferenceClient with provider based on model_id and user selection."""
|
|
|
|
| 235 |
return InferenceClient(
|
| 236 |
provider=provider,
|
| 237 |
api_key=HF_TOKEN,
|
|
|
|
| 941 |
except Exception as e:
|
| 942 |
return f"Error extracting website content: {str(e)}"
|
| 943 |
|
| 944 |
+
def generation_code(query: Optional[str], image: Optional[gr.Image], file: Optional[str], website_url: Optional[str], _setting: Dict[str, str], _history: Optional[History], _current_model: Dict, enable_search: bool = False, language: str = "html", provider: str = "auto"):
|
| 945 |
if query is None:
|
| 946 |
query = ''
|
| 947 |
if _history is None:
|
| 948 |
_history = []
|
| 949 |
+
# Ensure _history is always a list of lists with at least 2 elements per item
|
| 950 |
+
if not isinstance(_history, list):
|
| 951 |
+
_history = []
|
| 952 |
+
_history = [h for h in _history if isinstance(h, list) and len(h) == 2]
|
| 953 |
+
|
| 954 |
# Check if there's existing HTML content in history to determine if this is a modification request
|
| 955 |
has_existing_html = False
|
| 956 |
+
last_assistant_msg = ""
|
| 957 |
+
if _history and len(_history[-1]) > 1:
|
| 958 |
+
last_assistant_msg = _history[-1][1]
|
| 959 |
if '<!DOCTYPE html>' in last_assistant_msg or '<html' in last_assistant_msg:
|
| 960 |
has_existing_html = True
|
| 961 |
+
|
| 962 |
# Choose system prompt based on context
|
| 963 |
if has_existing_html:
|
| 964 |
# Use follow-up prompt for modifying existing HTML
|
|
|
|
| 969 |
system_prompt = HTML_SYSTEM_PROMPT_WITH_SEARCH if enable_search else HTML_SYSTEM_PROMPT
|
| 970 |
else:
|
| 971 |
system_prompt = GENERIC_SYSTEM_PROMPT_WITH_SEARCH.format(language=language) if enable_search else GENERIC_SYSTEM_PROMPT.format(language=language)
|
| 972 |
+
|
| 973 |
messages = history_to_messages(_history, system_prompt)
|
| 974 |
+
|
| 975 |
# Extract file text and append to query if file is present
|
| 976 |
file_text = ""
|
| 977 |
if file:
|
|
|
|
| 979 |
if file_text:
|
| 980 |
file_text = file_text[:5000] # Limit to 5000 chars for prompt size
|
| 981 |
query = f"{query}\n\n[Reference file content below]\n{file_text}"
|
| 982 |
+
|
| 983 |
# Extract website content and append to query if website URL is present
|
| 984 |
website_text = ""
|
| 985 |
if website_url and website_url.strip():
|
|
|
|
| 999 |
|
| 1000 |
This will help me create a better design for you."""
|
| 1001 |
query = f"{query}\n\n[Error extracting website: {website_text}]{fallback_guidance}"
|
| 1002 |
+
|
| 1003 |
# Enhance query with search if enabled
|
| 1004 |
enhanced_query = enhance_query_with_search(query, enable_search)
|
| 1005 |
+
|
| 1006 |
# Use dynamic client based on selected model
|
| 1007 |
+
client = get_inference_client(_current_model["id"], provider)
|
| 1008 |
|
| 1009 |
if image is not None:
|
| 1010 |
messages.append(create_multimodal_message(enhanced_query, image))
|
|
|
|
| 1019 |
)
|
| 1020 |
content = ""
|
| 1021 |
for chunk in completion:
|
| 1022 |
+
# Only process if chunk.choices is non-empty
|
| 1023 |
+
if hasattr(chunk, "choices") and chunk.choices and hasattr(chunk.choices[0], "delta") and hasattr(chunk.choices[0].delta, "content"):
|
| 1024 |
content += chunk.choices[0].delta.content
|
| 1025 |
clean_code = remove_code_block(content)
|
| 1026 |
search_status = " (with web search)" if enable_search and tavily_client else ""
|
|
|
|
| 1033 |
sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
|
| 1034 |
}
|
| 1035 |
else:
|
| 1036 |
+
last_html = _history[-1][1] if _history and len(_history[-1]) > 1 else ""
|
| 1037 |
modified_html = apply_search_replace_changes(last_html, clean_code)
|
| 1038 |
clean_html = remove_code_block(modified_html)
|
| 1039 |
yield {
|
|
|
|
| 1047 |
history_output: history_to_chatbot_messages(_history),
|
| 1048 |
sandbox: send_to_sandbox(clean_code) if language == "html" else "<div style='padding:1em;color:#888;text-align:center;'>Preview is only available for HTML. Please download your code using the download button above.</div>",
|
| 1049 |
}
|
| 1050 |
+
# Skip chunks with empty choices (end of stream)
|
| 1051 |
+
# Do not treat as error
|
| 1052 |
# Handle response based on whether this is a modification or new generation
|
| 1053 |
if has_existing_html:
|
| 1054 |
# Fallback: If the model returns a full HTML file, use it directly
|
|
|
|
| 1056 |
if final_code.strip().startswith("<!DOCTYPE html>") or final_code.strip().startswith("<html"):
|
| 1057 |
clean_html = final_code
|
| 1058 |
else:
|
| 1059 |
+
last_html = _history[-1][1] if _history and len(_history[-1]) > 1 else ""
|
| 1060 |
modified_html = apply_search_replace_changes(last_html, final_code)
|
| 1061 |
clean_html = remove_code_block(modified_html)
|
| 1062 |
# Update history with the cleaned HTML
|
| 1063 |
+
_history.append([query, clean_html])
|
|
|
|
|
|
|
|
|
|
| 1064 |
yield {
|
| 1065 |
code_output: clean_html,
|
| 1066 |
history: _history,
|
|
|
|
| 1069 |
}
|
| 1070 |
else:
|
| 1071 |
# Regular generation - use the content as is
|
| 1072 |
+
_history.append([query, content])
|
|
|
|
|
|
|
|
|
|
| 1073 |
yield {
|
| 1074 |
code_output: remove_code_block(content),
|
| 1075 |
history: _history,
|
|
|
|
| 1158 |
label="Model",
|
| 1159 |
visible=True # Always visible
|
| 1160 |
)
|
| 1161 |
+
provider_choices = [
|
| 1162 |
+
"auto", "black-forest-labs", "cerebras", "cohere", "fal-ai", "featherless-ai", "fireworks-ai", "groq", "hf-inference", "hyperbolic", "nebius", "novita", "nscale", "openai", "replicate", "sambanova", "together"
|
| 1163 |
+
]
|
| 1164 |
+
provider_dropdown = gr.Dropdown(
|
| 1165 |
+
choices=provider_choices,
|
| 1166 |
+
value="auto",
|
| 1167 |
+
label="Provider",
|
| 1168 |
+
visible=True
|
| 1169 |
+
)
|
| 1170 |
+
provider_state = gr.State("auto")
|
| 1171 |
gr.Markdown("**Quick start**", visible=True)
|
| 1172 |
with gr.Column(visible=True) as quick_examples_col:
|
| 1173 |
for i, demo_item in enumerate(DEMO_LIST[:3]):
|
|
|
|
| 1263 |
|
| 1264 |
btn.click(
|
| 1265 |
generation_code,
|
| 1266 |
+
inputs=[input, image_input, file_input, website_url_input, setting, history, current_model, search_toggle, language_dropdown, provider_state],
|
| 1267 |
outputs=[code_output, history, sandbox, history_output]
|
| 1268 |
)
|
| 1269 |
# Update preview when code or language changes
|
|
|
|
| 1271 |
language_dropdown.change(preview_logic, inputs=[code_output, language_dropdown], outputs=sandbox)
|
| 1272 |
clear_btn.click(clear_history, outputs=[history, history_output, file_input, website_url_input])
|
| 1273 |
|
| 1274 |
+
def on_provider_change(provider):
|
| 1275 |
+
return provider
|
| 1276 |
+
|
| 1277 |
+
provider_dropdown.change(
|
| 1278 |
+
on_provider_change,
|
| 1279 |
+
inputs=provider_dropdown,
|
| 1280 |
+
outputs=provider_state
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
if __name__ == "__main__":
|
| 1284 |
demo.queue(api_open=False, default_concurrency_limit=20).launch(ssr_mode=True, mcp_server=False, show_api=False)
|