Run_code_api / src /AI_Models /wave2vec_inference.py
ABAO77's picture
add deepspeed
5d88ac1
raw
history blame
22.1 kB
# import torch
# from transformers import (
# AutoModelForCTC,
# AutoProcessor,
# Wav2Vec2Processor,
# Wav2Vec2ForCTC,
# )
# import onnxruntime as rt
# import numpy as np
# import librosa
# import warnings
# import os
# warnings.filterwarnings("ignore")
# # Available Wave2Vec2 models
# WAVE2VEC2_MODELS = {
# "english_large": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
# "multilingual": "facebook/wav2vec2-large-xlsr-53",
# "english_960h": "facebook/wav2vec2-large-960h-lv60-self",
# "base_english": "facebook/wav2vec2-base-960h",
# "large_english": "facebook/wav2vec2-large-960h",
# "xlsr_english": "jonatasgrosman/wav2vec2-large-xlsr-53-english",
# "xlsr_multilingual": "facebook/wav2vec2-large-xlsr-53"
# }
# # Default model
# DEFAULT_MODEL = "jonatasgrosman/wav2vec2-large-xlsr-53-english"
# def get_available_models():
# """Return dictionary of available Wave2Vec2 models"""
# return WAVE2VEC2_MODELS.copy()
# def get_model_name(model_key=None):
# """
# Get model name from key or return default
# Args:
# model_key: Key from WAVE2VEC2_MODELS or full model name
# Returns:
# str: Full model name
# """
# if model_key is None:
# return DEFAULT_MODEL
# if model_key in WAVE2VEC2_MODELS:
# return WAVE2VEC2_MODELS[model_key]
# # If it's already a full model name, return as is
# return model_key
# class Wave2Vec2Inference:
# def __init__(self, model_name=None, use_gpu=True):
# # Get the actual model name using helper function
# self.model_name = get_model_name(model_name)
# # Auto-detect device
# if use_gpu:
# if torch.backends.mps.is_available():
# self.device = "mps"
# elif torch.cuda.is_available():
# self.device = "cuda"
# else:
# self.device = "cpu"
# else:
# self.device = "cpu"
# print(f"Using device: {self.device}")
# print(f"Loading model: {self.model_name}")
# # Check if model is XLSR and use appropriate processor/model
# is_xlsr = "xlsr" in self.model_name.lower()
# if is_xlsr:
# print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model")
# self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
# self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
# else:
# print("Using AutoProcessor and AutoModelForCTC")
# self.processor = AutoProcessor.from_pretrained(self.model_name)
# self.model = AutoModelForCTC.from_pretrained(self.model_name)
# self.model.to(self.device)
# self.model.eval()
# # Disable gradients for inference
# torch.set_grad_enabled(False)
# def buffer_to_text(self, audio_buffer):
# if len(audio_buffer) == 0:
# return ""
# # Convert to tensor
# if isinstance(audio_buffer, np.ndarray):
# audio_tensor = torch.from_numpy(audio_buffer).float()
# else:
# audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
# # Process audio
# inputs = self.processor(
# audio_tensor,
# sampling_rate=16_000,
# return_tensors="pt",
# padding=True,
# )
# # Move to device
# input_values = inputs.input_values.to(self.device)
# attention_mask = (
# inputs.attention_mask.to(self.device)
# if "attention_mask" in inputs
# else None
# )
# # Inference
# with torch.no_grad():
# if attention_mask is not None:
# logits = self.model(input_values, attention_mask=attention_mask).logits
# else:
# logits = self.model(input_values).logits
# # Decode
# predicted_ids = torch.argmax(logits, dim=-1)
# if self.device != "cpu":
# predicted_ids = predicted_ids.cpu()
# transcription = self.processor.batch_decode(predicted_ids)[0]
# return transcription.lower().strip()
# def file_to_text(self, filename):
# try:
# audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
# return self.buffer_to_text(audio_input)
# except Exception as e:
# print(f"Error loading audio file {filename}: {e}")
# return ""
# class Wave2Vec2ONNXInference:
# def __init__(self, model_name=None, onnx_path=None, use_gpu=True):
# # Get the actual model name using helper function
# self.model_name = get_model_name(model_name)
# print(f"Loading ONNX model: {self.model_name}")
# # Always use Wav2Vec2Processor for ONNX (works for all models)
# self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
# # Setup ONNX Runtime
# options = rt.SessionOptions()
# options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
# # Choose providers based on GPU availability
# providers = []
# if use_gpu and rt.get_available_providers():
# if "CUDAExecutionProvider" in rt.get_available_providers():
# providers.append("CUDAExecutionProvider")
# providers.append("CPUExecutionProvider")
# self.model = rt.InferenceSession(onnx_path, options, providers=providers)
# self.input_name = self.model.get_inputs()[0].name
# print(f"ONNX model loaded with providers: {self.model.get_providers()}")
# def buffer_to_text(self, audio_buffer):
# if len(audio_buffer) == 0:
# return ""
# # Convert to tensor
# if isinstance(audio_buffer, np.ndarray):
# audio_tensor = torch.from_numpy(audio_buffer).float()
# else:
# audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
# # Process audio
# inputs = self.processor(
# audio_tensor,
# sampling_rate=16_000,
# return_tensors="np",
# padding=True,
# )
# # ONNX inference
# input_values = inputs.input_values.astype(np.float32)
# onnx_outputs = self.model.run(None, {self.input_name: input_values})[0]
# # Decode
# prediction = np.argmax(onnx_outputs, axis=-1)
# transcription = self.processor.decode(prediction.squeeze().tolist())
# return transcription.lower().strip()
# def file_to_text(self, filename):
# try:
# audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
# return self.buffer_to_text(audio_input)
# except Exception as e:
# print(f"Error loading audio file {filename}: {e}")
# return ""
# def convert_to_onnx(model_id_or_path, onnx_model_name):
# """Convert PyTorch model to ONNX format"""
# print(f"Converting {model_id_or_path} to ONNX...")
# model = Wav2Vec2ForCTC.from_pretrained(model_id_or_path)
# model.eval()
# # Create dummy input
# audio_len = 250000
# dummy_input = torch.randn(1, audio_len, requires_grad=True)
# torch.onnx.export(
# model,
# dummy_input,
# onnx_model_name,
# export_params=True,
# opset_version=14,
# do_constant_folding=True,
# input_names=["input"],
# output_names=["output"],
# dynamic_axes={
# "input": {1: "audio_len"},
# "output": {1: "audio_len"},
# },
# )
# print(f"ONNX model saved to: {onnx_model_name}")
# def quantize_onnx_model(onnx_model_path, quantized_model_path):
# """Quantize ONNX model for faster inference"""
# print("Starting quantization...")
# from onnxruntime.quantization import quantize_dynamic, QuantType
# quantize_dynamic(
# onnx_model_path, quantized_model_path, weight_type=QuantType.QUInt8
# )
# print(f"Quantized model saved to: {quantized_model_path}")
# def export_to_onnx(model_name, quantize=False):
# """
# Export model to ONNX format with optional quantization
# Args:
# model_name: HuggingFace model name
# quantize: Whether to also create quantized version
# Returns:
# tuple: (onnx_path, quantized_path or None)
# """
# onnx_filename = f"{model_name.split('/')[-1]}.onnx"
# convert_to_onnx(model_name, onnx_filename)
# quantized_path = None
# if quantize:
# quantized_path = onnx_filename.replace(".onnx", ".quantized.onnx")
# quantize_onnx_model(onnx_filename, quantized_path)
# return onnx_filename, quantized_path
# def create_inference(
# model_name=None, use_onnx=False, onnx_path=None, use_gpu=True, use_onnx_quantize=False
# ):
# """
# Create optimized inference instance
# Args:
# model_name: Model key from WAVE2VEC2_MODELS or full HuggingFace model name (default: uses DEFAULT_MODEL)
# use_onnx: Whether to use ONNX runtime
# onnx_path: Path to ONNX model file
# use_gpu: Whether to use GPU if available
# use_onnx_quantize: Whether to use quantized ONNX model
# Returns:
# Inference instance
# """
# # Get the actual model name
# actual_model_name = get_model_name(model_name)
# if use_onnx:
# if not onnx_path or not os.path.exists(onnx_path):
# # Convert to ONNX if path not provided or doesn't exist
# onnx_filename = f"{actual_model_name.split('/')[-1]}.onnx"
# convert_to_onnx(actual_model_name, onnx_filename)
# onnx_path = onnx_filename
# if use_onnx_quantize:
# quantized_path = onnx_path.replace(".onnx", ".quantized.onnx")
# if not os.path.exists(quantized_path):
# quantize_onnx_model(onnx_path, quantized_path)
# onnx_path = quantized_path
# print(f"Using ONNX model: {onnx_path}")
# return Wave2Vec2ONNXInference(model_name, onnx_path, use_gpu)
# else:
# print("Using PyTorch model")
# return Wave2Vec2Inference(model_name, use_gpu)
# if __name__ == "__main__":
# import time
# # Display available models
# print("Available Wave2Vec2 models:")
# for key, model_name in get_available_models().items():
# print(f" {key}: {model_name}")
# print(f"\nDefault model: {DEFAULT_MODEL}")
# print()
# # Test with different models
# test_models = ["english_large", "multilingual", "english_960h"]
# test_file = "test.wav"
# if not os.path.exists(test_file):
# print(f"Test file {test_file} not found. Please provide a valid audio file.")
# print("Creating example usage without actual file...")
# # Example usage without file
# print("\n=== Example Usage ===")
# # Using default model
# print("1. Using default model:")
# asr_default = create_inference()
# print(f" Model loaded: {asr_default.model_name}")
# # Using model key
# print("\n2. Using model key 'english_large':")
# asr_key = create_inference("english_large")
# print(f" Model loaded: {asr_key.model_name}")
# # Using full model name
# print("\n3. Using full model name:")
# asr_full = create_inference("facebook/wav2vec2-base-960h")
# print(f" Model loaded: {asr_full.model_name}")
# exit(0)
# # Test different model configurations
# for model_key in test_models:
# print(f"\n=== Testing model: {model_key} ===")
# # Test different configurations
# configs = [
# {"use_onnx": False, "use_gpu": True},
# {"use_onnx": True, "use_gpu": True, "use_onnx_quantize": False},
# ]
# for config in configs:
# print(f"\nConfig: {config}")
# # Create inference instance with model selection
# asr = create_inference(model_key, **config)
# # Warm up
# asr.file_to_text(test_file)
# # Test performance
# times = []
# for i in range(3):
# start_time = time.time()
# text = asr.file_to_text(test_file)
# end_time = time.time()
# execution_time = end_time - start_time
# times.append(execution_time)
# print(f"Run {i+1}: {execution_time:.3f}s - {text[:50]}...")
# avg_time = sum(times) / len(times)
# print(f"Average time: {avg_time:.3f}s")
import torch
from transformers import (
Wav2Vec2ForCTC,
Wav2Vec2Processor,
AutoProcessor,
AutoModelForCTC,
)
import deepspeed
import librosa
import numpy as np
from typing import Optional, List, Union
def get_model_name(model_name: Optional[str] = None) -> str:
"""Helper function to get model name with default fallback"""
if model_name is None:
return "facebook/wav2vec2-large-robust-ft-libri-960h"
return model_name
class Wave2Vec2Inference:
def __init__(
self,
model_name: Optional[str] = None,
use_gpu: bool = True,
use_deepspeed: bool = True,
):
"""
Initialize Wav2Vec2 model for inference with optional DeepSpeed optimization.
Args:
model_name: HuggingFace model name or None for default
use_gpu: Whether to use GPU acceleration
use_deepspeed: Whether to use DeepSpeed optimization
"""
# Get the actual model name using helper function
self.model_name = get_model_name(model_name)
self.use_deepspeed = use_deepspeed
# Auto-detect device
if use_gpu:
if torch.backends.mps.is_available():
self.device = "mps"
elif torch.cuda.is_available():
self.device = "cuda"
else:
self.device = "cpu"
else:
self.device = "cpu"
print(f"Using device: {self.device}")
print(f"Loading model: {self.model_name}")
print(f"DeepSpeed enabled: {self.use_deepspeed}")
# Check if model is XLSR and use appropriate processor/model
is_xlsr = "xlsr" in self.model_name.lower()
if is_xlsr:
print("Using Wav2Vec2Processor and Wav2Vec2ForCTC for XLSR model")
self.processor = Wav2Vec2Processor.from_pretrained(self.model_name)
self.model = Wav2Vec2ForCTC.from_pretrained(self.model_name)
else:
print("Using AutoProcessor and AutoModelForCTC")
self.processor = AutoProcessor.from_pretrained(self.model_name)
self.model = AutoModelForCTC.from_pretrained(self.model_name)
# Initialize DeepSpeed if enabled
if self.use_deepspeed:
self._init_deepspeed()
else:
self.model.to(self.device)
self.model.eval()
self.ds_engine = None
# Disable gradients for inference
torch.set_grad_enabled(False)
def _init_deepspeed(self):
"""Initialize DeepSpeed inference engine"""
try:
# DeepSpeed configuration based on device
if self.device == "cuda":
ds_config = {
"tensor_parallel": {"tp_size": 1},
"dtype": torch.float32,
"replace_with_kernel_inject": True,
"enable_cuda_graph": False,
}
else:
ds_config = {
"tensor_parallel": {"tp_size": 1},
"dtype": torch.float32,
"replace_with_kernel_inject": False,
"enable_cuda_graph": False,
}
print("Initializing DeepSpeed inference engine...")
self.ds_engine = deepspeed.init_inference(self.model, **ds_config)
self.ds_engine.module.to(self.device)
except Exception as e:
print(f"DeepSpeed initialization failed: {e}")
print("Falling back to standard PyTorch inference...")
self.use_deepspeed = False
self.ds_engine = None
self.model.to(self.device)
self.model.eval()
def _get_model(self):
"""Get the appropriate model for inference"""
if self.use_deepspeed and self.ds_engine is not None:
return self.ds_engine.module
return self.model
def buffer_to_text(
self, audio_buffer: Union[np.ndarray, torch.Tensor, List]
) -> str:
"""
Convert audio buffer to text transcription.
Args:
audio_buffer: Audio data as numpy array, tensor, or list
Returns:
str: Transcribed text
"""
if len(audio_buffer) == 0:
return ""
# Convert to tensor
if isinstance(audio_buffer, np.ndarray):
audio_tensor = torch.from_numpy(audio_buffer).float()
elif isinstance(audio_buffer, list):
audio_tensor = torch.tensor(audio_buffer, dtype=torch.float32)
else:
audio_tensor = audio_buffer.float()
# Process audio
inputs = self.processor(
audio_tensor,
sampling_rate=16_000,
return_tensors="pt",
padding=True,
)
# Move to device
input_values = inputs.input_values.to(self.device)
attention_mask = (
inputs.attention_mask.to(self.device)
if "attention_mask" in inputs
else None
)
# Get the appropriate model
model = self._get_model()
# Inference
with torch.no_grad():
if attention_mask is not None:
outputs = model(input_values, attention_mask=attention_mask)
else:
outputs = model(input_values)
# Handle different output formats
if hasattr(outputs, "logits"):
logits = outputs.logits
else:
logits = outputs
# Decode
predicted_ids = torch.argmax(logits, dim=-1)
if self.device != "cpu":
predicted_ids = predicted_ids.cpu()
transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription.lower().strip()
def file_to_text(self, filename: str) -> str:
"""
Transcribe audio file to text.
Args:
filename: Path to audio file
Returns:
str: Transcribed text
"""
try:
audio_input, _ = librosa.load(filename, sr=16000, dtype=np.float32)
return self.buffer_to_text(audio_input)
except Exception as e:
print(f"Error loading audio file {filename}: {e}")
return ""
def batch_file_to_text(self, filenames: List[str]) -> List[str]:
"""
Transcribe multiple audio files to text.
Args:
filenames: List of audio file paths
Returns:
List[str]: List of transcribed texts
"""
results = []
for i, filename in enumerate(filenames):
print(f"Processing file {i+1}/{len(filenames)}: {filename}")
transcription = self.file_to_text(filename)
results.append(transcription)
if transcription:
print(f"Transcription: {transcription}")
else:
print("Failed to transcribe")
return results
def transcribe_with_confidence(
self, audio_buffer: Union[np.ndarray, torch.Tensor]
) -> tuple:
"""
Transcribe audio and return confidence scores.
Args:
audio_buffer: Audio data
Returns:
tuple: (transcription, confidence_scores)
"""
if len(audio_buffer) == 0:
return "", []
# Convert to tensor
if isinstance(audio_buffer, np.ndarray):
audio_tensor = torch.from_numpy(audio_buffer).float()
else:
audio_tensor = audio_buffer.float()
# Process audio
inputs = self.processor(
audio_tensor,
sampling_rate=16_000,
return_tensors="pt",
padding=True,
)
input_values = inputs.input_values.to(self.device)
attention_mask = (
inputs.attention_mask.to(self.device)
if "attention_mask" in inputs
else None
)
model = self._get_model()
# Inference
with torch.no_grad():
if attention_mask is not None:
outputs = model(input_values, attention_mask=attention_mask)
else:
outputs = model(input_values)
if hasattr(outputs, "logits"):
logits = outputs.logits
else:
logits = outputs
# Get probabilities and confidence scores
probs = torch.nn.functional.softmax(logits, dim=-1)
predicted_ids = torch.argmax(logits, dim=-1)
# Calculate confidence as max probability for each prediction
max_probs = torch.max(probs, dim=-1)[0]
confidence_scores = max_probs.cpu().numpy().tolist()
if self.device != "cpu":
predicted_ids = predicted_ids.cpu()
transcription = self.processor.batch_decode(predicted_ids)[0]
return transcription.lower().strip(), confidence_scores
def cleanup(self):
"""Clean up resources"""
if hasattr(self, "ds_engine") and self.ds_engine is not None:
del self.ds_engine
if hasattr(self, "model"):
del self.model
if hasattr(self, "processor"):
del self.processor
torch.cuda.empty_cache() if torch.cuda.is_available() else None
def __del__(self):
"""Destructor to clean up resources"""
self.cleanup()