rianders commited on
Commit
edbfc19
·
1 Parent(s): ee365c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -23
app.py CHANGED
@@ -4,30 +4,13 @@ import torch
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
6
 
7
- # BERT embeddings function
8
- def get_bert_embeddings(words):
9
- # Load pre-trained BERT model and tokenizer
10
- tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
11
- model = BertModel.from_pretrained('bert-base-uncased')
12
-
13
- embeddings = []
14
-
15
- # Extract embeddings
16
- for word in words:
17
- inputs = tokenizer(word, return_tensors='pt')
18
- outputs = model(**inputs)
19
- embeddings.append(outputs.last_hidden_state[0][0].detach().numpy())
20
-
21
- # Reduce dimensions to 3 using PCA
22
- pca = PCA(n_components=3)
23
- reduced_embeddings = pca.fit_transform(embeddings)
24
-
25
- return reduced_embeddings
26
-
27
- # Plotly plotting function
28
  def plot_interactive_bert_embeddings(embeddings, words):
29
- data = []
 
 
30
 
 
31
  for i, word in enumerate(words):
32
  trace = go.Scatter3d(
33
  x=[embeddings[i][0]],
@@ -45,7 +28,10 @@ def plot_interactive_bert_embeddings(embeddings, words):
45
  xaxis=dict(title='PCA Component 1'),
46
  yaxis=dict(title='PCA Component 2'),
47
  zaxis=dict(title='PCA Component 3')
48
- )
 
 
 
49
  )
50
 
51
  fig = go.Figure(data=data, layout=layout)
 
4
  from sklearn.decomposition import PCA
5
  import plotly.graph_objs as go
6
 
7
+ # BERT Embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def plot_interactive_bert_embeddings(embeddings, words):
9
+ if len(words) < 4:
10
+ st.error("Please provide at least 4 words/phrases for effective visualization.")
11
+ return None
12
 
13
+ data = []
14
  for i, word in enumerate(words):
15
  trace = go.Scatter3d(
16
  x=[embeddings[i][0]],
 
28
  xaxis=dict(title='PCA Component 1'),
29
  yaxis=dict(title='PCA Component 2'),
30
  zaxis=dict(title='PCA Component 3')
31
+ ),
32
+ autosize=False,
33
+ width=800, # Width of the plot
34
+ height=600 # Height of the plot
35
  )
36
 
37
  fig = go.Figure(data=data, layout=layout)