TiberiuCristianLeon commited on
Commit
bd61d8c
·
verified ·
1 Parent(s): ade20d4

Update src/Translate.py

Browse files
Files changed (1) hide show
  1. src/Translate.py +30 -1
src/Translate.py CHANGED
@@ -36,7 +36,7 @@ class Translators:
36
  def translationpipe(self):
37
  translation = pipeline('translation', model = self.model_name)
38
  return translation(self.input_text)[0]['translation_text'], self.message
39
- def mbartlarge(self):
40
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, MBartTokenizer
41
  src_lang = f"{self.sl}_XX"
42
  tgt_lang = f"{self.tl}_{self.tl.upper()}"
@@ -60,6 +60,35 @@ class Translators:
60
  print(src_lang, tgt_lang, tokenizer.lang_code_to_id[tgt_lang])
61
  translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
62
  return translation, self.message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def paraphraseTranslateMethod(requestValue: str, model: str):
65
  nltk.download('punkt')
 
36
  def translationpipe(self):
37
  translation = pipeline('translation', model = self.model_name)
38
  return translation(self.input_text)[0]['translation_text'], self.message
39
+ def mbartlarge25(self):
40
  from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, MBartTokenizer
41
  src_lang = f"{self.sl}_XX"
42
  tgt_lang = f"{self.tl}_{self.tl.upper()}"
 
60
  print(src_lang, tgt_lang, tokenizer.lang_code_to_id[tgt_lang])
61
  translation = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
62
  return translation, self.message
63
+ def mbartlarge(self):
64
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
65
+
66
+ model_name = "facebook/mbart-large-cc25"
67
+
68
+ # load tokenizer and model
69
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
70
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
71
+
72
+ # tell tokenizer the source language
73
+ tokenizer.src_lang = "en_XX"
74
+
75
+ # find the id for the target language and force it at generation
76
+ forced_bos_token_id = tokenizer.lang_code_to_id["ro_RO"]
77
+
78
+ # create the pipeline (pass tokenizer and model explicitly)
79
+ pipe = pipeline("translation", model=model, tokenizer=tokenizer)
80
+
81
+ # call the pipeline; generation kwargs are forwarded to model.generate
82
+ src_text = "This is a test sentence."
83
+ result = pipe(
84
+ src_text,
85
+ num_beams=4,
86
+ max_length=512,
87
+ forced_bos_token_id=forced_bos_token_id
88
+ )
89
+
90
+ return result[0]["translation_text"], self.message
91
+
92
 
93
  def paraphraseTranslateMethod(requestValue: str, model: str):
94
  nltk.download('punkt')