Spaces:
Configuration error
Configuration error
| package grpc | |
| import ( | |
| "context" | |
| "fmt" | |
| "log" | |
| "net" | |
| pb "github.com/mudler/LocalAI/pkg/grpc/proto" | |
| "google.golang.org/grpc" | |
| ) | |
| // A GRPC Server that allows to run LLM inference. | |
| // It is used by the LLMServices to expose the LLM functionalities that are called by the client. | |
| // The GRPC Service is general, trying to encompass all the possible LLM options models. | |
| // It depends on the real implementer then what can be done or not. | |
| // | |
| // The server is implemented as a GRPC service, with the following methods: | |
| // - Predict: to run the inference with options | |
| // - PredictStream: to run the inference with options and stream the results | |
| // server is used to implement helloworld.GreeterServer. | |
| type server struct { | |
| pb.UnimplementedBackendServer | |
| llm LLM | |
| } | |
| func (s *server) Health(ctx context.Context, in *pb.HealthMessage) (*pb.Reply, error) { | |
| return newReply("OK"), nil | |
| } | |
| func (s *server) Embedding(ctx context.Context, in *pb.PredictOptions) (*pb.EmbeddingResult, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| embeds, err := s.llm.Embeddings(in) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return &pb.EmbeddingResult{Embeddings: embeds}, nil | |
| } | |
| func (s *server) LoadModel(ctx context.Context, in *pb.ModelOptions) (*pb.Result, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| err := s.llm.Load(in) | |
| if err != nil { | |
| return &pb.Result{Message: fmt.Sprintf("Error loading model: %s", err.Error()), Success: false}, err | |
| } | |
| return &pb.Result{Message: "Loading succeeded", Success: true}, nil | |
| } | |
| func (s *server) Predict(ctx context.Context, in *pb.PredictOptions) (*pb.Reply, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| result, err := s.llm.Predict(in) | |
| return newReply(result), err | |
| } | |
| func (s *server) GenerateImage(ctx context.Context, in *pb.GenerateImageRequest) (*pb.Result, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| err := s.llm.GenerateImage(in) | |
| if err != nil { | |
| return &pb.Result{Message: fmt.Sprintf("Error generating image: %s", err.Error()), Success: false}, err | |
| } | |
| return &pb.Result{Message: "Image generated", Success: true}, nil | |
| } | |
| func (s *server) TTS(ctx context.Context, in *pb.TTSRequest) (*pb.Result, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| err := s.llm.TTS(in) | |
| if err != nil { | |
| return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err | |
| } | |
| return &pb.Result{Message: "TTS audio generated", Success: true}, nil | |
| } | |
| func (s *server) SoundGeneration(ctx context.Context, in *pb.SoundGenerationRequest) (*pb.Result, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| err := s.llm.SoundGeneration(in) | |
| if err != nil { | |
| return &pb.Result{Message: fmt.Sprintf("Error generating audio: %s", err.Error()), Success: false}, err | |
| } | |
| return &pb.Result{Message: "Sound Generation audio generated", Success: true}, nil | |
| } | |
| func (s *server) AudioTranscription(ctx context.Context, in *pb.TranscriptRequest) (*pb.TranscriptResult, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| result, err := s.llm.AudioTranscription(in) | |
| if err != nil { | |
| return nil, err | |
| } | |
| tresult := &pb.TranscriptResult{} | |
| for _, s := range result.Segments { | |
| tks := []int32{} | |
| for _, t := range s.Tokens { | |
| tks = append(tks, int32(t)) | |
| } | |
| tresult.Segments = append(tresult.Segments, | |
| &pb.TranscriptSegment{ | |
| Text: s.Text, | |
| Id: int32(s.Id), | |
| Start: int64(s.Start), | |
| End: int64(s.End), | |
| Tokens: tks, | |
| }) | |
| } | |
| tresult.Text = result.Text | |
| return tresult, nil | |
| } | |
| func (s *server) PredictStream(in *pb.PredictOptions, stream pb.Backend_PredictStreamServer) error { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| resultChan := make(chan string) | |
| done := make(chan bool) | |
| go func() { | |
| for result := range resultChan { | |
| stream.Send(newReply(result)) | |
| } | |
| done <- true | |
| }() | |
| err := s.llm.PredictStream(in, resultChan) | |
| <-done | |
| return err | |
| } | |
| func (s *server) TokenizeString(ctx context.Context, in *pb.PredictOptions) (*pb.TokenizationResponse, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| res, err := s.llm.TokenizeString(in) | |
| if err != nil { | |
| return nil, err | |
| } | |
| castTokens := make([]int32, len(res.Tokens)) | |
| for i, v := range res.Tokens { | |
| castTokens[i] = int32(v) | |
| } | |
| return &pb.TokenizationResponse{ | |
| Length: int32(res.Length), | |
| Tokens: castTokens, | |
| }, err | |
| } | |
| func (s *server) Status(ctx context.Context, in *pb.HealthMessage) (*pb.StatusResponse, error) { | |
| res, err := s.llm.Status() | |
| if err != nil { | |
| return nil, err | |
| } | |
| return &res, nil | |
| } | |
| func (s *server) StoresSet(ctx context.Context, in *pb.StoresSetOptions) (*pb.Result, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| err := s.llm.StoresSet(in) | |
| if err != nil { | |
| return &pb.Result{Message: fmt.Sprintf("Error setting entry: %s", err.Error()), Success: false}, err | |
| } | |
| return &pb.Result{Message: "Set key", Success: true}, nil | |
| } | |
| func (s *server) StoresDelete(ctx context.Context, in *pb.StoresDeleteOptions) (*pb.Result, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| err := s.llm.StoresDelete(in) | |
| if err != nil { | |
| return &pb.Result{Message: fmt.Sprintf("Error deleting entry: %s", err.Error()), Success: false}, err | |
| } | |
| return &pb.Result{Message: "Deleted key", Success: true}, nil | |
| } | |
| func (s *server) StoresGet(ctx context.Context, in *pb.StoresGetOptions) (*pb.StoresGetResult, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| res, err := s.llm.StoresGet(in) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return &res, nil | |
| } | |
| func (s *server) StoresFind(ctx context.Context, in *pb.StoresFindOptions) (*pb.StoresFindResult, error) { | |
| if s.llm.Locking() { | |
| s.llm.Lock() | |
| defer s.llm.Unlock() | |
| } | |
| res, err := s.llm.StoresFind(in) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return &res, nil | |
| } | |
| func StartServer(address string, model LLM) error { | |
| lis, err := net.Listen("tcp", address) | |
| if err != nil { | |
| return err | |
| } | |
| s := grpc.NewServer() | |
| pb.RegisterBackendServer(s, &server{llm: model}) | |
| log.Printf("gRPC Server listening at %v", lis.Addr()) | |
| if err := s.Serve(lis); err != nil { | |
| return err | |
| } | |
| return nil | |
| } | |
| func RunServer(address string, model LLM) (func() error, error) { | |
| lis, err := net.Listen("tcp", address) | |
| if err != nil { | |
| return nil, err | |
| } | |
| s := grpc.NewServer() | |
| pb.RegisterBackendServer(s, &server{llm: model}) | |
| log.Printf("gRPC Server listening at %v", lis.Addr()) | |
| if err = s.Serve(lis); err != nil { | |
| return func() error { | |
| return lis.Close() | |
| }, err | |
| } | |
| return func() error { | |
| s.GracefulStop() | |
| return nil | |
| }, nil | |
| } | |