Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| import plotly.express as px | |
| import torch.nn.functional as F | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| from sklearn.manifold import TSNE | |
| import plotly.express as px | |
| import torch | |
| import plotly.io as pio | |
| pio.templates.default = "plotly" | |
| st. set_page_config(layout="wide") | |
| st.header("Explore the Russian Dolls :nesting_dolls: - _ :green[Nomic Embed 1.5] _",divider='violet') | |
| st.write("Matryoshka Representation Learning : to learn more :https://aniketrege.github.io/blog/2024/mrl/") | |
| def get_df(): | |
| prodDf = pd.read_csv("./sample_products.csv") | |
| return prodDf | |
| def get_nomicModel(): | |
| model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True) | |
| return model | |
| def get_searchQueryEmbedding(query): | |
| embeddings = model.encode(["search_query: "+query], convert_to_tensor=True) | |
| embeddings = F.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],)) | |
| return embeddings | |
| def get_normEmbed(query_embedding,loaded_embed,matryoshka_dim): | |
| query_embedNorm = query_embedding[:, :matryoshka_dim] | |
| query_embedNorm = F.normalize(query_embedNorm, p=2, dim=1) | |
| loaded_embedNorm = loaded_embed[:, :matryoshka_dim] | |
| loaded_embedNorm = F.normalize(loaded_embedNorm, p=2, dim=1) | |
| return query_embedNorm,loaded_embedNorm | |
| def insert_line_breaks(text, interval=30): | |
| words = text.split(' ') | |
| wrapped_text = '' | |
| line_length = 0 | |
| for word in words: | |
| wrapped_text += word + ' ' | |
| line_length += len(word) + 1 | |
| if line_length >= interval: | |
| wrapped_text += '<br>' | |
| line_length = 0 | |
| return wrapped_text.strip() | |
| # Automatically wrap the hover text | |
| model = get_nomicModel() | |
| bigDollEmbedding = get_df()["Description"] | |
| docEmbedding = torch.Tensor(np.load("./prodBigDollEmbeddings.npy")) | |
| toggle = st.toggle('sample queries') | |
| with st.form("my_form"): | |
| if toggle: | |
| query_input = st.selectbox('select a query:', | |
| ('Pack of two assorted boxers, has two pockets, an elasticated waistbandDisclaimer: The final product delivered might vary in colour and prints from the display here.', | |
| 'Beige self design shoulder bag, has a zip closure1 main compartment, 3 inner pocketsTwo Handles', | |
| 'Set Content: 1 photo frameColour: Black and whiteFrame Pattern: SolidShape: SquareMaterial: Acrylic', | |
| 'A pair of dark grey solid boxers, has a slip-on closure with an elasticated waistband and drawstring, two pocket', | |
| 'Red & Black solid sweatshirt, has a hood, two pockets, long sleeves, zip closure, straight hem')) | |
| else: | |
| query_input = st.text_input("") | |
| Matry_dim = st.slider('Matryoshka Dimension', 64, 768, 64) | |
| submitted = st.form_submit_button("Submit") | |
| if submitted: | |
| queryEmbedding = get_searchQueryEmbedding(query_input) | |
| query_embedNorm,loaded_embedNorm = get_normEmbed(queryEmbedding,docEmbedding,Matry_dim) | |
| similarity_scores = torch.matmul(query_embedNorm,loaded_embedNorm.T) | |
| top_values, top_indices = torch.topk(similarity_scores, 10, dim=1) | |
| to_index = list(top_indices.numpy()[0]) | |
| top_items_per_query = [bigDollEmbedding.tolist()[index] for index in to_index] | |
| print(top_values) | |
| df = pd.DataFrame({"Product":top_items_per_query,"Score":top_values[0]}) | |
| df["Product"] = df["Product"].str.replace("search_document:","") | |
| # st.dataframe(df) | |
| allEmbedd = torch.concat([query_embedNorm,loaded_embedNorm]) | |
| tsne = TSNE(n_components=2, random_state=0) | |
| projections = tsne.fit_transform(allEmbedd) | |
| listHover = bigDollEmbedding.tolist() | |
| listHover =[insert_line_breaks(hover_text, 30) for hover_text in listHover] | |
| fig = px.scatter( | |
| projections, x=0, y=1, | |
| hover_name=[query_input]+listHover, | |
| color=["search_query"]+(["search_document"]*270) | |
| ) | |
| col1, col2 = st.columns([2, 2]) | |
| col2.plotly_chart(fig, use_container_width=True) | |
| col1.dataframe(df) | |
| st.caption("Dataset Credit : kaggle") | |