Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from transformers import BertModel, BertTokenizer | |
| import torch | |
| from sklearn.decomposition import PCA | |
| import plotly.graph_objs as go | |
| # BERT embeddings function | |
| def get_bert_embeddings(words): | |
| # Load pre-trained BERT model and tokenizer | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| model = BertModel.from_pretrained('bert-base-uncased') | |
| embeddings = [] | |
| # Extract embeddings | |
| for word in words: | |
| inputs = tokenizer(word, return_tensors='pt') | |
| outputs = model(**inputs) | |
| embeddings.append(outputs.last_hidden_state[0][0].detach().numpy()) | |
| # Reduce dimensions to 3 using PCA | |
| pca = PCA(n_components=3) | |
| reduced_embeddings = pca.fit_transform(embeddings) | |
| return reduced_embeddings | |
| # Plotly plotting function | |
| def plot_interactive_bert_embeddings(embeddings, words): | |
| data = [] | |
| for i, word in enumerate(words): | |
| trace = go.Scatter3d( | |
| x=[embeddings[i][0]], | |
| y=[embeddings[i][1]], | |
| z=[embeddings[i][2]], | |
| mode='markers+text', | |
| text=[word], | |
| name=word | |
| ) | |
| data.append(trace) | |
| layout = go.Layout( | |
| title='3D Scatter Plot of BERT Embeddings', | |
| scene=dict( | |
| xaxis=dict(title='PCA Component 1'), | |
| yaxis=dict(title='PCA Component 2'), | |
| zaxis=dict(title='PCA Component 3') | |
| ) | |
| ) | |
| fig = go.Figure(data=data, layout=layout) | |
| return fig | |
| # Streamlit app | |
| def main(): | |
| st.title("BERT Embeddings Visualization") | |
| # Text input for words | |
| words_input = st.text_area("Enter words/phrases separated by commas:", "Spider-Man, Rocket Racoon, Venom, Spider, Racoon, Snake") | |
| words = [word.strip() for word in words_input.split(',')] | |
| if st.button("Generate Embeddings"): | |
| with st.spinner('Generating embeddings...'): | |
| embeddings = get_bert_embeddings(words) | |
| fig = plot_interactive_bert_embeddings(embeddings, words) | |
| st.plotly_chart(fig, use_container_width=True) | |
| if __name__ == "__main__": | |
| main() |