File size: 6,662 Bytes
26557da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from transformers import AutoTokenizer
from models.model_manager import ModelManager
import torch
from .omost import OmostPromter


class BeautifulPrompt(torch.nn.Module):
    def __init__(self, tokenizer_path=None, model=None, template=""):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.model = model
        self.template = template

    @staticmethod
    def from_model_manager(model_manager: ModelManager):
        model, model_path = model_manager.fetch_model(
            "beautiful_prompt", require_model_path=True
        )
        template = "Instruction: Give a simple description of the image to generate a drawing prompt.\nInput: {raw_prompt}\nOutput:"
        if model_path.endswith("v2"):
            template = """Converts a simple image description into a prompt. \
Prompts are formatted as multiple related tags separated by commas, plus you can use () to increase the weight, [] to decrease the weight, \
or use a number to specify the weight. You should add appropriate words to make the images described in the prompt more aesthetically pleasing, \
but make sure there is a correlation between the input and output.\n\
### Input: {raw_prompt}\n### Output:"""
        beautiful_prompt = BeautifulPrompt(
            tokenizer_path=model_path, model=model, template=template
        )
        return beautiful_prompt

    def __call__(self, raw_prompt, positive=True, **kwargs):
        if positive:
            model_input = self.template.format(raw_prompt=raw_prompt)
            input_ids = self.tokenizer.encode(model_input, return_tensors="pt").to(
                self.model.device
            )
            outputs = self.model.generate(
                input_ids,
                max_new_tokens=384,
                do_sample=True,
                temperature=0.9,
                top_k=50,
                top_p=0.95,
                repetition_penalty=1.1,
                num_return_sequences=1,
            )
            prompt = (
                raw_prompt
                + ", "
                + self.tokenizer.batch_decode(
                    outputs[:, input_ids.size(1) :], skip_special_tokens=True
                )[0].strip()
            )
            print(f"Your prompt is refined by BeautifulPrompt: {prompt}")
            return prompt
        else:
            return raw_prompt


class QwenPrompt(torch.nn.Module):
    # This class leverages the open-source Qwen model to translate Chinese prompts into English,
    #    with an integrated optimization mechanism for enhanced translation quality.
    def __init__(self, tokenizer_path=None, model=None, system_prompt=""):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.model = model
        self.system_prompt = system_prompt

    @staticmethod
    def from_model_manager(model_nameger: ModelManager):
        model, model_path = model_nameger.fetch_model(
            "qwen_prompt", require_model_path=True
        )
        system_prompt = """You are an English image describer. Here are some example image styles:\n\n1. Extreme close-up: Clear focus on a single object with a blurred background, highlighted under natural sunlight.\n2. Vintage: A photograph of a historical scene, using techniques such as Daguerreotype or cyanotype.\n3. Anime: A stylized cartoon image, emphasizing hyper-realistic portraits and luminous brushwork.\n4. Candid: A natural, unposed shot capturing spontaneous moments, often with cinematic qualities.\n5. Landscape: A photorealistic image of natural scenery, such as a sunrise over the sea.\n6. Design: Colorful and detailed illustrations, often in the style of 2D game art or botanical illustrations.\n7. Urban: An ultrarealistic scene in a modern setting, possibly a cityscape viewed from indoors.\n\nYour task is to translate a given Chinese image description into a concise and precise English description. Ensure that the imagery is vivid and descriptive, and include stylistic elements to enrich the description.\nPlease note the following points:\n\n1. Capture the essence and mood of the Chinese description without including direct phrases or words from the examples provided.\n2. You should add appropriate words to make the images described in the prompt more aesthetically pleasing. If the Chinese description does not specify a style, you need to add some stylistic descriptions based on the essence of the Chinese text.\n3. The generated English description should not exceed 200 words.\n\n"""
        qwen_prompt = QwenPrompt(
            tokenizer_path=model_path, model=model, system_prompt=system_prompt
        )
        return qwen_prompt

    def __call__(self, raw_prompt, positive=True, **kwargs):
        if positive:
            messages = [
                {"role": "system", "content": self.system_prompt},
                {"role": "user", "content": raw_prompt},
            ]
            text = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            model_inputs = self.tokenizer([text], return_tensors="pt").to(
                self.model.device
            )

            generated_ids = self.model.generate(
                model_inputs.input_ids, max_new_tokens=512
            )
            generated_ids = [
                output_ids[len(input_ids) :]
                for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
            ]

            prompt = self.tokenizer.batch_decode(
                generated_ids, skip_special_tokens=True
            )[0]
            print(f"Your prompt is refined by Qwen: {prompt}")
            return prompt
        else:
            return raw_prompt


class Translator(torch.nn.Module):
    def __init__(self, tokenizer_path=None, model=None):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.model = model

    @staticmethod
    def from_model_manager(model_manager: ModelManager):
        model, model_path = model_manager.fetch_model(
            "translator", require_model_path=True
        )
        translator = Translator(tokenizer_path=model_path, model=model)
        return translator

    def __call__(self, prompt, **kwargs):
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
            self.model.device
        )
        output_ids = self.model.generate(input_ids)
        prompt = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
        print(f"Your prompt is translated: {prompt}")
        return prompt