Spaces:
Runtime error
Runtime error
rvian
commited on
Commit
·
d9956d6
1
Parent(s):
4e3e82b
ajuste
Browse files
app.py
CHANGED
|
@@ -22,11 +22,11 @@ def carregar_modelo_e_tokenizador_mbart(modelo):
|
|
| 22 |
|
| 23 |
# TODO:batch?
|
| 24 |
def traduzir_en_pt(text):
|
| 25 |
-
inputs =
|
| 26 |
input_ids = inputs.input_ids
|
| 27 |
attention_mask = inputs.attention_mask
|
| 28 |
-
output =
|
| 29 |
-
return
|
| 30 |
|
| 31 |
## streamlit ##
|
| 32 |
def carregar_dataset():
|
|
@@ -77,9 +77,9 @@ dataset = carregar_dataset()
|
|
| 77 |
if dataset is not None:
|
| 78 |
mostrar_dataset()
|
| 79 |
if st.button("Carregar modelo"):
|
| 80 |
-
|
| 81 |
|
| 82 |
|
| 83 |
-
if st.button("Traduzir dataset") and
|
| 84 |
traduzir_dataset(dataset)
|
| 85 |
resultado()
|
|
|
|
| 22 |
|
| 23 |
# TODO:batch?
|
| 24 |
def traduzir_en_pt(text):
|
| 25 |
+
inputs = tokenizador(text, return_tensors='pt')
|
| 26 |
input_ids = inputs.input_ids
|
| 27 |
attention_mask = inputs.attention_mask
|
| 28 |
+
output = modelo.generate(input_ids, attention_mask=attention_mask, forced_bos_token_id=tokenizador.lang_code_to_id['pt_XX'])
|
| 29 |
+
return tokenizador.decode(output[0], skip_special_tokens=True)
|
| 30 |
|
| 31 |
## streamlit ##
|
| 32 |
def carregar_dataset():
|
|
|
|
| 77 |
if dataset is not None:
|
| 78 |
mostrar_dataset()
|
| 79 |
if st.button("Carregar modelo"):
|
| 80 |
+
modelo, tokenizador = carregar_modelo()
|
| 81 |
|
| 82 |
|
| 83 |
+
if st.button("Traduzir dataset") and modelo is not None:
|
| 84 |
traduzir_dataset(dataset)
|
| 85 |
resultado()
|