Spaces:
Sleeping
Sleeping
| import json | |
| from Levenshtein import distance | |
| import streamlit as st | |
| import numpy as np | |
| import plotly.express as px | |
| from sklearn.decomposition import PCA | |
| def load_data(): | |
| embeddings = np.load("data/simplesegmentT5_embeddings.npy") | |
| words = json.load(open("data/words.json", "r")) | |
| return embeddings, words | |
| def project_embeddings(embeddings): | |
| pca = PCA(n_components=3) | |
| proj = pca.fit_transform(embeddings) | |
| return proj | |
| def filter_words(words, remove_capitalized, length): | |
| idx = [] | |
| for i, w in enumerate(words): | |
| if remove_capitalized and w.lower() != w: | |
| continue | |
| if len(w) < length[0] or len(w) > length[1]: | |
| continue | |
| idx.append(i) | |
| return idx | |
| def color_length(words): | |
| return [len(w) for w in words] | |
| def color_first_letter(words): | |
| return [min(1, max(0, (ord(w.lower()[0]) - 97) / 26)) for w in words] | |
| def color_levenshtein(words): | |
| return [distance(w, words[4]) for w in words] | |
| def plot_scatter(words, embeddings, remove_capitalized, length, color_select): | |
| idx = filter_words(words, remove_capitalized, length) | |
| filtered_embeddings = embeddings[idx] | |
| filtered_words = [words[i] for i in idx] | |
| proj = project_embeddings(filtered_embeddings) | |
| if color_select == "Word length": | |
| color = color_length(filtered_words) | |
| else: | |
| color = color_levenshtein(filtered_words) | |
| fig = px.scatter_3d( | |
| x=proj[:, 0], | |
| y=proj[:, 1], | |
| z=proj[:, 2], | |
| width=800, | |
| height=600, | |
| color=color, | |
| color_continuous_scale=px.colors.sequential.Viridis, | |
| hover_name=filtered_words, | |
| title="SimpleSegmentT5 Embeddings", | |
| ) | |
| fig.update_traces( | |
| marker={"size": 6, "line": {"width": 2}}, | |
| selector={"mode": "markers"}, | |
| ) | |
| return fig | |
| def main(): | |
| embeddings, words = load_data() | |
| proj = project_embeddings(embeddings) | |
| fig = px.scatter_3d( | |
| x=proj[:, 0], | |
| y=proj[:, 1], | |
| z=proj[:, 2], | |
| color=[len(w) for w in words], | |
| hover_name=words, | |
| title="SimpleSegmentT5 Embeddings", | |
| ) | |
| st.sidebar.title("Settings") | |
| remove_checkbox = st.sidebar.checkbox( | |
| "Remove capitalized words", | |
| value=True, | |
| key="include_capitalized", | |
| ) | |
| length_slider = st.sidebar.slider("Word length", 3, 9, (3, 9)) | |
| color_select = st.sidebar.radio("Color by", ["Word length", "Levenshtein distance to random word"]) | |
| scatter = st.plotly_chart(plot_scatter(words, embeddings, remove_checkbox, length_slider, color_select)) | |
| if __name__ == "__main__": | |
| main() | |