Tiago Caldeira
		
	commited on
		
		
					Commit 
							
							Β·
						
						9f37a6e
	
1
								Parent(s):
							
							44e14ac
								
different approach using unsloth
Browse files- _app.py +75 -0
 - app.py +29 -36
 - requirements.txt +5 -4
 
    	
        _app.py
    ADDED
    
    | 
         @@ -0,0 +1,75 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
            from transformers import AutoProcessor, Gemma3nForConditionalGeneration
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import textwrap
         
     | 
| 5 | 
         
            +
            from huggingface_hub import login
         
     | 
| 6 | 
         
            +
            import os
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
            # Log in using the HF token (automatically read from secret)
         
     | 
| 9 | 
         
            +
            hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
         
     | 
| 10 | 
         
            +
            login(token=hf_token)
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            # π Load model and processor
         
     | 
| 15 | 
         
            +
            model_id = "google/gemma-3n-e2b-it"
         
     | 
| 16 | 
         
            +
            model_id = "google/gemma-3n-E2B"
         
     | 
| 17 | 
         
            +
            model_id = "lmstudio-community/gemma-3n-E2B-it-MLX-4bit"
         
     | 
| 18 | 
         
            +
            model_id = "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit"
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            processor = AutoProcessor.from_pretrained(model_id)
         
     | 
| 22 | 
         
            +
            model = Gemma3nForConditionalGeneration.from_pretrained(
         
     | 
| 23 | 
         
            +
                model_id,
         
     | 
| 24 | 
         
            +
                torch_dtype=torch.float32,
         
     | 
| 25 | 
         
            +
                device_map="cpu"
         
     | 
| 26 | 
         
            +
            ).eval()
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            # π οΈ Helper to format output
         
     | 
| 29 | 
         
            +
            def print_response(text: str) -> str:
         
     | 
| 30 | 
         
            +
                return "\n".join(textwrap.fill(line, 100) for line in text.split("\n"))
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
            # π Inference function for text-only input
         
     | 
| 33 | 
         
            +
            def predict_text(system_prompt: str, user_prompt: str) -> str:
         
     | 
| 34 | 
         
            +
                messages = [
         
     | 
| 35 | 
         
            +
                    {"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]},
         
     | 
| 36 | 
         
            +
                    {"role": "user", "content": [{"type": "text", "text": user_prompt.strip()}]},
         
     | 
| 37 | 
         
            +
                ]
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                inputs = processor.apply_chat_template(
         
     | 
| 40 | 
         
            +
                    messages,
         
     | 
| 41 | 
         
            +
                    add_generation_prompt=True,
         
     | 
| 42 | 
         
            +
                    tokenize=True,
         
     | 
| 43 | 
         
            +
                    return_dict=True,
         
     | 
| 44 | 
         
            +
                    return_tensors="pt"
         
     | 
| 45 | 
         
            +
                ).to(model.device)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                input_len = inputs["input_ids"].shape[-1]
         
     | 
| 48 | 
         
            +
             
     | 
| 49 | 
         
            +
                with torch.inference_mode():
         
     | 
| 50 | 
         
            +
                    output = model.generate(
         
     | 
| 51 | 
         
            +
                        **inputs,
         
     | 
| 52 | 
         
            +
                        max_new_tokens=500,
         
     | 
| 53 | 
         
            +
                        do_sample=False,
         
     | 
| 54 | 
         
            +
                        use_cache=False  # π₯ Fixes CPU bug
         
     | 
| 55 | 
         
            +
                    )
         
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
            +
                gen = output[0][input_len:]
         
     | 
| 58 | 
         
            +
                decoded = processor.decode(gen, skip_special_tokens=True)
         
     | 
| 59 | 
         
            +
                return print_response(decoded)
         
     | 
| 60 | 
         
            +
             
     | 
| 61 | 
         
            +
            # ποΈ Gradio Interface
         
     | 
| 62 | 
         
            +
            demo = gr.Interface(
         
     | 
| 63 | 
         
            +
                fn=predict_text,
         
     | 
| 64 | 
         
            +
                inputs=[
         
     | 
| 65 | 
         
            +
                    gr.Textbox(lines=2, label="System Prompt", value="You are a helpful assistant."),
         
     | 
| 66 | 
         
            +
                    gr.Textbox(lines=4, label="User Prompt", placeholder="Ask something..."),
         
     | 
| 67 | 
         
            +
                ],
         
     | 
| 68 | 
         
            +
                outputs=gr.Textbox(label="Gemma 3n Response"),
         
     | 
| 69 | 
         
            +
                title="Gemma 3n Text-Only Chat",
         
     | 
| 70 | 
         
            +
                description="Interact with the Gemma 3n language model using plain text. Image input not required.",
         
     | 
| 71 | 
         
            +
            )
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            if __name__ == "__main__":
         
     | 
| 74 | 
         
            +
                demo.launch()
         
     | 
| 75 | 
         
            +
             
     | 
    	
        app.py
    CHANGED
    
    | 
         @@ -1,63 +1,56 @@ 
     | 
|
| 1 | 
         
            -
            import gradio as gr
         
     | 
| 2 | 
         
            -
            from transformers import AutoProcessor, Gemma3nForConditionalGeneration
         
     | 
| 3 | 
         
             
            import torch
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 4 | 
         
             
            import textwrap
         
     | 
| 5 | 
         
            -
            from huggingface_hub import login
         
     | 
| 6 | 
         
            -
            import os
         
     | 
| 7 | 
         
            -
             
     | 
| 8 | 
         
            -
            # Log in using the HF token (automatically read from secret)
         
     | 
| 9 | 
         
            -
            hf_token = os.getenv("HUGGINGFACE_HUB_TOKEN")
         
     | 
| 10 | 
         
            -
            login(token=hf_token)
         
     | 
| 11 | 
         
            -
             
     | 
| 12 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 13 | 
         | 
| 14 | 
         
            -
             
     | 
| 15 | 
         
            -
             
     | 
| 16 | 
         
            -
             
     | 
| 17 | 
         
            -
            model_id = "lmstudio-community/gemma-3n-E2B-it-MLX-4bit"
         
     | 
| 18 | 
         
            -
            model_id = "unsloth/gemma-3n-E2B-it-unsloth-bnb-4bit"
         
     | 
| 19 | 
         
            -
             
     | 
| 20 | 
         
            -
            processor = AutoProcessor.from_pretrained(model_id)
         
     | 
| 21 | 
         
            -
            model = Gemma3nForConditionalGeneration.from_pretrained(
         
     | 
| 22 | 
         
            -
                model_id,
         
     | 
| 23 | 
         
            -
                torch_dtype=torch.float32,
         
     | 
| 24 | 
         
            -
                device_map="cpu"
         
     | 
| 25 | 
         
            -
            ).eval()
         
     | 
| 26 | 
         | 
| 27 | 
         
            -
            # π οΈ  
     | 
| 28 | 
         
             
            def print_response(text: str) -> str:
         
     | 
| 29 | 
         
             
                return "\n".join(textwrap.fill(line, 100) for line in text.split("\n"))
         
     | 
| 30 | 
         | 
| 31 | 
         
            -
            # π Inference function for  
     | 
| 32 | 
         
             
            def predict_text(system_prompt: str, user_prompt: str) -> str:
         
     | 
| 33 | 
         
             
                messages = [
         
     | 
| 34 | 
         
             
                    {"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]},
         
     | 
| 35 | 
         
             
                    {"role": "user", "content": [{"type": "text", "text": user_prompt.strip()}]},
         
     | 
| 36 | 
         
             
                ]
         
     | 
| 37 | 
         | 
| 38 | 
         
            -
                inputs =  
     | 
| 39 | 
         
             
                    messages,
         
     | 
| 40 | 
         
             
                    add_generation_prompt=True,
         
     | 
| 41 | 
         
             
                    tokenize=True,
         
     | 
| 42 | 
         
             
                    return_dict=True,
         
     | 
| 43 | 
         
            -
                    return_tensors="pt"
         
     | 
| 44 | 
         
            -
                ).to( 
     | 
| 45 | 
         
            -
             
     | 
| 46 | 
         
            -
                input_len = inputs["input_ids"].shape[-1]
         
     | 
| 47 | 
         | 
| 48 | 
         
             
                with torch.inference_mode():
         
     | 
| 49 | 
         
            -
                     
     | 
| 50 | 
         
             
                        **inputs,
         
     | 
| 51 | 
         
            -
                        max_new_tokens= 
     | 
| 52 | 
         
            -
                         
     | 
| 53 | 
         
            -
                         
     | 
| 
         | 
|
| 54 | 
         
             
                    )
         
     | 
| 55 | 
         | 
| 56 | 
         
            -
                 
     | 
| 57 | 
         
            -
                decoded =  
     | 
| 58 | 
         
             
                return print_response(decoded)
         
     | 
| 59 | 
         | 
| 60 | 
         
            -
            # ποΈ Gradio  
     | 
| 61 | 
         
             
            demo = gr.Interface(
         
     | 
| 62 | 
         
             
                fn=predict_text,
         
     | 
| 63 | 
         
             
                inputs=[
         
     | 
| 
         @@ -66,7 +59,7 @@ demo = gr.Interface( 
     | 
|
| 66 | 
         
             
                ],
         
     | 
| 67 | 
         
             
                outputs=gr.Textbox(label="Gemma 3n Response"),
         
     | 
| 68 | 
         
             
                title="Gemma 3n Text-Only Chat",
         
     | 
| 69 | 
         
            -
                description="Interact with the Gemma 3n language model using plain text.  
     | 
| 70 | 
         
             
            )
         
     | 
| 71 | 
         | 
| 72 | 
         
             
            if __name__ == "__main__":
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1 | 
         
             
            import torch
         
     | 
| 2 | 
         
            +
            import gradio as gr
         
     | 
| 3 | 
         
            +
            from unsloth import FastModel
         
     | 
| 4 | 
         
            +
            from transformers import TextStreamer, AutoTokenizer
         
     | 
| 5 | 
         
             
            import textwrap
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 6 | 
         | 
| 7 | 
         
            +
            # Load model (4-bit quantized)
         
     | 
| 8 | 
         
            +
            model, tokenizer = FastModel.from_pretrained(
         
     | 
| 9 | 
         
            +
                model_name = "unsloth/gemma-3n-E4B-it",
         
     | 
| 10 | 
         
            +
                dtype = None,  # Auto-detect FP16/32
         
     | 
| 11 | 
         
            +
                max_seq_length = 1024,
         
     | 
| 12 | 
         
            +
                load_in_4bit = True,
         
     | 
| 13 | 
         
            +
                full_finetuning = False,
         
     | 
| 14 | 
         
            +
                # token = "hf_..."  # Uncomment if model is gated
         
     | 
| 15 | 
         
            +
            )
         
     | 
| 16 | 
         | 
| 17 | 
         
            +
            model.eval()
         
     | 
| 18 | 
         
            +
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 19 | 
         
            +
            model.to(device)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 20 | 
         | 
| 21 | 
         
            +
            # π οΈ Format output
         
     | 
| 22 | 
         
             
            def print_response(text: str) -> str:
         
     | 
| 23 | 
         
             
                return "\n".join(textwrap.fill(line, 100) for line in text.split("\n"))
         
     | 
| 24 | 
         | 
| 25 | 
         
            +
            # π Inference function for Gradio
         
     | 
| 26 | 
         
             
            def predict_text(system_prompt: str, user_prompt: str) -> str:
         
     | 
| 27 | 
         
             
                messages = [
         
     | 
| 28 | 
         
             
                    {"role": "system", "content": [{"type": "text", "text": system_prompt.strip()}]},
         
     | 
| 29 | 
         
             
                    {"role": "user", "content": [{"type": "text", "text": user_prompt.strip()}]},
         
     | 
| 30 | 
         
             
                ]
         
     | 
| 31 | 
         | 
| 32 | 
         
            +
                inputs = tokenizer.apply_chat_template(
         
     | 
| 33 | 
         
             
                    messages,
         
     | 
| 34 | 
         
             
                    add_generation_prompt=True,
         
     | 
| 35 | 
         
             
                    tokenize=True,
         
     | 
| 36 | 
         
             
                    return_dict=True,
         
     | 
| 37 | 
         
            +
                    return_tensors="pt",
         
     | 
| 38 | 
         
            +
                ).to(device)
         
     | 
| 
         | 
|
| 
         | 
|
| 39 | 
         | 
| 40 | 
         
             
                with torch.inference_mode():
         
     | 
| 41 | 
         
            +
                    outputs = model.generate(
         
     | 
| 42 | 
         
             
                        **inputs,
         
     | 
| 43 | 
         
            +
                        max_new_tokens=256,
         
     | 
| 44 | 
         
            +
                        temperature=1.0,
         
     | 
| 45 | 
         
            +
                        top_p=0.95,
         
     | 
| 46 | 
         
            +
                        top_k=64,
         
     | 
| 47 | 
         
             
                    )
         
     | 
| 48 | 
         | 
| 49 | 
         
            +
                generated = outputs[0][inputs["input_ids"].shape[-1]:]
         
     | 
| 50 | 
         
            +
                decoded = tokenizer.decode(generated, skip_special_tokens=True)
         
     | 
| 51 | 
         
             
                return print_response(decoded)
         
     | 
| 52 | 
         | 
| 53 | 
         
            +
            # ποΈ Gradio UI
         
     | 
| 54 | 
         
             
            demo = gr.Interface(
         
     | 
| 55 | 
         
             
                fn=predict_text,
         
     | 
| 56 | 
         
             
                inputs=[
         
     | 
| 
         | 
|
| 59 | 
         
             
                ],
         
     | 
| 60 | 
         
             
                outputs=gr.Textbox(label="Gemma 3n Response"),
         
     | 
| 61 | 
         
             
                title="Gemma 3n Text-Only Chat",
         
     | 
| 62 | 
         
            +
                description="Interact with the Gemma 3n language model using plain text. 4-bit quantized for efficiency.",
         
     | 
| 63 | 
         
             
            )
         
     | 
| 64 | 
         | 
| 65 | 
         
             
            if __name__ == "__main__":
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -1,6 +1,7 @@ 
     | 
|
| 1 | 
         
            -
            transformers>=4.42.0
         
     | 
| 2 | 
         
            -
            torch
         
     | 
| 3 | 
         
             
            gradio
         
     | 
| 4 | 
         
            -
            accelerate
         
     | 
| 5 | 
         
             
            timm
         
     | 
| 6 | 
         
            -
            bitsandbytes
         
     | 
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            #transformers>=4.42.0
         
     | 
| 2 | 
         
            +
            #torch
         
     | 
| 3 | 
         
             
            gradio
         
     | 
| 4 | 
         
            +
            #accelerate
         
     | 
| 5 | 
         
             
            timm
         
     | 
| 6 | 
         
            +
            #bitsandbytes
         
     | 
| 7 | 
         
            +
            unsloth
         
     |