Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| ChatGPT Clone - 日本語対応チャットボット | |
| Hugging Face Spaces (ZeroGPU) 対応版 | |
| 使用モデル: | |
| - elyza/Llama-3-ELYZA-JP-8B | |
| - Fugaku-LLM/Fugaku-LLM-13B | |
| - openai/gpt-oss-20b | |
| """ | |
| import os | |
| from typing import List, Tuple | |
| # Check if running on ZeroGPU FIRST (before any CUDA initialization) | |
| try: | |
| import spaces | |
| IS_ZEROGPU = True | |
| print("ZeroGPU環境を検出しました。") | |
| except ImportError: | |
| IS_ZEROGPU = False | |
| print("通常のGPU/CPU環境で実行しています。") | |
| # Import after spaces check | |
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| # Hugging Face token from environment variable | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # トークンのチェック | |
| if not HF_TOKEN: | |
| print("警告: HF_TOKENが設定されていません。プライベートモデルへのアクセスが制限される場合があります。") | |
| class ChatBot: | |
| def __init__(self): | |
| self.model = None | |
| self.tokenizer = None | |
| self.pipeline = None | |
| self.current_model = None | |
| def is_gpt_oss_model(self, model_name: str) -> bool: | |
| """gpt-ossモデルかどうかを判定""" | |
| return "gpt-oss" in model_name.lower() | |
| def load_model(self, model_name: str): | |
| """モデルとトークナイザーをロード""" | |
| if self.current_model == model_name and (self.model is not None or self.pipeline is not None): | |
| return | |
| try: | |
| # メモリクリア | |
| if self.model is not None: | |
| del self.model | |
| del self.tokenizer | |
| if self.pipeline is not None: | |
| del self.pipeline | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| if self.is_gpt_oss_model(model_name): | |
| # gpt-ossモデルの場合はpipelineを使用 | |
| print(f"gpt-ossモデル {model_name} をpipelineでロードします...") | |
| self.pipeline = pipeline( | |
| "text-generation", | |
| model=model_name, | |
| torch_dtype="auto", # gpt-ossモデルに適した精度を自動選択 | |
| trust_remote_code=True, | |
| token=HF_TOKEN, | |
| device_map="auto" if not IS_ZEROGPU else None # ZeroGPU以外では自動マッピング | |
| ) | |
| self.model = None | |
| self.tokenizer = None | |
| else: | |
| # 通常のモデルの場合 | |
| print(f"通常のモデル {model_name} をロードします...") | |
| # トークナイザーロード | |
| self.tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| token=HF_TOKEN, | |
| trust_remote_code=True, | |
| padding_side="left" | |
| ) | |
| # パッドトークンの設定 | |
| if self.tokenizer.pad_token is None: | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.tokenizer.pad_token_id = self.tokenizer.eos_token_id | |
| # モデルロード(ZeroGPU対応) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| token=HF_TOKEN, | |
| torch_dtype="auto", # 自動精度選択 | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| load_in_8bit=False, # ZeroGPU環境では8bit量子化は使わない | |
| device_map="auto" if not IS_ZEROGPU else None # ZeroGPU対応 | |
| ) | |
| self.pipeline = None | |
| self.current_model = model_name | |
| print(f"モデル {model_name} のロードが完了しました。") | |
| except Exception as e: | |
| print(f"モデルのロード中にエラーが発生しました: {str(e)}") | |
| # gpt-ossモデルでエラーが出た場合、使用不可と表示 | |
| if self.is_gpt_oss_model(model_name): | |
| raise Exception(f"gpt-ossモデルのロードに失敗しました。このモデルは現在の環境では使用できません: {str(e)}") | |
| else: | |
| raise | |
| def _generate_response_gpu(self, message: str, history: List[Tuple[str, str]], model_name: str, | |
| temperature: float = 0.7, max_tokens: int = 512) -> str: | |
| """GPU上で応答を生成する実際の処理""" | |
| try: | |
| # モデルロード | |
| self.load_model(model_name) | |
| if self.is_gpt_oss_model(model_name): | |
| # gpt-ossモデルの場合 | |
| return self._generate_with_pipeline(message, history, temperature, max_tokens) | |
| else: | |
| # 通常のモデルの場合 | |
| return self._generate_with_model(message, history, temperature, max_tokens) | |
| except Exception as e: | |
| return f"エラー: {str(e)}" | |
| def _generate_with_pipeline(self, message: str, history: List[Tuple[str, str]], | |
| temperature: float, max_tokens: int) -> str: | |
| """pipelineを使用した生成(gpt-oss用)""" | |
| # device_mapが設定されていない場合のみ手動でGPU移動 | |
| if IS_ZEROGPU and hasattr(self.pipeline, 'model') and hasattr(self.pipeline.model, 'to'): | |
| self.pipeline.model.to('cuda') | |
| # gpt-ossはchat format用のmessages形式を使用 | |
| messages = [] | |
| # 履歴を追加(最新3件のみ) | |
| for user_msg, assistant_msg in history[-3:]: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| # 現在のメッセージを追加 | |
| messages.append({"role": "user", "content": message}) | |
| # pipeline経由で生成 | |
| outputs = self.pipeline( | |
| messages, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.95, | |
| return_full_text=False | |
| ) | |
| # ZeroGPU環境でのみCPUに戻す(メモリ節約) | |
| if IS_ZEROGPU and hasattr(self.pipeline, 'model') and hasattr(self.pipeline.model, 'to'): | |
| self.pipeline.model.to('cpu') | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| # レスポンスの処理 | |
| response_text = outputs[0]["generated_text"].strip() | |
| # gpt-ossモデルの場合、"assistantfinal"以降のみを抽出 | |
| if "assistantfinal" in response_text: | |
| response_text = response_text.split("assistantfinal", 1)[1].strip() | |
| return response_text | |
| def _generate_with_model(self, message: str, history: List[Tuple[str, str]], | |
| temperature: float, max_tokens: int) -> str: | |
| """通常のモデルを使用した生成""" | |
| # GPUに移動 | |
| self.model.to('cuda') | |
| # プロンプト構築 | |
| prompt = self._build_prompt(message, history) | |
| # トークナイズ | |
| inputs = self.tokenizer.encode(prompt, return_tensors="pt").to('cuda') | |
| # 生成 | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=50, | |
| repetition_penalty=1.1, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # デコード | |
| response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
| # CPUに戻す(メモリ節約) | |
| self.model.to('cpu') | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| return response.strip() | |
| def generate_response(self, message: str, history: List[Tuple[str, str]], model_name: str, | |
| temperature: float = 0.7, max_tokens: int = 512) -> str: | |
| """メッセージに対する応答を生成""" | |
| if IS_ZEROGPU: | |
| # ZeroGPU環境の場合 | |
| return self._generate_response_gpu(message, history, model_name, temperature, max_tokens) | |
| else: | |
| # 通常環境の場合 | |
| try: | |
| self.load_model(model_name) | |
| if self.is_gpt_oss_model(model_name): | |
| # gpt-ossモデルの場合 | |
| # device_mapが設定されていない場合のみ手動でGPU移動 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| if device == 'cuda' and hasattr(self.pipeline, 'model') and hasattr(self.pipeline.model, 'to'): | |
| # device_mapが"auto"でない場合のみ手動移動 | |
| self.pipeline.model.to(device) | |
| messages = [] | |
| for user_msg, assistant_msg in history[-3:]: | |
| messages.append({"role": "user", "content": user_msg}) | |
| messages.append({"role": "assistant", "content": assistant_msg}) | |
| messages.append({"role": "user", "content": message}) | |
| outputs = self.pipeline( | |
| messages, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.95, | |
| return_full_text=False | |
| ) | |
| response_text = outputs[0]["generated_text"].strip() | |
| # gpt-ossモデルの場合、"assistantfinal"以降のみを抽出 | |
| if "assistantfinal" in response_text: | |
| response_text = response_text.split("assistantfinal", 1)[1].strip() | |
| return response_text | |
| # 通常のモデルの場合 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| if device == 'cuda': | |
| self.model.to(device) | |
| prompt = self._build_prompt(message, history) | |
| inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=50, | |
| repetition_penalty=1.1, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| response = self.tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
| return response.strip() | |
| except Exception as e: | |
| return f"エラー: {str(e)}" | |
| def _build_prompt(self, message: str, history: List[Tuple[str, str]]) -> str: | |
| """会話履歴からプロンプトを構築(通常のモデル用)""" | |
| prompt = "" | |
| # 履歴を追加(最新3件のみ使用 - メモリ効率のため) | |
| for user_msg, assistant_msg in history[-3:]: | |
| prompt += f"User: {user_msg}\nAssistant: {assistant_msg}\n\n" | |
| # 現在のメッセージを追加 | |
| prompt += f"User: {message}\nAssistant: " | |
| return prompt | |
| # ChatBotインスタンス | |
| chatbot = ChatBot() | |
| # ZeroGPU環境の場合、GPUデコレータを適用 | |
| if IS_ZEROGPU: | |
| chatbot._generate_response_gpu = spaces.GPU(duration=120)(chatbot._generate_response_gpu) | |
| def respond(message: str, history: List[Tuple[str, str]], model_name: str, | |
| temperature: float, max_tokens: int) -> Tuple[List[Tuple[str, str]], str]: | |
| """Gradioのコールバック関数""" | |
| if not message: | |
| return history, "" | |
| try: | |
| # 応答生成 | |
| response = chatbot.generate_response(message, history, model_name, temperature, max_tokens) | |
| # 履歴に追加 | |
| history.append((message, response)) | |
| return history, "" | |
| except RuntimeError as e: | |
| if "out of memory" in str(e).lower(): | |
| error_msg = "メモリ不足エラー: より小さいモデルを使用するか、最大トークン数を減らしてください。" | |
| else: | |
| error_msg = f"実行時エラー: {str(e)}" | |
| history.append((message, error_msg)) | |
| return history, "" | |
| except Exception as e: | |
| error_msg = f"エラーが発生しました: {str(e)}" | |
| history.append((message, error_msg)) | |
| return history, "" | |
| def clear_chat() -> Tuple[List, str]: | |
| """チャット履歴をクリア""" | |
| return [], "" | |
| # Gradio UI | |
| with gr.Blocks(title="ChatGPT Clone", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# 🤖 ChatGPT Clone") | |
| gr.Markdown(""" | |
| 日本語対応のLLMを使用したチャットボットです。 | |
| **使用可能モデル:** | |
| - [elyza/Llama-3-ELYZA-JP-8B](https://huggingface.co/elyza/Llama-3-ELYZA-JP-8B) | |
| - [Fugaku-LLM/Fugaku-LLM-13B](https://huggingface.co/Fugaku-LLM/Fugaku-LLM-13B) | |
| - [openai/gpt-oss-20b](https://huggingface.co/openai/gpt-oss-20b) - OpenAIの最新オープンウェイト推論モデル | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot_ui = gr.Chatbot( | |
| label="Chat", | |
| height=500, | |
| show_label=False, | |
| container=True | |
| ) | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| label="メッセージを入力", | |
| placeholder="ここにメッセージを入力してください...", | |
| lines=2, | |
| scale=4, | |
| show_label=False | |
| ) | |
| send_btn = gr.Button("送信", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("🗑️ 新しい会話", variant="secondary") | |
| with gr.Column(scale=1): | |
| model_select = gr.Dropdown( | |
| choices=[ | |
| "elyza/Llama-3-ELYZA-JP-8B", | |
| "Fugaku-LLM/Fugaku-LLM-13B", | |
| "openai/gpt-oss-20b", | |
| ], | |
| value="elyza/Llama-3-ELYZA-JP-8B", | |
| label="モデル選択", | |
| interactive=True | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="生成の創造性を調整" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=64, | |
| maximum=512, | |
| value=256, | |
| step=64, | |
| label="最大トークン数", | |
| info="生成する最大トークン数" | |
| ) | |
| gr.Markdown(""" | |
| ### 使い方 | |
| 1. モデルを選択 | |
| 2. メッセージを入力 | |
| 3. 送信ボタンをクリック | |
| ### 注意事項 | |
| - 初回のモデル読み込みには時間がかかります | |
| - ZeroGPU使用により高速推論が可能 | |
| - 1回の生成は120秒以内に完了します | |
| - 大きなモデル使用時は、短めの応答になる場合があります | |
| - gpt-oss-20bは推論専用モデルで、harmony formatを使用します | |
| """) | |
| # イベントハンドラ | |
| msg_input.submit( | |
| fn=respond, | |
| inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens], | |
| outputs=[chatbot_ui, msg_input] | |
| ) | |
| send_btn.click( | |
| fn=respond, | |
| inputs=[msg_input, chatbot_ui, model_select, temperature, max_tokens], | |
| outputs=[chatbot_ui, msg_input] | |
| ) | |
| clear_btn.click( | |
| fn=clear_chat, | |
| outputs=[chatbot_ui, msg_input] | |
| ) | |
| if __name__ == "__main__": | |
| # Hugging Face Spaces環境かどうかを確認 | |
| is_hf_spaces = os.getenv("SPACE_ID") is not None | |
| app.launch( | |
| share=False, | |
| show_error=True, | |
| server_name="0.0.0.0" if is_hf_spaces else "127.0.0.1", | |
| server_port=7860 | |
| ) |