viewembedding / app.py
rianders's picture
Create app.py
5f86ced
raw
history blame
2.11 kB
import streamlit as st
from transformers import BertModel, BertTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.graph_objs as go
# BERT embeddings function
def get_bert_embeddings(words):
# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
embeddings = []
# Extract embeddings
for word in words:
inputs = tokenizer(word, return_tensors='pt')
outputs = model(**inputs)
embeddings.append(outputs.last_hidden_state[0][0].detach().numpy())
# Reduce dimensions to 3 using PCA
pca = PCA(n_components=3)
reduced_embeddings = pca.fit_transform(embeddings)
return reduced_embeddings
# Plotly plotting function
def plot_interactive_bert_embeddings(embeddings, words):
data = []
for i, word in enumerate(words):
trace = go.Scatter3d(
x=[embeddings[i][0]],
y=[embeddings[i][1]],
z=[embeddings[i][2]],
mode='markers+text',
text=[word],
name=word
)
data.append(trace)
layout = go.Layout(
title='3D Scatter Plot of BERT Embeddings',
scene=dict(
xaxis=dict(title='PCA Component 1'),
yaxis=dict(title='PCA Component 2'),
zaxis=dict(title='PCA Component 3')
)
)
fig = go.Figure(data=data, layout=layout)
return fig
# Streamlit app
def main():
st.title("BERT Embeddings Visualization")
# Text input for words
words_input = st.text_area("Enter words/phrases separated by commas:", "Spider-Man, Rocket Racoon, Venom, Spider, Racoon, Snake")
words = [word.strip() for word in words_input.split(',')]
if st.button("Generate Embeddings"):
with st.spinner('Generating embeddings...'):
embeddings = get_bert_embeddings(words)
fig = plot_interactive_bert_embeddings(embeddings, words)
st.plotly_chart(fig, use_container_width=True)
if __name__ == "__main__":
main()