Update app.py
Browse files
app.py
CHANGED
|
@@ -1,23 +1,24 @@
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
|
|
|
| 4 |
|
| 5 |
tokenizer = AutoTokenizer.from_pretrained('TrLOX/gpt2-tdk')
|
| 6 |
model = AutoModelForCausalLM.from_pretrained('TrLOX/gpt2-tdk')
|
| 7 |
|
| 8 |
-
def text_generation(keywords, domain
|
| 9 |
input_ids = tokenizer('keyword ' + keywords + ' domain ' + domain + ' title', return_tensors="pt").input_ids
|
| 10 |
-
torch.manual_seed(seed)
|
| 11 |
outputs = model.generate(input_ids, do_sample=True, min_length=50, max_length=250)
|
| 12 |
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 13 |
-
return generated_text
|
| 14 |
|
| 15 |
title = "TDK GPT2"
|
| 16 |
description = "Title and description generation by keywords"
|
| 17 |
|
| 18 |
gr.Interface(
|
| 19 |
text_generation,
|
| 20 |
-
[gr.inputs.Textbox(default='test 1,test 2',lines=2, label="Enter keywords"), gr.inputs.Textbox(lines=2, default='test.com',label="Enter domain")
|
| 21 |
[gr.outputs.Textbox(type="auto", label="Text Generated")],
|
| 22 |
title=title,
|
| 23 |
description=description,
|
|
|
|
| 1 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
+
import random
|
| 5 |
|
| 6 |
tokenizer = AutoTokenizer.from_pretrained('TrLOX/gpt2-tdk')
|
| 7 |
model = AutoModelForCausalLM.from_pretrained('TrLOX/gpt2-tdk')
|
| 8 |
|
| 9 |
+
def text_generation(keywords, domain):
|
| 10 |
input_ids = tokenizer('keyword ' + keywords + ' domain ' + domain + ' title', return_tensors="pt").input_ids
|
| 11 |
+
torch.manual_seed(random.seed(18446744073709551615))
|
| 12 |
outputs = model.generate(input_ids, do_sample=True, min_length=50, max_length=250)
|
| 13 |
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 14 |
+
return generated_text[0]
|
| 15 |
|
| 16 |
title = "TDK GPT2"
|
| 17 |
description = "Title and description generation by keywords"
|
| 18 |
|
| 19 |
gr.Interface(
|
| 20 |
text_generation,
|
| 21 |
+
[gr.inputs.Textbox(default='test 1,test 2',lines=2, label="Enter keywords"), gr.inputs.Textbox(lines=2, default='test.com',label="Enter domain")],
|
| 22 |
[gr.outputs.Textbox(type="auto", label="Text Generated")],
|
| 23 |
title=title,
|
| 24 |
description=description,
|