File size: 5,138 Bytes
929223c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python3
"""
Test Shisa v2 Qwen2.5-7B ONNX model
Includes QNN provider compatibility check
"""

from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer
import time
import onnxruntime as ort

def check_qnn_support():
    """Check if QNN execution provider is available"""
    print("="*70)
    print("Checking ONNX Runtime Providers")
    print("="*70)

    providers = ort.get_available_providers()
    print("\nAvailable providers:")
    for provider in providers:
        print(f"  โœ“ {provider}")

    if "QNNExecutionProvider" in providers:
        print("\nโœ… QNN Execution Provider is available!")
        print("   Ready for Qualcomm hardware acceleration")
        return True
    else:
        print("\nโš ๏ธ  QNN Execution Provider not available")
        print("   Note: QNN requires Qualcomm's QNN SDK")
        print("   Model will run on CPU for now")
        return False

def test_onnx_model():
    """Test the ONNX model"""
    print("\n" + "="*70)
    print("Testing Shisa v2 Qwen2.5-7B ONNX Model")
    print("="*70)
    print()

    model_path = "models/Shisa_ONNX"

    print(f"๐Ÿ“ Loading model from: {model_path}")

    # Load model and tokenizer
    print("Loading ONNX model...")
    load_start = time.time()
    model = ORTModelForCausalLM.from_pretrained(model_path)
    load_time = time.time() - load_start
    print(f"โœ… Model loaded in {load_time:.2f} seconds")

    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    print("โœ… Tokenizer loaded")

    # Test cases (Japanese and English)
    test_cases = [
        {
            "prompt": "ใ“ใ‚“ใซใกใฏ๏ผ่ชฟๅญใฏใฉใ†ใงใ™ใ‹๏ผŸ",
            "description": "Japanese greeting",
            "max_tokens": 50
        },
        {
            "prompt": "Hello! How are you today?",
            "description": "English greeting",
            "max_tokens": 50
        },
        {
            "prompt": "ๆ—ฅๆœฌใฎ้ฆ–้ƒฝใฏไฝ•ใงใ™ใ‹๏ผŸ",
            "description": "Japanese factual question",
            "max_tokens": 30
        },
        {
            "prompt": "What is 2+2?",
            "description": "Simple math",
            "max_tokens": 20
        }
    ]

    print("\n" + "="*70)
    print("Running Test Cases")
    print("="*70)

    total_tokens = 0
    total_time = 0

    for i, test_case in enumerate(test_cases, 1):
        print(f"\n{'โ”€'*70}")
        print(f"Test {i}/{len(test_cases)}: {test_case['description']}")
        print(f"{'โ”€'*70}")

        prompt = test_case["prompt"]
        print(f"Prompt: {prompt}")

        # Apply chat template if available
        if hasattr(tokenizer, 'apply_chat_template') and tokenizer.chat_template:
            messages = [{"role": "user", "content": prompt}]
            formatted_prompt = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
        else:
            formatted_prompt = prompt

        # Tokenize
        inputs = tokenizer(formatted_prompt, return_tensors="pt")
        input_length = len(inputs['input_ids'][0])

        # Generate
        print("\nGenerating...")
        start_time = time.time()

        outputs = model.generate(
            **inputs,
            max_new_tokens=test_case["max_tokens"],
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

        generation_time = time.time() - start_time

        # Decode
        response = tokenizer.decode(outputs[0], skip_special_tokens=True)
        tokens_generated = len(outputs[0]) - input_length

        print(f"\nResponse: {response}")
        print(f"\nโฑ๏ธ  Time: {generation_time:.2f}s")
        print(f"๐Ÿ“Š Tokens: {tokens_generated}")
        print(f"๐Ÿš€ Speed: {tokens_generated/generation_time:.2f} tokens/sec")

        total_tokens += tokens_generated
        total_time += generation_time

    # Summary
    print("\n" + "="*70)
    print("Test Summary")
    print("="*70)
    print(f"โœ… All {len(test_cases)} tests completed successfully!")
    print(f"๐Ÿ“Š Total tokens generated: {total_tokens}")
    print(f"โฑ๏ธ  Total generation time: {total_time:.2f}s")
    print(f"๐Ÿš€ Average speed: {total_tokens/total_time:.2f} tokens/sec")
    print(f"๐Ÿ’พ Model size: ~527MB (FP32 ONNX)")
    print("="*70)

    return True

if __name__ == "__main__":
    try:
        # Check QNN support
        qnn_available = check_qnn_support()

        # Test model
        test_onnx_model()

        print("\nโœ… ONNX model is ready for deployment!")
        if not qnn_available:
            print("\n๐Ÿ“ To enable QNN acceleration:")
            print("   1. Install Qualcomm QNN SDK")
            print("   2. Build ONNX Runtime with QNN support")
            print("   3. Deploy to Qualcomm hardware (NPU)")

    except Exception as e:
        print(f"\nโŒ Error: {e}")
        import traceback
        traceback.print_exc()