Spaces:
Runtime error
Runtime error
| /// Single shard Client | |
| use crate::pb::generate::v1::text_generation_service_client::TextGenerationServiceClient; | |
| use crate::pb::generate::v1::*; | |
| use crate::Result; | |
| use grpc_metadata::InjectTelemetryContext; | |
| use tonic::transport::{Channel, Uri}; | |
| use tracing::instrument; | |
| /// Text Generation Inference gRPC client | |
| pub struct Client { | |
| stub: TextGenerationServiceClient<Channel>, | |
| } | |
| impl Client { | |
| /// Returns a client connected to the given url | |
| pub async fn connect(uri: Uri) -> Result<Self> { | |
| let channel = Channel::builder(uri).connect().await?; | |
| Ok(Self { | |
| stub: TextGenerationServiceClient::new(channel), | |
| }) | |
| } | |
| /// Returns a client connected to the given unix socket | |
| pub async fn connect_uds(path: String) -> Result<Self> { | |
| let channel = Channel::from_shared("http://[::]:50051".to_string()) | |
| .unwrap() | |
| .connect_with_connector(tower::service_fn(move |_: Uri| { | |
| tokio::net::UnixStream::connect(path.clone()) | |
| })) | |
| .await?; | |
| Ok(Self { | |
| stub: TextGenerationServiceClient::new(channel), | |
| }) | |
| } | |
| /// Returns a list of uris or unix sockets of all shards | |
| pub async fn service_discovery(&mut self) -> Result<Vec<String>> { | |
| let request = tonic::Request::new(ServiceDiscoveryRequest {}).inject_context(); | |
| let response = self.stub.service_discovery(request).await?; | |
| let urls = response | |
| .into_inner() | |
| .urls | |
| .into_iter() | |
| // Remove unix socket prefix | |
| .map(|url| match url.strip_prefix("unix://") { | |
| None => url, | |
| Some(stripped_url) => stripped_url.to_string(), | |
| }) | |
| .collect(); | |
| Ok(urls) | |
| } | |
| /// Get model info | |
| pub async fn info(&mut self) -> Result<InfoResponse> { | |
| let request = tonic::Request::new(InfoRequest {}).inject_context(); | |
| let response = self.stub.info(request).await?.into_inner(); | |
| Ok(response) | |
| } | |
| /// Get model health | |
| pub async fn health(&mut self) -> Result<HealthResponse> { | |
| let request = tonic::Request::new(HealthRequest {}).inject_context(); | |
| let response = self.stub.health(request).await?.into_inner(); | |
| Ok(response) | |
| } | |
| /// Clear the past generations cache | |
| pub async fn clear_cache(&mut self, batch_id: Option<u64>) -> Result<()> { | |
| let request = tonic::Request::new(ClearCacheRequest { id: batch_id }).inject_context(); | |
| self.stub.clear_cache(request).await?; | |
| Ok(()) | |
| } | |
| /// Filter a cached batch | |
| pub async fn filter_batch( | |
| &mut self, | |
| batch_id: u64, | |
| keep_requests: Vec<Request>, | |
| ) -> Result<Option<Batch>> { | |
| let request = tonic::Request::new(FilterBatchRequest { | |
| batch_id, | |
| keep_requests, | |
| }) | |
| .inject_context(); | |
| let filtered_batch = self.stub.filter_batch(request).await?.into_inner(); | |
| Ok(filtered_batch.batch) | |
| } | |
| /// Generate one token for each request in the given batch | |
| /// | |
| /// Returns Generation for each request in batch | |
| /// and the next cached batch | |
| pub async fn prefill(&mut self, batch: Batch) -> Result<(Vec<Generation>, Option<Batch>)> { | |
| let request = tonic::Request::new(PrefillRequest { batch: Some(batch) }).inject_context(); | |
| let response = self.stub.prefill(request).await?.into_inner(); | |
| Ok((response.generations, response.batch)) | |
| } | |
| /// Generate one token for each request in the given cached batches | |
| /// | |
| /// Returns Generation for each request in batches | |
| /// and the next cached batch | |
| pub async fn decode( | |
| &mut self, | |
| batches: Vec<Batch>, | |
| ) -> Result<(Vec<Generation>, Option<Batch>)> { | |
| let request = tonic::Request::new(DecodeRequest { batches }).inject_context(); | |
| let response = self.stub.decode(request).await?.into_inner(); | |
| Ok((response.generations, response.batch)) | |
| } | |
| } | |