rianders commited on
Commit
5f86ced
·
1 Parent(s): bd6ce4a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import BertModel, BertTokenizer
3
+ import torch
4
+ from sklearn.decomposition import PCA
5
+ import plotly.graph_objs as go
6
+
7
+ # BERT embeddings function
8
+ def get_bert_embeddings(words):
9
+ # Load pre-trained BERT model and tokenizer
10
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
11
+ model = BertModel.from_pretrained('bert-base-uncased')
12
+
13
+ embeddings = []
14
+
15
+ # Extract embeddings
16
+ for word in words:
17
+ inputs = tokenizer(word, return_tensors='pt')
18
+ outputs = model(**inputs)
19
+ embeddings.append(outputs.last_hidden_state[0][0].detach().numpy())
20
+
21
+ # Reduce dimensions to 3 using PCA
22
+ pca = PCA(n_components=3)
23
+ reduced_embeddings = pca.fit_transform(embeddings)
24
+
25
+ return reduced_embeddings
26
+
27
+ # Plotly plotting function
28
+ def plot_interactive_bert_embeddings(embeddings, words):
29
+ data = []
30
+
31
+ for i, word in enumerate(words):
32
+ trace = go.Scatter3d(
33
+ x=[embeddings[i][0]],
34
+ y=[embeddings[i][1]],
35
+ z=[embeddings[i][2]],
36
+ mode='markers+text',
37
+ text=[word],
38
+ name=word
39
+ )
40
+ data.append(trace)
41
+
42
+ layout = go.Layout(
43
+ title='3D Scatter Plot of BERT Embeddings',
44
+ scene=dict(
45
+ xaxis=dict(title='PCA Component 1'),
46
+ yaxis=dict(title='PCA Component 2'),
47
+ zaxis=dict(title='PCA Component 3')
48
+ )
49
+ )
50
+
51
+ fig = go.Figure(data=data, layout=layout)
52
+ return fig
53
+
54
+ # Streamlit app
55
+ def main():
56
+ st.title("BERT Embeddings Visualization")
57
+
58
+ # Text input for words
59
+ words_input = st.text_area("Enter words/phrases separated by commas:", "Spider-Man, Rocket Racoon, Venom, Spider, Racoon, Snake")
60
+ words = [word.strip() for word in words_input.split(',')]
61
+
62
+ if st.button("Generate Embeddings"):
63
+ with st.spinner('Generating embeddings...'):
64
+ embeddings = get_bert_embeddings(words)
65
+ fig = plot_interactive_bert_embeddings(embeddings, words)
66
+ st.plotly_chart(fig, use_container_width=True)
67
+
68
+ if __name__ == "__main__":
69
+ main()