Spaces:
Runtime error
Runtime error
| use float_eq::assert_float_eq; | |
| use serde::Deserialize; | |
| use serde_json::Value; | |
| use std::fs::File; | |
| use std::io::{BufRead, BufReader}; | |
| use std::path::PathBuf; | |
| use std::thread; | |
| use std::thread::sleep; | |
| use std::time::Duration; | |
| use subprocess::{Popen, PopenConfig, Redirection}; | |
| pub struct Token { | |
| id: u32, | |
| text: String, | |
| logprob: Option<f32>, | |
| special: bool, | |
| } | |
| struct Details { | |
| finish_reason: String, | |
| generated_tokens: u32, | |
| tokens: Vec<Token>, | |
| } | |
| struct GeneratedText { | |
| generated_text: String, | |
| details: Details, | |
| } | |
| fn start_launcher(model_id: String, num_shard: usize, port: usize, master_port: usize) -> Popen { | |
| let argv = vec![ | |
| "text-generation-launcher".to_string(), | |
| "--model-id".to_string(), | |
| model_id.clone(), | |
| "--num-shard".to_string(), | |
| num_shard.to_string(), | |
| "--port".to_string(), | |
| port.to_string(), | |
| "--master-port".to_string(), | |
| master_port.to_string(), | |
| "--shard-uds-path".to_string(), | |
| format!("/tmp/test-{}-{}-{}", num_shard, port, master_port), | |
| ]; | |
| let mut launcher = Popen::create( | |
| &argv, | |
| PopenConfig { | |
| stdout: Redirection::Pipe, | |
| stderr: Redirection::Merge, | |
| ..Default::default() | |
| }, | |
| ) | |
| .expect("Could not start launcher"); | |
| // Redirect STDOUT and STDERR to the console | |
| // (STDERR is merged into STDOUT) | |
| let launcher_stdout = launcher.stdout.take().unwrap(); | |
| thread::spawn(move || { | |
| let stdout = BufReader::new(launcher_stdout); | |
| for line in stdout.lines() { | |
| println!("{}", line.unwrap()); | |
| } | |
| }); | |
| for _ in 0..60 { | |
| let health = reqwest::blocking::get(format!("http://localhost:{}/health", port)); | |
| if health.is_ok() { | |
| return launcher; | |
| } | |
| sleep(Duration::from_secs(2)); | |
| } | |
| launcher.terminate().unwrap(); | |
| launcher.wait().unwrap(); | |
| panic!("failed to launch {}", model_id) | |
| } | |
| fn test_model( | |
| model_id: String, | |
| num_shard: usize, | |
| port: usize, | |
| master_port: usize, | |
| ) -> GeneratedText { | |
| let mut launcher = start_launcher(model_id, num_shard, port, master_port); | |
| let data = r#" | |
| { | |
| "inputs": "Test request", | |
| "parameters": { | |
| "details": true | |
| } | |
| }"#; | |
| let req: Value = serde_json::from_str(data).unwrap(); | |
| let client = reqwest::blocking::Client::new(); | |
| let res = client | |
| .post(format!("http://localhost:{}/generate", port)) | |
| .json(&req) | |
| .send(); | |
| launcher.terminate().unwrap(); | |
| launcher.wait().unwrap(); | |
| let result: GeneratedText = res.unwrap().json().unwrap(); | |
| result | |
| } | |
| fn read_json(name: &str) -> GeneratedText { | |
| let mut d = PathBuf::from(env!("CARGO_MANIFEST_DIR")); | |
| d.push("tests/"); | |
| d.push(name); | |
| let file = File::open(d).unwrap(); | |
| let reader = BufReader::new(file); | |
| let result: GeneratedText = serde_json::from_reader(reader).unwrap(); | |
| result | |
| } | |
| fn compare_results(result: GeneratedText, expected: GeneratedText) { | |
| assert_eq!(result.generated_text, expected.generated_text); | |
| assert_eq!(result.details.finish_reason, expected.details.finish_reason); | |
| assert_eq!( | |
| result.details.generated_tokens, | |
| expected.details.generated_tokens | |
| ); | |
| for (token, expected_token) in result | |
| .details | |
| .tokens | |
| .into_iter() | |
| .zip(expected.details.tokens.into_iter()) | |
| { | |
| assert_eq!(token.id, expected_token.id); | |
| assert_eq!(token.text, expected_token.text); | |
| assert_eq!(token.special, expected_token.special); | |
| if let Some(logprob) = token.logprob { | |
| let expected_logprob = expected_token.logprob.unwrap(); | |
| assert_float_eq!(logprob, expected_logprob, abs <= 0.001); | |
| } else { | |
| assert_eq!(token.logprob, expected_token.logprob); | |
| } | |
| } | |
| } | |
| fn test_bloom_560m() { | |
| let expected = read_json("bloom_560m.json"); | |
| let result = test_model("bigscience/bloom-560m".to_string(), 1, 3000, 29500); | |
| compare_results(result, expected); | |
| } | |
| fn test_bloom_560m_distributed() { | |
| let expected = read_json("bloom_560m.json"); | |
| let result = test_model("bigscience/bloom-560m".to_string(), 2, 3001, 29501); | |
| compare_results(result, expected); | |
| } | |
| fn test_mt0_base() { | |
| let expected = read_json("mt0_base.json"); | |
| let result = test_model("bigscience/mt0-base".to_string(), 1, 3002, 29502); | |
| compare_results(result, expected); | |
| } | |