import streamlit as st from transformers import BertModel, BertTokenizer import torch from sklearn.decomposition import PCA import plotly.graph_objs as go # BERT embeddings function def get_bert_embeddings(words): # Load pre-trained BERT model and tokenizer tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') model = BertModel.from_pretrained('bert-base-uncased') embeddings = [] # Extract embeddings for word in words: inputs = tokenizer(word, return_tensors='pt') outputs = model(**inputs) embeddings.append(outputs.last_hidden_state[0][0].detach().numpy()) # Reduce dimensions to 3 using PCA pca = PCA(n_components=3) reduced_embeddings = pca.fit_transform(embeddings) return reduced_embeddings # Plotly plotting function def plot_interactive_bert_embeddings(embeddings, words): if len(words) < 4: st.error("Please provide at least 4 words/phrases for effective visualization.") return None data = [] for i, word in enumerate(words): trace = go.Scatter3d( x=[embeddings[i][0]], y=[embeddings[i][1]], z=[embeddings[i][2]], mode='markers+text', text=[word], name=word ) data.append(trace) layout = go.Layout( title='3D Scatter Plot of BERT Embeddings', scene=dict( xaxis=dict(title='PCA Component 1'), yaxis=dict(title='PCA Component 2'), zaxis=dict(title='PCA Component 3') ), autosize=False, width=800, # Width of the plot height=600 # Height of the plot ) fig = go.Figure(data=data, layout=layout) return fig # Streamlit app def main(): st.title("BERT Embeddings Visualization") # Text input for words words_input = st.text_area("Enter words/phrases separated by commas:", "Spider-Man, Rocket Racoon, Venom, Spider, Racoon, Snake") words = [word.strip() for word in words_input.split(',')] if st.button("Generate Embeddings"): with st.spinner('Generating embeddings...'): embeddings = get_bert_embeddings(words) fig = plot_interactive_bert_embeddings(embeddings, words) if fig is not None: # Only plot if the figure is not None st.plotly_chart(fig, use_container_width=True) if __name__ == "__main__": main()