Spaces:
Sleeping
Sleeping
| import os | |
| from openai import OpenAI | |
| import streamlit as st | |
| st.title("Trillion-7B-Preview") | |
| client = OpenAI( | |
| api_key=os.getenv("OPENROUTER_API_KEY"), | |
| # api_key=os.getenv("OpenAI"), | |
| # base_url=os.getenv("BASE_URL"), | |
| # base_url=os.getenv("https://api.openai.com/v1"), | |
| base_url=os.getenv("https://openrouter.ai/api/v1"), | |
| ) | |
| if "openai_model" not in st.session_state: | |
| # st.session_state["openai_model"] = "trillionlabs/Trillion-7B-preview" | |
| st.session_state["openai_model"] = "deepseek/deepseek-chat-v3-0324:free" | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if prompt := st.chat_input("Message"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| stream = client.chat.completions.create( | |
| model=st.session_state["openai_model"], | |
| messages=[ | |
| {"role": m["role"], "content": m["content"]} | |
| for m in st.session_state.messages | |
| ], | |
| stream=True, | |
| extra_body={ | |
| "topP": 0.95, | |
| "maxTokens": 3072, | |
| "temperature": 0.6, | |
| }, | |
| ) | |
| response = st.write_stream(stream) | |
| st.session_state.messages.append({"role": "assistant", "content": response}) | |
| # import os | |
| # import torch | |
| # import time | |
| # import warnings | |
| # from fastapi import FastAPI, Request | |
| # from fastapi.responses import JSONResponse | |
| # from fastapi.middleware.cors import CORSMiddleware | |
| # import gradio as gr | |
| # from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| # # Suppress specific warnings | |
| # warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.hub") | |
| # # Configure environment variables for cache | |
| # os.environ["HF_HOME"] = os.getenv("HF_HOME", "/app/cache/huggingface") | |
| # os.environ["MPLCONFIGDIR"] = os.getenv("MPLCONFIGDIR", "/app/cache/matplotlib") | |
| # # Ensure cache directories exist | |
| # os.makedirs(os.environ["HF_HOME"], exist_ok=True) | |
| # os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True) | |
| # # Initialize FastAPI app | |
| # app = FastAPI() | |
| # def log_message(message: str): | |
| # """Helper function for logging""" | |
| # print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {message}") | |
| # def load_model(): | |
| # """Load the model with CPU optimization""" | |
| # model_name = "trillionlabs/Trillion-7B-preview-AWQ" | |
| # log_message("Loading tokenizer...") | |
| # try: | |
| # tokenizer = AutoTokenizer.from_pretrained( | |
| # model_name, | |
| # trust_remote_code=True | |
| # ) | |
| # except Exception as e: | |
| # log_message(f"Tokenizer loading failed: {e}") | |
| # # Fallback to LlamaTokenizer if available | |
| # from transformers import LlamaTokenizer | |
| # tokenizer = LlamaTokenizer.from_pretrained(model_name) | |
| # log_message("Loading model...") | |
| # try: | |
| # model = AutoModelForCausalLM.from_pretrained( | |
| # model_name, | |
| # torch_dtype=torch.float32, | |
| # trust_remote_code=True | |
| # ) | |
| # # Explicitly move to CPU | |
| # model = model.to("cpu") | |
| # model.eval() | |
| # except Exception as e: | |
| # log_message(f"Model loading failed: {e}") | |
| # raise | |
| # log_message("Creating pipeline...") | |
| # text_generator = pipeline( | |
| # "text-generation", | |
| # model=model, | |
| # tokenizer=tokenizer, | |
| # device="cpu" | |
| # ) | |
| # return text_generator, tokenizer | |
| # # Load model | |
| # try: | |
| # log_message("Starting model loading process...") | |
| # text_generator, tokenizer = load_model() | |
| # log_message("Model loaded successfully") | |
| # except Exception as e: | |
| # log_message(f"Critical error loading model: {e}") | |
| # raise | |
| # # API endpoints | |
| # @app.post("/api/generate") | |
| # async def api_generate(request: Request): | |
| # """API endpoint for text generation""" | |
| # try: | |
| # data = await request.json() | |
| # prompt = data.get("prompt", "").strip() | |
| # if not prompt: | |
| # return JSONResponse({"error": "Prompt cannot be empty"}, status_code=400) | |
| # max_length = min(int(data.get("max_length", 100)), 300) # Conservative limit | |
| # start_time = time.time() | |
| # outputs = text_generator( | |
| # prompt, | |
| # max_length=max_length, | |
| # do_sample=True, | |
| # temperature=0.7, | |
| # top_k=50, | |
| # top_p=0.95, | |
| # pad_token_id=tokenizer.eos_token_id | |
| # ) | |
| # generation_time = time.time() - start_time | |
| # response_data = { | |
| # "generated_text": outputs[0]["generated_text"], | |
| # "time_seconds": round(generation_time, 2), | |
| # "tokens_generated": len(tokenizer.tokenize(outputs[0]["generated_text"])), | |
| # "model": "Trillion-7B-preview-AWQ", | |
| # "device": "cpu" | |
| # } | |
| # return JSONResponse(response_data) | |
| # except Exception as e: | |
| # log_message(f"API Error: {e}") | |
| # return JSONResponse({"error": str(e)}, status_code=500) | |
| # @app.get("/health") | |
| # async def health_check(): | |
| # """Health check endpoint""" | |
| # return { | |
| # "status": "healthy", | |
| # "model_loaded": text_generator is not None, | |
| # "device": "cpu", | |
| # "cache_path": os.environ["HF_HOME"] | |
| # } | |
| # # Gradio Interface | |
| # def gradio_generate(prompt, max_length=100): | |
| # """Function for Gradio interface generation""" | |
| # try: | |
| # max_length = min(int(max_length), 300) # Same conservative limit as API | |
| # if not prompt.strip(): | |
| # return "Please enter a prompt" | |
| # outputs = text_generator( | |
| # prompt, | |
| # max_length=max_length, | |
| # do_sample=True, | |
| # temperature=0.7, | |
| # top_k=50, | |
| # top_p=0.95, | |
| # pad_token_id=tokenizer.eos_token_id | |
| # ) | |
| # return outputs[0]["generated_text"] | |
| # except Exception as e: | |
| # log_message(f"Gradio Error: {e}") | |
| # return f"Error generating text: {str(e)}" | |
| # with gr.Blocks(title="Trillion-7B CPU Demo", theme=gr.themes.Default()) as gradio_app: | |
| # gr.Markdown(""" | |
| # # 🚀 Trillion-7B-preview-AWQ (CPU Version) | |
| # *Running on CPU with optimized settings - responses may be slower than GPU versions* | |
| # """) | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # input_prompt = gr.Textbox( | |
| # label="Your Prompt", | |
| # placeholder="Enter text here...", | |
| # lines=5, | |
| # max_lines=10 | |
| # ) | |
| # with gr.Row(): | |
| # max_length = gr.Slider( | |
| # label="Max Length", | |
| # minimum=20, | |
| # maximum=300, | |
| # value=100, | |
| # step=10 | |
| # ) | |
| # generate_btn = gr.Button("Generate", variant="primary") | |
| # with gr.Column(): | |
| # output_text = gr.Textbox( | |
| # label="Generated Text", | |
| # lines=10, | |
| # interactive=False | |
| # ) | |
| # # Examples | |
| # gr.Examples( | |
| # examples=[ | |
| # ["Explain quantum computing in simple terms"], | |
| # ["Write a haiku about artificial intelligence"], | |
| # ["What are the main benefits of renewable energy?"], | |
| # ["Suggest three ideas for a science fiction story"] | |
| # ], | |
| # inputs=input_prompt, | |
| # label="Example Prompts" | |
| # ) | |
| # generate_btn.click( | |
| # fn=gradio_generate, | |
| # inputs=[input_prompt, max_length], | |
| # outputs=output_text | |
| # ) | |
| # # Mount Gradio app | |
| # app = gr.mount_gradio_app(app, gradio_app, path="/") | |
| # # CORS configuration | |
| # app.add_middleware( | |
| # CORSMiddleware, | |
| # allow_origins=["*"], | |
| # allow_methods=["*"], | |
| # allow_headers=["*"], | |
| # ) | |
| # if __name__ == "__main__": | |
| # import uvicorn | |
| # uvicorn.run(app, host="0.0.0.0", port=7860) |