import os from openai import OpenAI import streamlit as st st.title("Trillion-7B-Preview") client = OpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), # api_key=os.getenv("OpenAI"), # base_url=os.getenv("BASE_URL"), # base_url=os.getenv("https://api.openai.com/v1"), base_url=os.getenv("https://openrouter.ai/api/v1"), ) if "openai_model" not in st.session_state: # st.session_state["openai_model"] = "trillionlabs/Trillion-7B-preview" st.session_state["openai_model"] = "deepseek/deepseek-chat-v3-0324:free" if "messages" not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: with st.chat_message(message["role"]): st.markdown(message["content"]) if prompt := st.chat_input("Message"): st.session_state.messages.append({"role": "user", "content": prompt}) with st.chat_message("user"): st.markdown(prompt) with st.chat_message("assistant"): stream = client.chat.completions.create( model=st.session_state["openai_model"], messages=[ {"role": m["role"], "content": m["content"]} for m in st.session_state.messages ], stream=True, extra_body={ "topP": 0.95, "maxTokens": 3072, "temperature": 0.6, }, ) response = st.write_stream(stream) st.session_state.messages.append({"role": "assistant", "content": response}) # import os # import torch # import time # import warnings # from fastapi import FastAPI, Request # from fastapi.responses import JSONResponse # from fastapi.middleware.cors import CORSMiddleware # import gradio as gr # from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline # # Suppress specific warnings # warnings.filterwarnings("ignore", category=FutureWarning, module="transformers.utils.hub") # # Configure environment variables for cache # os.environ["HF_HOME"] = os.getenv("HF_HOME", "/app/cache/huggingface") # os.environ["MPLCONFIGDIR"] = os.getenv("MPLCONFIGDIR", "/app/cache/matplotlib") # # Ensure cache directories exist # os.makedirs(os.environ["HF_HOME"], exist_ok=True) # os.makedirs(os.environ["MPLCONFIGDIR"], exist_ok=True) # # Initialize FastAPI app # app = FastAPI() # def log_message(message: str): # """Helper function for logging""" # print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {message}") # def load_model(): # """Load the model with CPU optimization""" # model_name = "trillionlabs/Trillion-7B-preview-AWQ" # log_message("Loading tokenizer...") # try: # tokenizer = AutoTokenizer.from_pretrained( # model_name, # trust_remote_code=True # ) # except Exception as e: # log_message(f"Tokenizer loading failed: {e}") # # Fallback to LlamaTokenizer if available # from transformers import LlamaTokenizer # tokenizer = LlamaTokenizer.from_pretrained(model_name) # log_message("Loading model...") # try: # model = AutoModelForCausalLM.from_pretrained( # model_name, # torch_dtype=torch.float32, # trust_remote_code=True # ) # # Explicitly move to CPU # model = model.to("cpu") # model.eval() # except Exception as e: # log_message(f"Model loading failed: {e}") # raise # log_message("Creating pipeline...") # text_generator = pipeline( # "text-generation", # model=model, # tokenizer=tokenizer, # device="cpu" # ) # return text_generator, tokenizer # # Load model # try: # log_message("Starting model loading process...") # text_generator, tokenizer = load_model() # log_message("Model loaded successfully") # except Exception as e: # log_message(f"Critical error loading model: {e}") # raise # # API endpoints # @app.post("/api/generate") # async def api_generate(request: Request): # """API endpoint for text generation""" # try: # data = await request.json() # prompt = data.get("prompt", "").strip() # if not prompt: # return JSONResponse({"error": "Prompt cannot be empty"}, status_code=400) # max_length = min(int(data.get("max_length", 100)), 300) # Conservative limit # start_time = time.time() # outputs = text_generator( # prompt, # max_length=max_length, # do_sample=True, # temperature=0.7, # top_k=50, # top_p=0.95, # pad_token_id=tokenizer.eos_token_id # ) # generation_time = time.time() - start_time # response_data = { # "generated_text": outputs[0]["generated_text"], # "time_seconds": round(generation_time, 2), # "tokens_generated": len(tokenizer.tokenize(outputs[0]["generated_text"])), # "model": "Trillion-7B-preview-AWQ", # "device": "cpu" # } # return JSONResponse(response_data) # except Exception as e: # log_message(f"API Error: {e}") # return JSONResponse({"error": str(e)}, status_code=500) # @app.get("/health") # async def health_check(): # """Health check endpoint""" # return { # "status": "healthy", # "model_loaded": text_generator is not None, # "device": "cpu", # "cache_path": os.environ["HF_HOME"] # } # # Gradio Interface # def gradio_generate(prompt, max_length=100): # """Function for Gradio interface generation""" # try: # max_length = min(int(max_length), 300) # Same conservative limit as API # if not prompt.strip(): # return "Please enter a prompt" # outputs = text_generator( # prompt, # max_length=max_length, # do_sample=True, # temperature=0.7, # top_k=50, # top_p=0.95, # pad_token_id=tokenizer.eos_token_id # ) # return outputs[0]["generated_text"] # except Exception as e: # log_message(f"Gradio Error: {e}") # return f"Error generating text: {str(e)}" # with gr.Blocks(title="Trillion-7B CPU Demo", theme=gr.themes.Default()) as gradio_app: # gr.Markdown(""" # # 🚀 Trillion-7B-preview-AWQ (CPU Version) # *Running on CPU with optimized settings - responses may be slower than GPU versions* # """) # with gr.Row(): # with gr.Column(): # input_prompt = gr.Textbox( # label="Your Prompt", # placeholder="Enter text here...", # lines=5, # max_lines=10 # ) # with gr.Row(): # max_length = gr.Slider( # label="Max Length", # minimum=20, # maximum=300, # value=100, # step=10 # ) # generate_btn = gr.Button("Generate", variant="primary") # with gr.Column(): # output_text = gr.Textbox( # label="Generated Text", # lines=10, # interactive=False # ) # # Examples # gr.Examples( # examples=[ # ["Explain quantum computing in simple terms"], # ["Write a haiku about artificial intelligence"], # ["What are the main benefits of renewable energy?"], # ["Suggest three ideas for a science fiction story"] # ], # inputs=input_prompt, # label="Example Prompts" # ) # generate_btn.click( # fn=gradio_generate, # inputs=[input_prompt, max_length], # outputs=output_text # ) # # Mount Gradio app # app = gr.mount_gradio_app(app, gradio_app, path="/") # # CORS configuration # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # allow_methods=["*"], # allow_headers=["*"], # ) # if __name__ == "__main__": # import uvicorn # uvicorn.run(app, host="0.0.0.0", port=7860)