Commit
·
b7e2104
1
Parent(s):
bdef5c4
Make it work
Browse files
views.py
CHANGED
|
@@ -8,7 +8,7 @@ from streamlit_plotly_events import plotly_events
|
|
| 8 |
import utils
|
| 9 |
import pandas as pd
|
| 10 |
from scipy.spatial import distance
|
| 11 |
-
|
| 12 |
dimensionality_reduction_model_name = "PCA"
|
| 13 |
|
| 14 |
def diffs(embeddings: np.ndarray, corrector, encoder, tokenizer):
|
|
@@ -26,15 +26,22 @@ def diffs(embeddings: np.ndarray, corrector, encoder, tokenizer):
|
|
| 26 |
with st.form(key="foo") as form:
|
| 27 |
submit_button = st.form_submit_button("Synthesize")
|
| 28 |
|
| 29 |
-
sent1 = st.text_input("Sentence 1")
|
| 30 |
st.latex("-")
|
| 31 |
-
sent2 = st.text_input("Sentence 2")
|
| 32 |
st.latex("+")
|
| 33 |
-
sent3 = st.text_input("Sentence 3")
|
| 34 |
st.latex("=")
|
| 35 |
|
| 36 |
if submit_button:
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
sent4 = st.text_input("Sentence 4", value=generated_sentence, disabled=True)
|
| 40 |
|
|
|
|
| 8 |
import utils
|
| 9 |
import pandas as pd
|
| 10 |
from scipy.spatial import distance
|
| 11 |
+
from resources import get_gtr_embeddings
|
| 12 |
dimensionality_reduction_model_name = "PCA"
|
| 13 |
|
| 14 |
def diffs(embeddings: np.ndarray, corrector, encoder, tokenizer):
|
|
|
|
| 26 |
with st.form(key="foo") as form:
|
| 27 |
submit_button = st.form_submit_button("Synthesize")
|
| 28 |
|
| 29 |
+
sent1 = st.text_input("Sentence 1", value="I am a king")
|
| 30 |
st.latex("-")
|
| 31 |
+
sent2 = st.text_input("Sentence 2", value="I am a man")
|
| 32 |
st.latex("+")
|
| 33 |
+
sent3 = st.text_input("Sentence 3", value="I am a woman")
|
| 34 |
st.latex("=")
|
| 35 |
|
| 36 |
if submit_button:
|
| 37 |
+
v1, v2, v3 = get_gtr_embeddings([sent1, sent2, sent3], encoder, tokenizer).to("cpu")
|
| 38 |
+
v4 = v1 - v2 + v3
|
| 39 |
+
generated_sentence, = vec2text.invert_embeddings(
|
| 40 |
+
embeddings=v4.unsqueeze(0).cuda(),
|
| 41 |
+
corrector=corrector,
|
| 42 |
+
num_steps=20,
|
| 43 |
+
)
|
| 44 |
+
generated_sentence = generated_sentence.strip()
|
| 45 |
|
| 46 |
sent4 = st.text_input("Sentence 4", value=generated_sentence, disabled=True)
|
| 47 |
|