multimodalart HF Staff commited on
Commit
dce996d
·
verified ·
1 Parent(s): a31d3a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -71
app.py CHANGED
@@ -14,7 +14,7 @@ if not GOOGLE_API_KEY:
14
  raise ValueError("GOOGLE_API_KEY environment variable not set.")
15
 
16
  client = genai.Client(
17
- api_key=os.environ.get("GEMINI_API_KEY"),
18
  )
19
 
20
  GEMINI_MODEL_NAME = 'gemini-2.5-flash-image-preview'
@@ -45,13 +45,29 @@ def _extract_image_data_from_response(response) -> Optional[bytes]:
45
  return part.inline_data.data
46
  return None
47
 
48
- def run_single_image_logic(prompt: str, image_path: Optional[str] = None) -> str:
49
- """Handles text-to-image or single image-to-image using Google Gemini."""
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
- contents = [prompt]
52
- if image_path:
53
- input_image = Image.open(image_path)
54
- contents.append(input_image)
 
 
 
 
 
55
 
56
  response = client.models.generate_content(
57
  model=GEMINI_MODEL_NAME,
@@ -73,36 +89,6 @@ def run_single_image_logic(prompt: str, image_path: Optional[str] = None) -> str
73
  raise gr.Error(f"Image generation failed: {e}")
74
 
75
 
76
- def run_multi_image_logic(prompt: str, images: List[str]) -> str:
77
- """
78
- Handles multi-image editing by sending a list of images and a prompt.
79
- """
80
- if not images:
81
- raise gr.Error("Please upload at least one image in the 'Multiple Images' tab.")
82
-
83
- try:
84
- contents = [Image.open(image_path[0]) for image_path in images]
85
- contents.append(prompt)
86
-
87
- response = client.models.generate_content(
88
- model=GEMINI_MODEL_NAME,
89
- contents=contents,
90
- )
91
-
92
- image_data = _extract_image_data_from_response(response)
93
-
94
- if not image_data:
95
- raise ValueError("No image data found in the model response.")
96
-
97
- pil_image = Image.open(BytesIO(image_data))
98
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmpfile:
99
- pil_image.save(tmpfile.name)
100
- return tmpfile.name
101
-
102
- except Exception as e:
103
- raise gr.Error(f"Image generation failed: {e}")
104
-
105
-
106
  # --- Gradio App UI ---
107
  css = '''
108
  #sub_title{margin-top: -35px !important}
@@ -129,17 +115,12 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
129
  with main_interface:
130
  with gr.Row():
131
  with gr.Column(scale=1):
132
- active_tab_state = gr.State(value="single")
133
- with gr.Tabs() as tabs:
134
- with gr.TabItem("Single Image", id="single") as single_tab:
135
- image_input = gr.Image(
136
- type="filepath",
137
- label="Input Image (Leave blank for text-to-image)"
138
- )
139
- with gr.TabItem("Multiple Images", id="multiple") as multi_tab:
140
- gallery_input = gr.Gallery(
141
- label="Input Images (drop all images here)", file_types=["image"]
142
- )
143
 
144
  prompt_input = gr.Textbox(
145
  label="Prompt",
@@ -154,37 +135,20 @@ with gr.Blocks(theme=gr.themes.Citrus(), css=css) as demo:
154
 
155
  login_button = gr.LoginButton()
156
 
157
- # --- Event Handlers ---
158
- def unified_generator(
159
- prompt: str,
160
- single_image: Optional[str],
161
- multi_images: Optional[List[str]],
162
- active_tab: str,
163
- oauth_token: Optional[gr.OAuthToken] = None,
164
- ) -> str:
165
- if not verify_pro_status(oauth_token):
166
- raise gr.Error("Access Denied. This service is for PRO users only.")
167
- if active_tab == "multiple" and multi_images:
168
- return run_multi_image_logic(prompt, multi_images)
169
- else:
170
- return run_single_image_logic(prompt, single_image)
171
-
172
- single_tab.select(lambda: "single", None, active_tab_state)
173
- multi_tab.select(lambda: "multiple", None, active_tab_state)
174
-
175
  generate_button.click(
176
- unified_generator,
177
- inputs=[prompt_input, image_input, gallery_input, active_tab_state],
178
  outputs=[output_image],
179
  )
180
 
181
  use_image_button.click(
182
- lambda img: img,
183
  inputs=[output_image],
184
- outputs=[image_input]
185
  )
186
 
187
- # --- Access Control Logic ---
188
  def control_access(
189
  profile: Optional[gr.OAuthProfile] = None,
190
  oauth_token: Optional[gr.OAuthToken] = None
 
14
  raise ValueError("GOOGLE_API_KEY environment variable not set.")
15
 
16
  client = genai.Client(
17
+ api_key=os.environ.get("GOOGLE_API_KEY"),
18
  )
19
 
20
  GEMINI_MODEL_NAME = 'gemini-2.5-flash-image-preview'
 
45
  return part.inline_data.data
46
  return None
47
 
48
+ def unified_image_generator(
49
+ prompt: str,
50
+ images: Optional[List[str]] = None,
51
+ oauth_token: Optional[gr.OAuthToken] = None
52
+ ) -> str:
53
+ """
54
+ Handles all image generation tasks based on the number of input images.
55
+ - 0 images: Text-to-image
56
+ - 1+ images: Image-to-image (single or multi-modal)
57
+ """
58
+ if not verify_pro_status(oauth_token):
59
+ raise gr.Error("Access Denied. This service is for PRO users only.")
60
+
61
  try:
62
+ # Dynamically build the 'contents' list for the API
63
+ contents = []
64
+ if images:
65
+ # If there are images, open them and add to contents
66
+ for image_path in images:
67
+ contents.append(Image.open(image_path))
68
+
69
+ # Always add the prompt to the contents
70
+ contents.append(prompt)
71
 
72
  response = client.models.generate_content(
73
  model=GEMINI_MODEL_NAME,
 
89
  raise gr.Error(f"Image generation failed: {e}")
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  # --- Gradio App UI ---
93
  css = '''
94
  #sub_title{margin-top: -35px !important}
 
115
  with main_interface:
116
  with gr.Row():
117
  with gr.Column(scale=1):
118
+ image_input_gallery = gr.Gallery(
119
+ label="Input Image(s)",
120
+ info="Upload one or more images here. Leave empty for text-to-image",
121
+ file_types=["image"],
122
+ height="auto"
123
+ )
 
 
 
 
 
124
 
125
  prompt_input = gr.Textbox(
126
  label="Prompt",
 
135
 
136
  login_button = gr.LoginButton()
137
 
138
+ # --- Event Handlers (SIMPLIFIED) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  generate_button.click(
140
+ unified_image_generator,
141
+ inputs=[prompt_input, image_input_gallery], # Inputs are now just the prompt and the single gallery
142
  outputs=[output_image],
143
  )
144
 
145
  use_image_button.click(
146
+ lambda img_path: [img_path] if img_path else None,
147
  inputs=[output_image],
148
+ outputs=[image_input_gallery]
149
  )
150
 
151
+ # --- Access Control Logic (UNCHANGED) ---
152
  def control_access(
153
  profile: Optional[gr.OAuthProfile] = None,
154
  oauth_token: Optional[gr.OAuthToken] = None