🔁 SimplePromptClassifier — классификатор промптов (русский)

Model banner

Кратко: модель классифицирует входные промпты/вопросы на три действия:

  • 0 — Поиск в локальной базе знаний (RAG): сначала ищем релевантные документы в локальном индексе и формируем контекст для генерации.
  • 1 — Поиск в сети: триггер запуска обхода внешних поисковых систем/скрейпинга.
  • 2 — Прямой запрос: сразу посылаем промпт в генеративную модель (например, LLM) для синтеза ответа.

Где используется

Подходит для систем, где нужно автоматически решать стратегию обработки пользовательского промпта:

  • чат-боты со связкой Retrieval-Augmented Generation (RAG),
  • голосовые ассистенты,
  • интерфейсы поддержки, где часть запросов решается поиском, часть — генерацией.

Файлы в репозитории

  • pytorch_model.bin — веса модели (state_dict).
  • config.json — конфигурация (input_dim, num_classes, p_dropout, classes).
  • modeling_simple_classifier.py — определение архитектуры.
  • vectorizer.pkl — sklearn-векторизатор (TF-IDF/Count).
  • svd.pkl — TruncatedSVD (опционально).
  • label_encoder.pkl — sklearn.LabelEncoder (для декодирования метки).
  • README.md — эта карточка.

Пример загрузки и инференса (без AutoModel)

# Пример: загрузка напрямую из репозитория HF (не требует локальной копии)
from huggingface_hub import hf_hub_download
import json, pickle, torch
import numpy as np
from types import SimpleNamespace

REPO = "Neweret/SimplePromptClassifier-85k"

config_path = hf_hub_download(REPO, "config.json")
weights_path = hf_hub_download(REPO, "pytorch_model.bin")
vec_path = hf_hub_download(REPO, "vectorizer.pkl")
svd_path = None
try:
    svd_path = hf_hub_download(REPO, "svd.pkl")
except Exception:
    svd_path = None
le_path = hf_hub_download(REPO, "label_encoder.pkl")

cfg = SimpleNamespace(**json.load(open(config_path, "r", encoding="utf-8")))

# --- Динамическая модель ---
class SimpleClassifier(torch.nn.Module):
    def __init__(self, input_dim, num_classes, p_dropout=0.3):
        super().__init__()
        self.linear1 = torch.nn.Linear(input_dim, 256)
        self.ln1 = torch.nn.LayerNorm(256)
        self.dropout = torch.nn.Dropout(p_dropout)
        self.linear2 = torch.nn.Linear(256, 128)
        self.ln2 = torch.nn.LayerNorm(128)
        self.linear_out = torch.nn.Linear(128, num_classes)
    def forward(self, x):
        x = torch.nn.functional.gelu(self.ln1(self.linear1(x)))
        x = self.dropout(x)
        x = torch.nn.functional.gelu(self.ln2(self.linear2(x)))
        x = self.dropout(x)
        return self.linear_out(x)

model = SimpleClassifier(cfg.input_dim, cfg.num_classes, cfg.p_dropout)
state = torch.load(weights_path, map_location="cpu")
model.load_state_dict(state)
model.eval()

# препроцессинг
vectorizer = pickle.load(open(vec_path, "rb"))
svd = pickle.load(open(svd_path, "rb")) if svd_path else None
le = pickle.load(open(le_path, "rb"))

def preprocess(text):
    X = vectorizer.transform([text])
    if svd is not None:
        X = svd.transform(X)
    return X.astype(np.float32)

def predict(text):
    x = preprocess(text)
    xb = torch.from_numpy(x).float()
    with torch.inference_mode():
        logits = model(xb)
        pred = int(torch.argmax(logits, dim=1).cpu().numpy()[0])
    return pred, le.inverse_transform([pred])[0]

# пример
print(predict("Как мне найти документацию по нашей компании?"))
Downloads last month
36
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support