Spaces:
Configuration error
Configuration error
| package config | |
| import ( | |
| "os" | |
| "regexp" | |
| "slices" | |
| "strings" | |
| "github.com/mudler/LocalAI/core/schema" | |
| "github.com/mudler/LocalAI/pkg/downloader" | |
| "github.com/mudler/LocalAI/pkg/functions" | |
| "gopkg.in/yaml.v3" | |
| ) | |
| const ( | |
| RAND_SEED = -1 | |
| ) | |
| type TTSConfig struct { | |
| // Voice wav path or id | |
| Voice string `yaml:"voice"` | |
| // Vall-e-x | |
| VallE VallE `yaml:"vall-e"` | |
| } | |
| type BackendConfig struct { | |
| schema.PredictionOptions `yaml:"parameters"` | |
| Name string `yaml:"name"` | |
| F16 *bool `yaml:"f16"` | |
| Threads *int `yaml:"threads"` | |
| Debug *bool `yaml:"debug"` | |
| Roles map[string]string `yaml:"roles"` | |
| Embeddings *bool `yaml:"embeddings"` | |
| Backend string `yaml:"backend"` | |
| TemplateConfig TemplateConfig `yaml:"template"` | |
| KnownUsecaseStrings []string `yaml:"known_usecases"` | |
| KnownUsecases *BackendConfigUsecases `yaml:"-"` | |
| PromptStrings, InputStrings []string `yaml:"-"` | |
| InputToken [][]int `yaml:"-"` | |
| functionCallString, functionCallNameString string `yaml:"-"` | |
| ResponseFormat string `yaml:"-"` | |
| ResponseFormatMap map[string]interface{} `yaml:"-"` | |
| FunctionsConfig functions.FunctionsConfig `yaml:"function"` | |
| FeatureFlag FeatureFlag `yaml:"feature_flags"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early. | |
| // LLM configs (GPT4ALL, Llama.cpp, ...) | |
| LLMConfig `yaml:",inline"` | |
| // AutoGPTQ specifics | |
| AutoGPTQ AutoGPTQ `yaml:"autogptq"` | |
| // Diffusers | |
| Diffusers Diffusers `yaml:"diffusers"` | |
| Step int `yaml:"step"` | |
| // GRPC Options | |
| GRPC GRPC `yaml:"grpc"` | |
| // TTS specifics | |
| TTSConfig `yaml:"tts"` | |
| // CUDA | |
| // Explicitly enable CUDA or not (some backends might need it) | |
| CUDA bool `yaml:"cuda"` | |
| DownloadFiles []File `yaml:"download_files"` | |
| Description string `yaml:"description"` | |
| Usage string `yaml:"usage"` | |
| } | |
| type File struct { | |
| Filename string `yaml:"filename" json:"filename"` | |
| SHA256 string `yaml:"sha256" json:"sha256"` | |
| URI downloader.URI `yaml:"uri" json:"uri"` | |
| } | |
| type VallE struct { | |
| AudioPath string `yaml:"audio_path"` | |
| } | |
| type FeatureFlag map[string]*bool | |
| func (ff FeatureFlag) Enabled(s string) bool { | |
| v, exist := ff[s] | |
| return exist && v != nil && *v | |
| } | |
| type GRPC struct { | |
| Attempts int `yaml:"attempts"` | |
| AttemptsSleepTime int `yaml:"attempts_sleep_time"` | |
| } | |
| type Diffusers struct { | |
| CUDA bool `yaml:"cuda"` | |
| PipelineType string `yaml:"pipeline_type"` | |
| SchedulerType string `yaml:"scheduler_type"` | |
| EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify | |
| CFGScale float32 `yaml:"cfg_scale"` // Classifier-Free Guidance Scale | |
| IMG2IMG bool `yaml:"img2img"` // Image to Image Diffuser | |
| ClipSkip int `yaml:"clip_skip"` // Skip every N frames | |
| ClipModel string `yaml:"clip_model"` // Clip model to use | |
| ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model | |
| ControlNet string `yaml:"control_net"` | |
| } | |
| // LLMConfig is a struct that holds the configuration that are | |
| // generic for most of the LLM backends. | |
| type LLMConfig struct { | |
| SystemPrompt string `yaml:"system_prompt"` | |
| TensorSplit string `yaml:"tensor_split"` | |
| MainGPU string `yaml:"main_gpu"` | |
| RMSNormEps float32 `yaml:"rms_norm_eps"` | |
| NGQA int32 `yaml:"ngqa"` | |
| PromptCachePath string `yaml:"prompt_cache_path"` | |
| PromptCacheAll bool `yaml:"prompt_cache_all"` | |
| PromptCacheRO bool `yaml:"prompt_cache_ro"` | |
| MirostatETA *float64 `yaml:"mirostat_eta"` | |
| MirostatTAU *float64 `yaml:"mirostat_tau"` | |
| Mirostat *int `yaml:"mirostat"` | |
| NGPULayers *int `yaml:"gpu_layers"` | |
| MMap *bool `yaml:"mmap"` | |
| MMlock *bool `yaml:"mmlock"` | |
| LowVRAM *bool `yaml:"low_vram"` | |
| Grammar string `yaml:"grammar"` | |
| StopWords []string `yaml:"stopwords"` | |
| Cutstrings []string `yaml:"cutstrings"` | |
| ExtractRegex []string `yaml:"extract_regex"` | |
| TrimSpace []string `yaml:"trimspace"` | |
| TrimSuffix []string `yaml:"trimsuffix"` | |
| ContextSize *int `yaml:"context_size"` | |
| NUMA bool `yaml:"numa"` | |
| LoraAdapter string `yaml:"lora_adapter"` | |
| LoraBase string `yaml:"lora_base"` | |
| LoraAdapters []string `yaml:"lora_adapters"` | |
| LoraScales []float32 `yaml:"lora_scales"` | |
| LoraScale float32 `yaml:"lora_scale"` | |
| NoMulMatQ bool `yaml:"no_mulmatq"` | |
| DraftModel string `yaml:"draft_model"` | |
| NDraft int32 `yaml:"n_draft"` | |
| Quantization string `yaml:"quantization"` | |
| LoadFormat string `yaml:"load_format"` | |
| GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM | |
| TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM | |
| EnforceEager bool `yaml:"enforce_eager"` // vLLM | |
| SwapSpace int `yaml:"swap_space"` // vLLM | |
| MaxModelLen int `yaml:"max_model_len"` // vLLM | |
| TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM | |
| MMProj string `yaml:"mmproj"` | |
| FlashAttention bool `yaml:"flash_attention"` | |
| NoKVOffloading bool `yaml:"no_kv_offloading"` | |
| RopeScaling string `yaml:"rope_scaling"` | |
| ModelType string `yaml:"type"` | |
| YarnExtFactor float32 `yaml:"yarn_ext_factor"` | |
| YarnAttnFactor float32 `yaml:"yarn_attn_factor"` | |
| YarnBetaFast float32 `yaml:"yarn_beta_fast"` | |
| YarnBetaSlow float32 `yaml:"yarn_beta_slow"` | |
| } | |
| // AutoGPTQ is a struct that holds the configuration specific to the AutoGPTQ backend | |
| type AutoGPTQ struct { | |
| ModelBaseName string `yaml:"model_base_name"` | |
| Device string `yaml:"device"` | |
| Triton bool `yaml:"triton"` | |
| UseFastTokenizer bool `yaml:"use_fast_tokenizer"` | |
| } | |
| // TemplateConfig is a struct that holds the configuration of the templating system | |
| type TemplateConfig struct { | |
| // Chat is the template used in the chat completion endpoint | |
| Chat string `yaml:"chat"` | |
| // ChatMessage is the template used for chat messages | |
| ChatMessage string `yaml:"chat_message"` | |
| // Completion is the template used for completion requests | |
| Completion string `yaml:"completion"` | |
| // Edit is the template used for edit completion requests | |
| Edit string `yaml:"edit"` | |
| // Functions is the template used when tools are present in the client requests | |
| Functions string `yaml:"function"` | |
| // UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used. | |
| // Note: this is mostly consumed for backends such as vllm and transformers | |
| // that can use the tokenizers specified in the JSON config files of the models | |
| UseTokenizerTemplate bool `yaml:"use_tokenizer_template"` | |
| // JoinChatMessagesByCharacter is a string that will be used to join chat messages together. | |
| // It defaults to \n | |
| JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character"` | |
| Multimodal string `yaml:"multimodal"` | |
| } | |
| func (c *BackendConfig) UnmarshalYAML(value *yaml.Node) error { | |
| type BCAlias BackendConfig | |
| var aux BCAlias | |
| if err := value.Decode(&aux); err != nil { | |
| return err | |
| } | |
| *c = BackendConfig(aux) | |
| c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings) | |
| return nil | |
| } | |
| func (c *BackendConfig) SetFunctionCallString(s string) { | |
| c.functionCallString = s | |
| } | |
| func (c *BackendConfig) SetFunctionCallNameString(s string) { | |
| c.functionCallNameString = s | |
| } | |
| func (c *BackendConfig) ShouldUseFunctions() bool { | |
| return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) | |
| } | |
| func (c *BackendConfig) ShouldCallSpecificFunction() bool { | |
| return len(c.functionCallNameString) > 0 | |
| } | |
| // MMProjFileName returns the filename of the MMProj file | |
| // If the MMProj is a URL, it will return the MD5 of the URL which is the filename | |
| func (c *BackendConfig) MMProjFileName() string { | |
| uri := downloader.URI(c.MMProj) | |
| if uri.LooksLikeURL() { | |
| f, _ := uri.FilenameFromUrl() | |
| return f | |
| } | |
| return c.MMProj | |
| } | |
| func (c *BackendConfig) IsMMProjURL() bool { | |
| uri := downloader.URI(c.MMProj) | |
| return uri.LooksLikeURL() | |
| } | |
| func (c *BackendConfig) IsModelURL() bool { | |
| uri := downloader.URI(c.Model) | |
| return uri.LooksLikeURL() | |
| } | |
| // ModelFileName returns the filename of the model | |
| // If the model is a URL, it will return the MD5 of the URL which is the filename | |
| func (c *BackendConfig) ModelFileName() string { | |
| uri := downloader.URI(c.Model) | |
| if uri.LooksLikeURL() { | |
| f, _ := uri.FilenameFromUrl() | |
| return f | |
| } | |
| return c.Model | |
| } | |
| func (c *BackendConfig) FunctionToCall() string { | |
| if c.functionCallNameString != "" && | |
| c.functionCallNameString != "none" && c.functionCallNameString != "auto" { | |
| return c.functionCallNameString | |
| } | |
| return c.functionCallString | |
| } | |
| func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) { | |
| lo := &LoadOptions{} | |
| lo.Apply(opts...) | |
| ctx := lo.ctxSize | |
| threads := lo.threads | |
| f16 := lo.f16 | |
| debug := lo.debug | |
| // https://github.com/ggerganov/llama.cpp/blob/75cd4c77292034ecec587ecb401366f57338f7c0/common/sampling.h#L22 | |
| defaultTopP := 0.95 | |
| defaultTopK := 40 | |
| defaultTemp := 0.9 | |
| defaultMirostat := 2 | |
| defaultMirostatTAU := 5.0 | |
| defaultMirostatETA := 0.1 | |
| defaultTypicalP := 1.0 | |
| defaultTFZ := 1.0 | |
| defaultZero := 0 | |
| // Try to offload all GPU layers (if GPU is found) | |
| defaultHigh := 99999999 | |
| trueV := true | |
| falseV := false | |
| if cfg.Seed == nil { | |
| // random number generator seed | |
| defaultSeed := RAND_SEED | |
| cfg.Seed = &defaultSeed | |
| } | |
| if cfg.TopK == nil { | |
| cfg.TopK = &defaultTopK | |
| } | |
| if cfg.TypicalP == nil { | |
| cfg.TypicalP = &defaultTypicalP | |
| } | |
| if cfg.TFZ == nil { | |
| cfg.TFZ = &defaultTFZ | |
| } | |
| if cfg.MMap == nil { | |
| // MMap is enabled by default | |
| // Only exception is for Intel GPUs | |
| if os.Getenv("XPU") != "" { | |
| cfg.MMap = &falseV | |
| } else { | |
| cfg.MMap = &trueV | |
| } | |
| } | |
| if cfg.MMlock == nil { | |
| // MMlock is disabled by default | |
| cfg.MMlock = &falseV | |
| } | |
| if cfg.TopP == nil { | |
| cfg.TopP = &defaultTopP | |
| } | |
| if cfg.Temperature == nil { | |
| cfg.Temperature = &defaultTemp | |
| } | |
| if cfg.Maxtokens == nil { | |
| cfg.Maxtokens = &defaultZero | |
| } | |
| if cfg.Mirostat == nil { | |
| cfg.Mirostat = &defaultMirostat | |
| } | |
| if cfg.MirostatETA == nil { | |
| cfg.MirostatETA = &defaultMirostatETA | |
| } | |
| if cfg.MirostatTAU == nil { | |
| cfg.MirostatTAU = &defaultMirostatTAU | |
| } | |
| if cfg.NGPULayers == nil { | |
| cfg.NGPULayers = &defaultHigh | |
| } | |
| if cfg.LowVRAM == nil { | |
| cfg.LowVRAM = &falseV | |
| } | |
| if cfg.Embeddings == nil { | |
| cfg.Embeddings = &falseV | |
| } | |
| // Value passed by the top level are treated as default (no implicit defaults) | |
| // defaults are set by the user | |
| if ctx == 0 { | |
| ctx = 1024 | |
| } | |
| if cfg.ContextSize == nil { | |
| cfg.ContextSize = &ctx | |
| } | |
| if threads == 0 { | |
| // Threads can't be 0 | |
| threads = 4 | |
| } | |
| if cfg.Threads == nil { | |
| cfg.Threads = &threads | |
| } | |
| if cfg.F16 == nil { | |
| cfg.F16 = &f16 | |
| } | |
| if cfg.Debug == nil { | |
| cfg.Debug = &falseV | |
| } | |
| if debug { | |
| cfg.Debug = &trueV | |
| } | |
| guessDefaultsFromFile(cfg, lo.modelPath) | |
| } | |
| func (c *BackendConfig) Validate() bool { | |
| downloadedFileNames := []string{} | |
| for _, f := range c.DownloadFiles { | |
| downloadedFileNames = append(downloadedFileNames, f.Filename) | |
| } | |
| validationTargets := []string{c.Backend, c.Model, c.MMProj} | |
| validationTargets = append(validationTargets, downloadedFileNames...) | |
| // Simple validation to make sure the model can be correctly loaded | |
| for _, n := range validationTargets { | |
| if n == "" { | |
| continue | |
| } | |
| if strings.HasPrefix(n, string(os.PathSeparator)) || | |
| strings.Contains(n, "..") { | |
| return false | |
| } | |
| } | |
| if c.Backend != "" { | |
| // a regex that checks that is a string name with no special characters, except '-' and '_' | |
| re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`) | |
| return re.MatchString(c.Backend) | |
| } | |
| return true | |
| } | |
| func (c *BackendConfig) HasTemplate() bool { | |
| return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" | |
| } | |
| type BackendConfigUsecases int | |
| const ( | |
| FLAG_ANY BackendConfigUsecases = 0b000000000 | |
| FLAG_CHAT BackendConfigUsecases = 0b000000001 | |
| FLAG_COMPLETION BackendConfigUsecases = 0b000000010 | |
| FLAG_EDIT BackendConfigUsecases = 0b000000100 | |
| FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000 | |
| FLAG_RERANK BackendConfigUsecases = 0b000010000 | |
| FLAG_IMAGE BackendConfigUsecases = 0b000100000 | |
| FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000 | |
| FLAG_TTS BackendConfigUsecases = 0b010000000 | |
| FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000 | |
| // Common Subsets | |
| FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT | |
| ) | |
| func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases { | |
| return map[string]BackendConfigUsecases{ | |
| "FLAG_ANY": FLAG_ANY, | |
| "FLAG_CHAT": FLAG_CHAT, | |
| "FLAG_COMPLETION": FLAG_COMPLETION, | |
| "FLAG_EDIT": FLAG_EDIT, | |
| "FLAG_EMBEDDINGS": FLAG_EMBEDDINGS, | |
| "FLAG_RERANK": FLAG_RERANK, | |
| "FLAG_IMAGE": FLAG_IMAGE, | |
| "FLAG_TRANSCRIPT": FLAG_TRANSCRIPT, | |
| "FLAG_TTS": FLAG_TTS, | |
| "FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION, | |
| "FLAG_LLM": FLAG_LLM, | |
| } | |
| } | |
| func GetUsecasesFromYAML(input []string) *BackendConfigUsecases { | |
| if len(input) == 0 { | |
| return nil | |
| } | |
| result := FLAG_ANY | |
| flags := GetAllBackendConfigUsecases() | |
| for _, str := range input { | |
| flag, exists := flags["FLAG_"+strings.ToUpper(str)] | |
| if exists { | |
| result |= flag | |
| } | |
| } | |
| return &result | |
| } | |
| // HasUsecases examines a BackendConfig and determines which endpoints have a chance of success. | |
| func (c *BackendConfig) HasUsecases(u BackendConfigUsecases) bool { | |
| if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) { | |
| return true | |
| } | |
| return c.GuessUsecases(u) | |
| } | |
| // GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at. | |
| // In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half. | |
| // This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently. | |
| func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool { | |
| if (u & FLAG_CHAT) == FLAG_CHAT { | |
| if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" { | |
| return false | |
| } | |
| } | |
| if (u & FLAG_COMPLETION) == FLAG_COMPLETION { | |
| if c.TemplateConfig.Completion == "" { | |
| return false | |
| } | |
| } | |
| if (u & FLAG_EDIT) == FLAG_EDIT { | |
| if c.TemplateConfig.Edit == "" { | |
| return false | |
| } | |
| } | |
| if (u & FLAG_EMBEDDINGS) == FLAG_EMBEDDINGS { | |
| if c.Embeddings == nil || !*c.Embeddings { | |
| return false | |
| } | |
| } | |
| if (u & FLAG_IMAGE) == FLAG_IMAGE { | |
| imageBackends := []string{"diffusers", "tinydream", "stablediffusion"} | |
| if !slices.Contains(imageBackends, c.Backend) { | |
| return false | |
| } | |
| if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" { | |
| return false | |
| } | |
| } | |
| if (u & FLAG_RERANK) == FLAG_RERANK { | |
| if c.Backend != "rerankers" { | |
| return false | |
| } | |
| } | |
| if (u & FLAG_TRANSCRIPT) == FLAG_TRANSCRIPT { | |
| if c.Backend != "whisper" { | |
| return false | |
| } | |
| } | |
| if (u & FLAG_TTS) == FLAG_TTS { | |
| ttsBackends := []string{"piper", "transformers-musicgen", "parler-tts"} | |
| if !slices.Contains(ttsBackends, c.Backend) { | |
| return false | |
| } | |
| } | |
| if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION { | |
| if c.Backend != "transformers-musicgen" { | |
| return false | |
| } | |
| } | |
| return true | |
| } | |