rianders commited on
Commit
26b32a1
·
1 Parent(s): edbfc19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -3
app.py CHANGED
@@ -4,7 +4,27 @@ import torch
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
6
 
7
- # BERT Embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def plot_interactive_bert_embeddings(embeddings, words):
9
  if len(words) < 4:
10
  st.error("Please provide at least 4 words/phrases for effective visualization.")
@@ -49,7 +69,8 @@ def main():
49
  with st.spinner('Generating embeddings...'):
50
  embeddings = get_bert_embeddings(words)
51
  fig = plot_interactive_bert_embeddings(embeddings, words)
52
- st.plotly_chart(fig, use_container_width=True)
 
53
 
54
  if __name__ == "__main__":
55
- main()
 
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
6
 
7
+ # BERT embeddings function
8
+ def get_bert_embeddings(words):
9
+ # Load pre-trained BERT model and tokenizer
10
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
11
+ model = BertModel.from_pretrained('bert-base-uncased')
12
+
13
+ embeddings = []
14
+
15
+ # Extract embeddings
16
+ for word in words:
17
+ inputs = tokenizer(word, return_tensors='pt')
18
+ outputs = model(**inputs)
19
+ embeddings.append(outputs.last_hidden_state[0][0].detach().numpy())
20
+
21
+ # Reduce dimensions to 3 using PCA
22
+ pca = PCA(n_components=3)
23
+ reduced_embeddings = pca.fit_transform(embeddings)
24
+
25
+ return reduced_embeddings
26
+
27
+ # Plotly plotting function
28
  def plot_interactive_bert_embeddings(embeddings, words):
29
  if len(words) < 4:
30
  st.error("Please provide at least 4 words/phrases for effective visualization.")
 
69
  with st.spinner('Generating embeddings...'):
70
  embeddings = get_bert_embeddings(words)
71
  fig = plot_interactive_bert_embeddings(embeddings, words)
72
+ if fig is not None: # Only plot if the figure is not None
73
+ st.plotly_chart(fig, use_container_width=True)
74
 
75
  if __name__ == "__main__":
76
+ main()