|
|
|
|
|
|
|
|
import os, sys, types, pathlib |
|
|
|
|
|
os.environ.setdefault("PORT","7860") |
|
|
os.environ.setdefault("STREAMLIT_SERVER_PORT","7860") |
|
|
os.environ.setdefault("STREAMLIT_SERVER_ADDRESS","0.0.0.0") |
|
|
os.environ.setdefault("STREAMLIT_SERVER_HEADLESS","true") |
|
|
os.environ.setdefault("STREAMLIT_SERVER_ENABLECORS","false") |
|
|
os.environ.setdefault("STREAMLIT_SERVER_ENABLE_XSRF_PROTECTION","false") |
|
|
os.environ.setdefault("STREAMLIT_BROWSER_GATHERUSAGESTATS","false") |
|
|
|
|
|
|
|
|
os.environ.setdefault("HOME","/app") |
|
|
os.environ.setdefault("STREAMLIT_CONFIG_DIR","/app/.streamlit") |
|
|
os.environ.setdefault("MPLCONFIGDIR","/tmp/matplotlib") |
|
|
for v in ("STREAMLIT_CONFIG_DIR","MPLCONFIGDIR"): |
|
|
p = pathlib.Path(os.environ[v]); p.mkdir(parents=True, exist_ok=True) |
|
|
try: p.chmod(0o777) |
|
|
except PermissionError: pass |
|
|
|
|
|
|
|
|
|
|
|
if os.environ.get("DISABLE_PYNNDESCENT","1") == "1": |
|
|
pkg = types.ModuleType("pynndescent") |
|
|
pkg.__path__ = [] |
|
|
class NNDescent: |
|
|
def __init__(self, *a, **k): raise RuntimeError("PyNNDescent disabled; using precomputed_knn.") |
|
|
pkg.NNDescent = NNDescent |
|
|
sys.modules["pynndescent"] = pkg |
|
|
|
|
|
dist = types.ModuleType("pynndescent.distances") |
|
|
dist.named_distances = {} |
|
|
sys.modules["pynndescent.distances"] = dist |
|
|
|
|
|
sparse = types.ModuleType("pynndescent.sparse") |
|
|
sparse.sparse_named_distances = {} |
|
|
sys.modules["pynndescent.sparse"] = sparse |
|
|
|
|
|
|
|
|
|
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
|
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import seaborn as sns |
|
|
import matplotlib.pyplot as plt |
|
|
import streamlit as st |
|
|
from typing import List, Dict |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from sklearn.neighbors import NearestNeighbors |
|
|
import umap |
|
|
import re |
|
|
import io |
|
|
|
|
|
|
|
|
def precompute_knn(X: np.ndarray, n_neighbors: int = 15, metric: str = "cosine"): |
|
|
nn = NearestNeighbors(n_neighbors=n_neighbors, metric=metric, n_jobs=-1) |
|
|
nn.fit(X) |
|
|
dists, indices = nn.kneighbors(X, n_neighbors=n_neighbors, return_distance=True) |
|
|
return indices.astype(np.int32, copy=False), dists.astype(np.float32, copy=False) |
|
|
|
|
|
def umap_embed(X: np.ndarray, n_neighbors=15, metric="cosine", **kwargs): |
|
|
idx, dst = precompute_knn(X, n_neighbors=n_neighbors, metric=metric) |
|
|
reducer = umap.UMAP( |
|
|
n_neighbors=n_neighbors, |
|
|
metric=metric, |
|
|
precomputed_knn=(idx, dst), |
|
|
force_approximation_algorithm=False, |
|
|
**kwargs |
|
|
) |
|
|
return reducer.fit_transform(X) |
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Embedding-Native Cognition — UMAP demo", |
|
|
layout="wide" |
|
|
) |
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
/* Keep stable scrollbar width */ |
|
|
:root, |
|
|
html, |
|
|
body, |
|
|
[data-testid="stAppViewContainer"] { |
|
|
scrollbar-gutter: stable both-edges; |
|
|
} |
|
|
|
|
|
/* Force a scrollbar so layout doesn't jump */ |
|
|
html { |
|
|
overflow-y: scroll; |
|
|
} |
|
|
|
|
|
/* Constrain the main content width and restore vertical breathing room */ |
|
|
div.block-container { |
|
|
max-width: 1480px !important; |
|
|
min-width: 1480px !important; |
|
|
margin: 0 auto; |
|
|
padding: 1.5rem 1rem 1rem 1rem; /* <-- top padding restored */ |
|
|
} |
|
|
|
|
|
/* Mobile fallback */ |
|
|
@media (max-width: 1400px) { |
|
|
div.block-container { |
|
|
max-width: 100% !important; |
|
|
min-width: 0 !important; |
|
|
} |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
st.sidebar.title("Settings") |
|
|
|
|
|
model_name = st.sidebar.selectbox( |
|
|
"Embedding model", |
|
|
options=[ |
|
|
|
|
|
"sentence-transformers/all-mpnet-base-v2", |
|
|
"sentence-transformers/all-MiniLM-L6-v2", |
|
|
"sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2", |
|
|
"sentence-transformers/nli-mpnet-base-v2", |
|
|
"Qwen/Qwen3-Embedding-0.6B", |
|
|
"Qwen/Qwen3-Embedding-4B" |
|
|
], |
|
|
index=0, |
|
|
help="MiniLM is small & fast; mpnet is stronger but heavier. Qwen is most powerful, but very large" |
|
|
) |
|
|
|
|
|
n_neighbors = st.sidebar.slider("UMAP n_neighbors", min_value=5, max_value=50, value=23, step=1) |
|
|
min_dist = st.sidebar.slider("UMAP min_dist", min_value=0.0, max_value=0.99, value=0.4, step=0.01) |
|
|
metric = st.sidebar.selectbox("UMAP metric", options=["cosine","euclidean"], index=0) |
|
|
show_labels = st.sidebar.checkbox("Show labels on plot", value=True) |
|
|
random_state = st.sidebar.number_input("Random seed", min_value=0, value=42, step=1) |
|
|
|
|
|
|
|
|
from typing import Dict, List |
|
|
|
|
|
DEFAULT_LEX: Dict[str, List[str]] = { |
|
|
"animal": ["cat","dog","lion","Pangolin","tiger","wolf","elephant","horse","zebra","bear","monkey","penguin","sparrow","raven","crow","cow","rat","mouse","whale","fish","frog"], |
|
|
"vehicle": ["car","truck","bus","bicycle","motorcycle","train","airplane","boat","ship","submarine","tractor","ford_f150","skateboard","monorail","scooter","unicycle","segway"], |
|
|
"style": ["grungy","Japanese-style","oriental-style","western-style","Shakespearian-style","old-English","Victorian","gothic","plain/ordinary","medieval","garish","fantasy","magical","naturalistic","modern","Art-Nouveau"], |
|
|
"emotion": ["happy","sad","angry","fearful","joyful","calm","anxious","surprised","bored","proud","jealous","ennui","curious","spiteful","furious","ecstatic","horrified","inquisitive"], |
|
|
"scene": ["city","forest","woods","nature","town","downtown-LA","downtown","tokyo-city","sci-fi-city","medieval-town","trailer-park","park","plains","prairie","mountains"] |
|
|
} |
|
|
|
|
|
|
|
|
if "lex" not in st.session_state: |
|
|
st.session_state.lex = {k: list(v) for k, v in DEFAULT_LEX.items()} |
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
def _canonicalize(name: str) -> str: |
|
|
name = (name or "").strip().lower() |
|
|
name = re.sub(r"[^a-z0-9 _-]+", "", name) |
|
|
name = re.sub(r"\s+", "_", name) |
|
|
return name[:32] |
|
|
|
|
|
with st.sidebar.expander("📂 Manage categories", expanded=False): |
|
|
new_cat_input = st.text_input("New category name", placeholder="e.g., 'color'", key="new_cat_input") |
|
|
if st.button("➕ Add category", use_container_width=True, key="btn_add_cat"): |
|
|
cat = _canonicalize(new_cat_input) |
|
|
if not cat: |
|
|
st.warning("Please enter a category name.") |
|
|
elif cat in st.session_state.lex: |
|
|
st.info(f"'{cat}' already exists.") |
|
|
else: |
|
|
st.session_state.lex[cat] = [] |
|
|
st.session_state.setdefault(f"ta_{cat}", "") |
|
|
st.success(f"Added '{cat}'. Scroll to the editor to add words.") |
|
|
|
|
|
if st.session_state.lex: |
|
|
rm_sel = st.selectbox("Remove category", ["(none)"] + sorted(st.session_state.lex.keys()), key="rm_sel") |
|
|
if st.button("🗑 Remove selected", use_container_width=True, key="btn_rm_cat") and rm_sel != "(none)": |
|
|
st.session_state.lex.pop(rm_sel, None) |
|
|
st.session_state.pop(f"ta_{rm_sel}", None) |
|
|
st.success(f"Removed '{rm_sel}'") |
|
|
|
|
|
if st.button("↺ Reset to defaults", use_container_width=True, key="btn_reset_lex"): |
|
|
st.session_state.lex = {k: list(v) for k, v in DEFAULT_LEX.items()} |
|
|
|
|
|
for k in list(st.session_state.keys()): |
|
|
if isinstance(k, str) and k.startswith("ta_"): |
|
|
st.session_state.pop(k, None) |
|
|
st.success("Restored default categories.") |
|
|
|
|
|
st.sidebar.markdown("---") |
|
|
run_btn = st.sidebar.button("▶️ Run (embed + UMAP)") |
|
|
|
|
|
st.title("Vector Embeddings as a Cognitive-Inspired Representational Framework — UMAP Demo") |
|
|
st.write("Edit categories/words below (one per line). Click **Run** in the sidebar to compute embeddings and UMAP.") |
|
|
|
|
|
|
|
|
with st.expander("Edit categories & words", expanded=True): |
|
|
cols = st.columns(2) |
|
|
cats_sorted = sorted(st.session_state.lex.keys()) |
|
|
for i, cat in enumerate(cats_sorted): |
|
|
default_text = "\n".join(st.session_state.lex[cat]) |
|
|
text = cols[i % 2].text_area(cat, default_text, height=220, key=f"ta_{cat}") |
|
|
st.session_state.lex[cat] = [w.strip() for w in text.splitlines() if w.strip()] |
|
|
|
|
|
|
|
|
lex: Dict[str, List[str]] = {c: list(ws) for c, ws in st.session_state.lex.items()} |
|
|
|
|
|
words: List[str] = [] |
|
|
cats: List[str] = [] |
|
|
seen = set() |
|
|
for c, ws in lex.items(): |
|
|
for w in ws: |
|
|
if w not in seen: |
|
|
words.append(w) |
|
|
cats.append(c) |
|
|
seen.add(w) |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
|
def load_model(name: str): |
|
|
|
|
|
return SentenceTransformer(name) |
|
|
|
|
|
@st.cache_data(show_spinner=False) |
|
|
def compute_embeddings(name: str, tokens: List[str]) -> np.ndarray: |
|
|
model = load_model(name) |
|
|
|
|
|
X = model.encode( |
|
|
tokens, |
|
|
normalize_embeddings=True, |
|
|
show_progress_bar=False, |
|
|
batch_size=32 |
|
|
) |
|
|
return X.astype(np.float32, copy=False) |
|
|
|
|
|
def cosine_distance(a: np.ndarray, b: np.ndarray, eps: float = 1e-9) -> float: |
|
|
return 1 - float(np.dot(a, b) / (np.linalg.norm(a)*np.linalg.norm(b) + eps)) |
|
|
|
|
|
|
|
|
if len(words) < 5: |
|
|
st.warning("Please provide at least 5 unique words across categories.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
st.markdown("---") |
|
|
with st.expander("📈 Results", expanded=True): |
|
|
|
|
|
if not run_btn: |
|
|
st.info("Ready. Adjust settings, then click **Run** in the sidebar to compute embeddings and UMAP.") |
|
|
st.write(f"Current word count: **{len(words)}** across **{sum(1 for k,v in lex.items() if v)}** categories.") |
|
|
st.stop() |
|
|
|
|
|
|
|
|
with st.spinner("Embedding words (downloading model on first run)..."): |
|
|
X = compute_embeddings(model_name, words) |
|
|
|
|
|
with st.spinner("Running UMAP..."): |
|
|
np.random.seed(int(random_state)) |
|
|
proj = umap_embed( |
|
|
X, |
|
|
n_neighbors=n_neighbors, |
|
|
metric=metric, |
|
|
random_state=int(random_state), |
|
|
min_dist=float(min_dist), |
|
|
) |
|
|
|
|
|
|
|
|
df = pd.DataFrame({"x": proj[:, 0], "y": proj[:, 1], "word": words, "cat": cats}) |
|
|
|
|
|
rows = [] |
|
|
for c in lex: |
|
|
idx = [i for i, cat in enumerate(cats) if cat == c] |
|
|
if not idx: |
|
|
continue |
|
|
proto = X[idx].mean(axis=0) |
|
|
for i in idx: |
|
|
rows.append({ |
|
|
"category": c, |
|
|
"word": words[i], |
|
|
"cosine_dist_to_centroid": cosine_distance(X[i], proto), |
|
|
}) |
|
|
|
|
|
dist_df = ( |
|
|
pd.DataFrame(rows) |
|
|
.sort_values(["category", "cosine_dist_to_centroid"]) |
|
|
.reset_index(drop=True) |
|
|
) |
|
|
|
|
|
if dist_df.empty: |
|
|
st.info("No rows to display (did all categories end up empty?).") |
|
|
else: |
|
|
col_plot, col_table = st.columns([7, 5], gap="large") |
|
|
|
|
|
|
|
|
with col_plot: |
|
|
fig, ax = plt.subplots(figsize=(9, 7)) |
|
|
|
|
|
sns.scatterplot( |
|
|
data=df, |
|
|
x="x", |
|
|
y="y", |
|
|
hue="cat", |
|
|
s=60, |
|
|
palette="tab10", |
|
|
ax=ax, |
|
|
linewidth=0.5, |
|
|
edgecolor="k", |
|
|
) |
|
|
|
|
|
centroids2d = df.groupby("cat")[["x", "y"]].mean() |
|
|
ax.scatter( |
|
|
centroids2d["x"], |
|
|
centroids2d["y"], |
|
|
marker="*", |
|
|
s=400, |
|
|
edgecolors="yellow", |
|
|
) |
|
|
|
|
|
if show_labels and len(df) <= 200: |
|
|
for _, r in df.iterrows(): |
|
|
ax.text( |
|
|
r.x + 0.01, |
|
|
r.y + 0.01, |
|
|
r.word, |
|
|
fontsize=8, |
|
|
alpha=0.7, |
|
|
) |
|
|
|
|
|
ax.set_title("UMAP of tiny semantic set — clusters & categories") |
|
|
|
|
|
|
|
|
buf = io.BytesIO() |
|
|
fig.savefig( |
|
|
buf, |
|
|
format="png", |
|
|
dpi=240, |
|
|
bbox_inches="tight", |
|
|
facecolor="white", |
|
|
) |
|
|
buf.seek(0) |
|
|
plot_png_bytes = buf.getvalue() |
|
|
|
|
|
|
|
|
st.pyplot(fig, clear_figure=True) |
|
|
|
|
|
|
|
|
with col_table: |
|
|
st.subheader("Prototype (centroid) distances") |
|
|
st.dataframe( |
|
|
dist_df, |
|
|
hide_index=True, |
|
|
use_container_width=True, |
|
|
height=220, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from sklearn.manifold import trustworthiness |
|
|
from sklearn.neighbors import NearestNeighbors |
|
|
|
|
|
def _knn_idx(data_matrix, k_neighbors, metric_name): |
|
|
nn_local = NearestNeighbors(n_neighbors=k_neighbors + 1, metric=metric_name) |
|
|
nn_local.fit(data_matrix) |
|
|
idx_local = nn_local.kneighbors(return_distance=False) |
|
|
|
|
|
return idx_local[:, 1:k_neighbors + 1] |
|
|
|
|
|
|
|
|
k = min(10, max(2, len(words) // 5)) |
|
|
|
|
|
|
|
|
tw = float( |
|
|
trustworthiness( |
|
|
X, |
|
|
proj, |
|
|
n_neighbors=k, |
|
|
metric=metric, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hi_idx = _knn_idx(X, k_neighbors=k, metric_name=metric) |
|
|
lo_idx = _knn_idx(proj, k_neighbors=k, metric_name="euclidean") |
|
|
pres = 0.0 |
|
|
for i in range(len(words)): |
|
|
s_hi = set(hi_idx[i]) |
|
|
s_lo = set(lo_idx[i]) |
|
|
pres += len(s_hi & s_lo) / k |
|
|
pres /= len(words) |
|
|
|
|
|
|
|
|
st.markdown( |
|
|
f"**2D fidelity** · " |
|
|
f"Trustworthiness@{k}: **{tw:.3f}** · " |
|
|
f"Neighborhood Preservation@{k}: **{pres:.3f}**" |
|
|
) |
|
|
st.caption( |
|
|
"Higher is better. Trustworthiness penalizes false neighbors " |
|
|
"introduced by the projection; Neighborhood Preservation measures " |
|
|
"how many true neighbors survive in 2D." |
|
|
) |
|
|
|
|
|
|
|
|
with st.expander("Downloads", expanded=False): |
|
|
|
|
|
emb_df = pd.DataFrame( |
|
|
X, columns=[f"e{j}" for j in range(X.shape[1])] |
|
|
) |
|
|
emb_df.insert(0, "word", words) |
|
|
emb_df.insert(1, "category", cats) |
|
|
|
|
|
|
|
|
umap_df = df.copy() |
|
|
|
|
|
c1, c2, c3, c4 = st.columns(4) |
|
|
|
|
|
c1.download_button( |
|
|
"Embeddings (.csv)", |
|
|
data=emb_df.to_csv(index=False).encode("utf-8"), |
|
|
file_name="embeddings.csv", |
|
|
mime="text/csv", |
|
|
use_container_width=False, |
|
|
) |
|
|
|
|
|
c2.download_button( |
|
|
"UMAP coords (.csv)", |
|
|
data=umap_df.to_csv(index=False).encode("utf-8"), |
|
|
file_name="umap_coords.csv", |
|
|
mime="text/csv", |
|
|
use_container_width=False, |
|
|
) |
|
|
|
|
|
c3.download_button( |
|
|
"Prototype distances (.csv)", |
|
|
data=dist_df.to_csv(index=False).encode("utf-8"), |
|
|
file_name="prototype_distances.csv", |
|
|
mime="text/csv", |
|
|
use_container_width=False, |
|
|
) |
|
|
|
|
|
c4.download_button( |
|
|
"Plot (.png)", |
|
|
data=plot_png_bytes, |
|
|
file_name="umap_prototypes.png", |
|
|
mime="image/png", |
|
|
use_container_width=False, |
|
|
) |