Spaces:
Runtime error
Runtime error
| use crate::validation::ValidationError::{BestOfSampling, BestOfSeed, EmptyInput}; | |
| /// Payload validation logic | |
| use crate::{GenerateParameters, GenerateRequest}; | |
| use rand::{thread_rng, Rng}; | |
| use text_generation_client::{NextTokenChooserParameters, StoppingCriteriaParameters}; | |
| use thiserror::Error; | |
| use tokenizers::tokenizer::Tokenizer; | |
| use tokenizers::TruncationDirection; | |
| use tokio::sync::oneshot; | |
| use tracing::{instrument, Span}; | |
| /// Validation | |
| pub struct Validation { | |
| /// Validation parameters | |
| max_best_of: usize, | |
| max_stop_sequences: usize, | |
| max_input_length: usize, | |
| max_total_tokens: usize, | |
| /// Channel to communicate with the background tokenization task | |
| sender: Option<flume::Sender<TokenizerRequest>>, | |
| } | |
| impl Validation { | |
| pub(crate) fn new( | |
| workers: usize, | |
| tokenizer: Option<Tokenizer>, | |
| max_best_of: usize, | |
| max_stop_sequences: usize, | |
| max_input_length: usize, | |
| max_total_tokens: usize, | |
| ) -> Self { | |
| if max_input_length >= max_total_tokens { | |
| panic!("`max_input_length` must be < `max_total_tokens`"); | |
| } | |
| // If we have a fast tokenizer | |
| let sender = if let Some(tokenizer) = tokenizer { | |
| // Create channel | |
| let (validation_sender, validation_receiver) = flume::unbounded(); | |
| // Create workers | |
| for _ in 0..workers { | |
| let tokenizer_clone = tokenizer.clone(); | |
| let receiver_clone = validation_receiver.clone(); | |
| // Spawn worker | |
| tokio::task::spawn_blocking(move || { | |
| tokenizer_worker(tokenizer_clone, receiver_clone) | |
| }); | |
| } | |
| Some(validation_sender) | |
| } else { | |
| None | |
| }; | |
| Self { | |
| max_best_of, | |
| sender, | |
| max_stop_sequences, | |
| max_input_length, | |
| max_total_tokens, | |
| } | |
| } | |
| async fn validate_input( | |
| &self, | |
| inputs: String, | |
| truncate: Option<usize>, | |
| max_new_tokens: u32, | |
| ) -> Result<(String, usize), ValidationError> { | |
| // If we have a fast tokenizer | |
| if let Some(sender) = &self.sender { | |
| // Create response channel | |
| let (response_sender, response_receiver) = oneshot::channel(); | |
| // Send request to the background validation task | |
| // Unwrap is safe here | |
| sender | |
| .send(((inputs, truncate), response_sender, Span::current())) | |
| .unwrap(); | |
| // Await on response channel | |
| // Unwrap is safe here | |
| let (inputs, input_length) = response_receiver.await.unwrap()?; | |
| // Get total tokens | |
| let total_tokens = input_length + max_new_tokens as usize; | |
| // Validate MaxTotalTokens | |
| if total_tokens > self.max_total_tokens { | |
| return Err(ValidationError::MaxTotalTokens( | |
| self.max_total_tokens, | |
| input_length, | |
| max_new_tokens, | |
| )); | |
| } | |
| // Validate InputLength | |
| if input_length > self.max_input_length { | |
| return Err(ValidationError::InputLength( | |
| self.max_input_length, | |
| input_length, | |
| )); | |
| } | |
| metrics::histogram!("tgi_request_input_length", input_length as f64); | |
| Ok((inputs, input_length)) | |
| } | |
| // Return inputs without validation | |
| else { | |
| // In this case, we don't know the real length in tokens of the inputs | |
| // However, the inputs will be truncated by the python servers | |
| // We make sure that truncate + max_new_tokens <= self.max_total_tokens | |
| let input_length = truncate.unwrap_or(self.max_input_length); | |
| // Validate MaxNewTokens | |
| if (input_length as u32 + max_new_tokens) > self.max_total_tokens as u32 { | |
| return Err(ValidationError::MaxNewTokens( | |
| self.max_total_tokens - self.max_input_length, | |
| max_new_tokens, | |
| )); | |
| } | |
| Ok((inputs, input_length)) | |
| } | |
| } | |
| /// Validate a payload and get the number of tokens in the input | |
| pub(crate) async fn validate( | |
| &self, | |
| request: GenerateRequest, | |
| ) -> Result<ValidGenerateRequest, ValidationError> { | |
| let GenerateParameters { | |
| best_of, | |
| temperature, | |
| repetition_penalty, | |
| top_k, | |
| top_p, | |
| typical_p, | |
| do_sample, | |
| max_new_tokens, | |
| stop: stop_sequences, | |
| truncate, | |
| seed, | |
| watermark, | |
| .. | |
| } = request.parameters; | |
| // sampling must be true when best_of > 1 | |
| let best_of = best_of.unwrap_or(1); | |
| let sampling = do_sample | |
| || temperature.is_some() | |
| || top_k.is_some() | |
| || top_p.is_some() | |
| || typical_p.is_some(); | |
| if best_of > 1 && !sampling { | |
| return Err(BestOfSampling); | |
| } | |
| let temperature = temperature.unwrap_or(1.0); | |
| if temperature <= 0.0 { | |
| return Err(ValidationError::Temperature); | |
| } | |
| let repetition_penalty = repetition_penalty.unwrap_or(1.0); | |
| if repetition_penalty <= 0.0 { | |
| return Err(ValidationError::RepetitionPenalty); | |
| } | |
| // Different because the proto default value is not a valid value | |
| // for the user | |
| let top_p = top_p | |
| .map(|value| { | |
| if value <= 0.0 || value >= 1.0 { | |
| return Err(ValidationError::TopP); | |
| } | |
| Ok(value) | |
| }) | |
| .unwrap_or(Ok(1.0))?; | |
| let typical_p = typical_p | |
| .map(|value| { | |
| if value <= 0.0 || value >= 1.0 { | |
| return Err(ValidationError::TypicalP); | |
| } | |
| Ok(value) | |
| }) | |
| .unwrap_or(Ok(1.0))?; | |
| let top_k: u32 = top_k | |
| .map(|value| { | |
| if value <= 0 { | |
| return Err(ValidationError::TopK); | |
| } | |
| Ok(value as u32) | |
| }) | |
| .unwrap_or(Ok(0))?; | |
| if max_new_tokens == 0 { | |
| return Err(ValidationError::NegativeMaxNewTokens); | |
| } | |
| if stop_sequences.len() > self.max_stop_sequences { | |
| return Err(ValidationError::StopSequence( | |
| self.max_stop_sequences, | |
| stop_sequences.len(), | |
| )); | |
| } | |
| // If seed is None, assign a random one | |
| let seed = match seed { | |
| None => thread_rng().gen(), | |
| Some(seed) => { | |
| if best_of > 1 { | |
| return Err(BestOfSeed); | |
| } | |
| seed | |
| } | |
| }; | |
| // Check if inputs is empty | |
| if request.inputs.is_empty() { | |
| return Err(EmptyInput); | |
| } | |
| // Check if truncate is strictly positive and less than max_input_length | |
| let truncate = truncate | |
| .map(|value| { | |
| if value == 0 || value > self.max_input_length { | |
| return Err(ValidationError::Truncate(self.max_input_length, value)); | |
| } | |
| Ok(Some(value)) | |
| }) | |
| .unwrap_or(Ok(None))?; | |
| // Validate inputs | |
| let (inputs, input_length) = self | |
| .validate_input(request.inputs, truncate, max_new_tokens) | |
| .await?; | |
| let parameters = NextTokenChooserParameters { | |
| temperature, | |
| repetition_penalty, | |
| top_k, | |
| top_p, | |
| typical_p, | |
| do_sample, | |
| seed, | |
| watermark, | |
| }; | |
| let stopping_parameters = StoppingCriteriaParameters { | |
| max_new_tokens, | |
| stop_sequences, | |
| ignore_eos_token: false, | |
| }; | |
| metrics::histogram!("tgi_request_max_new_tokens", max_new_tokens as f64); | |
| Ok(ValidGenerateRequest { | |
| inputs, | |
| input_length: input_length as u32, | |
| truncate: truncate.unwrap_or(self.max_input_length) as u32, | |
| parameters, | |
| stopping_parameters, | |
| }) | |
| } | |
| /// Validate the best_of parameter | |
| pub(crate) fn validate_best_of(&self, best_of: usize) -> Result<usize, ValidationError> { | |
| if self.max_best_of == 1 && best_of != 1 { | |
| return Err(ValidationError::BestOfDisabled); | |
| } | |
| if best_of > self.max_best_of { | |
| return Err(ValidationError::BestOf(self.max_best_of, best_of)); | |
| } | |
| Ok(best_of) | |
| } | |
| } | |
| /// Start tokenization workers | |
| fn tokenizer_worker(tokenizer: Tokenizer, receiver: flume::Receiver<TokenizerRequest>) { | |
| // Loop over requests | |
| while let Ok(((inputs, truncate), response_tx, parent_span)) = receiver.recv() { | |
| parent_span.in_scope(|| { | |
| response_tx | |
| .send(prepare_input(inputs, truncate, &tokenizer)) | |
| .unwrap_or(()) | |
| }) | |
| } | |
| } | |
| /// Get input length and optionally truncate it | |
| fn prepare_input( | |
| inputs: String, | |
| truncate: Option<usize>, | |
| tokenizer: &Tokenizer, | |
| ) -> Result<(String, usize), ValidationError> { | |
| // Get the number of tokens in the input | |
| let mut encoding = tokenizer | |
| .encode(inputs.clone(), true) | |
| .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; | |
| // Optionally truncate | |
| let (inputs, input_length) = match truncate { | |
| // Truncate is some and < encoding length | |
| Some(truncate) if truncate < encoding.len() => { | |
| // truncate encoding and decode new inputs | |
| encoding.truncate(truncate, 0, TruncationDirection::Left); | |
| let inputs = tokenizer | |
| .decode(Vec::from(encoding.get_ids()), false) | |
| .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; | |
| (inputs, encoding.len()) | |
| } | |
| // Nothing to do | |
| _ => (inputs, encoding.len()), | |
| }; | |
| Ok((inputs, input_length)) | |
| } | |
| type TokenizerRequest = ( | |
| (String, Option<usize>), | |
| oneshot::Sender<Result<(String, usize), ValidationError>>, | |
| Span, | |
| ); | |
| pub(crate) struct ValidGenerateRequest { | |
| pub inputs: String, | |
| pub input_length: u32, | |
| pub truncate: u32, | |
| pub parameters: NextTokenChooserParameters, | |
| pub stopping_parameters: StoppingCriteriaParameters, | |
| } | |
| pub enum ValidationError { | |
| BestOf(usize, usize), | |
| BestOfDisabled, | |
| BestOfSampling, | |
| BestOfSeed, | |
| BestOfStream, | |
| Temperature, | |
| RepetitionPenalty, | |
| TopP, | |
| TopK, | |
| Truncate(usize, usize), | |
| TypicalP, | |
| NegativeMaxNewTokens, | |
| MaxNewTokens(usize, u32), | |
| MaxTotalTokens(usize, usize, u32), | |
| InputLength(usize, usize), | |
| EmptyInput, | |
| StopSequence(usize, usize), | |
| Tokenizer(String), | |
| } | |
| mod tests { | |
| use super::*; | |
| use crate::default_parameters; | |
| use crate::tests::get_tokenizer; | |
| async fn test_validation_max_new_tokens() { | |
| let tokenizer = None; | |
| let max_best_of = 2; | |
| let max_stop_sequence = 3; | |
| let max_input_length = 4; | |
| let max_total_tokens = 5; | |
| let workers = 1; | |
| let validation = Validation::new( | |
| workers, | |
| tokenizer, | |
| max_best_of, | |
| max_stop_sequence, | |
| max_input_length, | |
| max_total_tokens, | |
| ); | |
| let max_new_tokens = 10; | |
| match validation | |
| .validate_input("Hello".to_string(), None, max_new_tokens) | |
| .await | |
| { | |
| Err(ValidationError::MaxNewTokens(1, 10)) => (), | |
| _ => panic!("Unexpected not max new tokens"), | |
| } | |
| } | |
| async fn test_validation_input_length() { | |
| let tokenizer = Some(get_tokenizer().await); | |
| let max_best_of = 2; | |
| let max_stop_sequence = 3; | |
| let max_input_length = 4; | |
| let max_total_tokens = 5; | |
| let workers = 1; | |
| let validation = Validation::new( | |
| workers, | |
| tokenizer, | |
| max_best_of, | |
| max_stop_sequence, | |
| max_input_length, | |
| max_total_tokens, | |
| ); | |
| let max_new_tokens = 10; | |
| match validation | |
| .validate_input("Hello".to_string(), None, max_new_tokens) | |
| .await | |
| { | |
| Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), | |
| _ => panic!("Unexpected not max new tokens"), | |
| } | |
| } | |
| async fn test_validation_best_of_sampling() { | |
| let tokenizer = Some(get_tokenizer().await); | |
| let max_best_of = 2; | |
| let max_stop_sequence = 3; | |
| let max_input_length = 4; | |
| let max_total_tokens = 5; | |
| let workers = 1; | |
| let validation = Validation::new( | |
| workers, | |
| tokenizer, | |
| max_best_of, | |
| max_stop_sequence, | |
| max_input_length, | |
| max_total_tokens, | |
| ); | |
| match validation | |
| .validate(GenerateRequest { | |
| inputs: "Hello".to_string(), | |
| parameters: GenerateParameters { | |
| best_of: Some(2), | |
| do_sample: false, | |
| ..default_parameters() | |
| }, | |
| }) | |
| .await | |
| { | |
| Err(ValidationError::BestOfSampling) => (), | |
| _ => panic!("Unexpected not best of sampling"), | |
| } | |
| } | |
| async fn test_validation_top_p() { | |
| let tokenizer = Some(get_tokenizer().await); | |
| let max_best_of = 2; | |
| let max_stop_sequence = 3; | |
| let max_input_length = 4; | |
| let max_total_tokens = 5; | |
| let workers = 1; | |
| let validation = Validation::new( | |
| workers, | |
| tokenizer, | |
| max_best_of, | |
| max_stop_sequence, | |
| max_input_length, | |
| max_total_tokens, | |
| ); | |
| match validation | |
| .validate(GenerateRequest { | |
| inputs: "Hello".to_string(), | |
| parameters: GenerateParameters { | |
| top_p: Some(1.0), | |
| ..default_parameters() | |
| }, | |
| }) | |
| .await | |
| { | |
| Err(ValidationError::TopP) => (), | |
| _ => panic!("Unexpected top_p"), | |
| } | |
| match validation | |
| .validate(GenerateRequest { | |
| inputs: "Hello".to_string(), | |
| parameters: GenerateParameters { | |
| top_p: Some(0.99), | |
| max_new_tokens: 1, | |
| ..default_parameters() | |
| }, | |
| }) | |
| .await | |
| { | |
| Ok(_) => (), | |
| _ => panic!("Unexpected top_p error"), | |
| } | |
| let valid_request = validation | |
| .validate(GenerateRequest { | |
| inputs: "Hello".to_string(), | |
| parameters: GenerateParameters { | |
| top_p: None, | |
| max_new_tokens: 1, | |
| ..default_parameters() | |
| }, | |
| }) | |
| .await | |
| .unwrap(); | |
| // top_p == 1.0 is invalid for users to ask for but it's the default resolved value. | |
| assert_eq!(valid_request.parameters.top_p, 1.0); | |
| } | |
| } | |