TiberiuCristianLeon commited on
Commit
2e5ad82
·
verified ·
1 Parent(s): 891bd81

Update src/Translate.py

Browse files
Files changed (1) hide show
  1. src/Translate.py +6 -2
src/Translate.py CHANGED
@@ -37,12 +37,16 @@ class Translators:
37
  translation = pipeline('translation', model = self.model_name)
38
  return translation(self.input_text)[0]['translation_text'], self.message
39
  def mbartlarge(self):
40
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
41
  src_lang = f"{self.sl}_XX"
42
  tgt_lang = f"{self.tl}_{self.tl.upper()}"
43
  # Load model and tokenizer
 
 
 
44
  model = MBartForConditionalGeneration.from_pretrained(self.model_name)
45
- tokenizer = MBart50TokenizerFast.from_pretrained(self.model_name, src_lang=src_lang)
 
46
  # Tokenize and translate
47
  inputs = tokenizer(self.input_text, return_tensors="pt")
48
  translated_tokens = model.generate(
 
37
  translation = pipeline('translation', model = self.model_name)
38
  return translation(self.input_text)[0]['translation_text'], self.message
39
  def mbartlarge(self):
40
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast, MBartTokenizer
41
  src_lang = f"{self.sl}_XX"
42
  tgt_lang = f"{self.tl}_{self.tl.upper()}"
43
  # Load model and tokenizer
44
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
45
+ # tokenizer = AutoTokenizer.from_pretrained(self.model_name)
46
+ # model = AutoModelForSeq2SeqLM.from_pretrained(self.model_name)
47
  model = MBartForConditionalGeneration.from_pretrained(self.model_name)
48
+ tokenizer = MBartTokenizer.from_pretrained(self.model_name, src_lang=src_lang)
49
+ # pipe = pipeline("translation", model="facebook/mbart-large-cc25")
50
  # Tokenize and translate
51
  inputs = tokenizer(self.input_text, return_tensors="pt")
52
  translated_tokens = model.generate(