🔁 SimplePromptClassifier — классификатор промптов (русский)
Кратко: модель классифицирует входные промпты/вопросы на три действия:
- 0 — Поиск в локальной базе знаний (RAG): сначала ищем релевантные документы в локальном индексе и формируем контекст для генерации.
- 1 — Поиск в сети: триггер запуска обхода внешних поисковых систем/скрейпинга.
- 2 — Прямой запрос: сразу посылаем промпт в генеративную модель (например, LLM) для синтеза ответа.
Где используется
Подходит для систем, где нужно автоматически решать стратегию обработки пользовательского промпта:
- чат-боты со связкой Retrieval-Augmented Generation (RAG),
- голосовые ассистенты,
- интерфейсы поддержки, где часть запросов решается поиском, часть — генерацией.
Файлы в репозитории
- pytorch_model.bin— веса модели (state_dict).
- config.json— конфигурация (input_dim, num_classes, p_dropout, classes).
- modeling_simple_classifier.py— определение архитектуры.
- vectorizer.pkl— sklearn-векторизатор (TF-IDF/Count).
- svd.pkl— TruncatedSVD (опционально).
- label_encoder.pkl— sklearn.LabelEncoder (для декодирования метки).
- README.md— эта карточка.
Пример загрузки и инференса (без AutoModel)
# Пример: загрузка напрямую из репозитория HF (не требует локальной копии)
from huggingface_hub import hf_hub_download
import json, pickle, torch
import numpy as np
from types import SimpleNamespace
REPO = "Neweret/SimplePromptClassifier-85k"
config_path = hf_hub_download(REPO, "config.json")
weights_path = hf_hub_download(REPO, "pytorch_model.bin")
vec_path = hf_hub_download(REPO, "vectorizer.pkl")
svd_path = None
try:
    svd_path = hf_hub_download(REPO, "svd.pkl")
except Exception:
    svd_path = None
le_path = hf_hub_download(REPO, "label_encoder.pkl")
cfg = SimpleNamespace(**json.load(open(config_path, "r", encoding="utf-8")))
# --- Динамическая модель ---
class SimpleClassifier(torch.nn.Module):
    def __init__(self, input_dim, num_classes, p_dropout=0.3):
        super().__init__()
        self.linear1 = torch.nn.Linear(input_dim, 256)
        self.ln1 = torch.nn.LayerNorm(256)
        self.dropout = torch.nn.Dropout(p_dropout)
        self.linear2 = torch.nn.Linear(256, 128)
        self.ln2 = torch.nn.LayerNorm(128)
        self.linear_out = torch.nn.Linear(128, num_classes)
    def forward(self, x):
        x = torch.nn.functional.gelu(self.ln1(self.linear1(x)))
        x = self.dropout(x)
        x = torch.nn.functional.gelu(self.ln2(self.linear2(x)))
        x = self.dropout(x)
        return self.linear_out(x)
model = SimpleClassifier(cfg.input_dim, cfg.num_classes, cfg.p_dropout)
state = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state)
model.eval()
# препроцессинг
vectorizer = pickle.load(open(vec_path, "rb"))
svd = pickle.load(open(svd_path, "rb")) if svd_path else None
le = pickle.load(open(le_path, "rb"))
def preprocess(text):
    X = vectorizer.transform([text])
    if svd is not None:
        X = svd.transform(X)
    return X.astype(np.float32)
def predict(text):
    x = preprocess(text)
    xb = torch.from_numpy(x).float()
    with torch.inference_mode():
        logits = model(xb)
        pred = int(torch.argmax(logits, dim=1).cpu().numpy()[0])
    return pred, le.inverse_transform([pred])[0]
# пример
print(predict("Как мне найти документацию по нашей компании?"))
- Downloads last month
- 36
	Inference Providers
	NEW
	
	
	This model isn't deployed by any Inference Provider.
	🙋
			
		Ask for provider support

