File size: 2,105 Bytes
5f86ced
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import streamlit as st
from transformers import BertModel, BertTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.graph_objs as go

# BERT embeddings function
def get_bert_embeddings(words):
    # Load pre-trained BERT model and tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = BertModel.from_pretrained('bert-base-uncased')

    embeddings = []
    
    # Extract embeddings
    for word in words:
        inputs = tokenizer(word, return_tensors='pt')
        outputs = model(**inputs)
        embeddings.append(outputs.last_hidden_state[0][0].detach().numpy())

    # Reduce dimensions to 3 using PCA
    pca = PCA(n_components=3)
    reduced_embeddings = pca.fit_transform(embeddings)

    return reduced_embeddings

# Plotly plotting function
def plot_interactive_bert_embeddings(embeddings, words):
    data = []

    for i, word in enumerate(words):
        trace = go.Scatter3d(
            x=[embeddings[i][0]], 
            y=[embeddings[i][1]], 
            z=[embeddings[i][2]],
            mode='markers+text',
            text=[word],
            name=word
        )
        data.append(trace)

    layout = go.Layout(
        title='3D Scatter Plot of BERT Embeddings',
        scene=dict(
            xaxis=dict(title='PCA Component 1'),
            yaxis=dict(title='PCA Component 2'),
            zaxis=dict(title='PCA Component 3')
        )
    )

    fig = go.Figure(data=data, layout=layout)
    return fig

# Streamlit app
def main():
    st.title("BERT Embeddings Visualization")

    # Text input for words
    words_input = st.text_area("Enter words/phrases separated by commas:", "Spider-Man, Rocket Racoon, Venom, Spider, Racoon, Snake")
    words = [word.strip() for word in words_input.split(',')]

    if st.button("Generate Embeddings"):
        with st.spinner('Generating embeddings...'):
            embeddings = get_bert_embeddings(words)
            fig = plot_interactive_bert_embeddings(embeddings, words)
            st.plotly_chart(fig, use_container_width=True)

if __name__ == "__main__":
    main()