Spaces:
Running
on
L4
Running
on
L4
Update app.py
Browse files
app.py
CHANGED
|
@@ -1028,18 +1028,36 @@ networks = {
|
|
| 1028 |
"conditional": True
|
| 1029 |
}
|
| 1030 |
}
|
|
|
|
| 1031 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 1032 |
|
| 1033 |
-
#
|
| 1034 |
-
|
| 1035 |
-
|
| 1036 |
-
|
| 1037 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1038 |
|
| 1039 |
@spaces.GPU
|
| 1040 |
def generate_random_images(network_choice, class_label=None, num_images=2):
|
| 1041 |
# Get the selected generator
|
| 1042 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1043 |
|
| 1044 |
# Generate random seeds
|
| 1045 |
seeds = np.random.randint(0, 100000, num_images)
|
|
|
|
| 1028 |
"conditional": True
|
| 1029 |
}
|
| 1030 |
}
|
| 1031 |
+
NETWORK_CHOICES = ["FFHQ (Faces) 256px", "ImageNet 64px"]
|
| 1032 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 1033 |
|
| 1034 |
+
# Create separate global variables for each network
|
| 1035 |
+
G_ffhq = None
|
| 1036 |
+
G_imagenet = None
|
| 1037 |
+
|
| 1038 |
+
# Load the FFHQ network
|
| 1039 |
+
print(f'Loading network "{NETWORK_CHOICES[0]}" from "{path_ffhq}"...')
|
| 1040 |
+
with dnnlib.util.open_url(path_ffhq) as f:
|
| 1041 |
+
G_ffhq = legacy.load_network_pkl(f)['G_ema'].to(device)
|
| 1042 |
+
|
| 1043 |
+
# Load the ImageNet network
|
| 1044 |
+
print(f'Loading network "{NETWORK_CHOICES[1]}" from "{path_imagenet}"...')
|
| 1045 |
+
with dnnlib.util.open_url(path_imagenet) as f:
|
| 1046 |
+
G_imagenet = legacy.load_network_pkl(f)['G_ema'].to(device)
|
| 1047 |
+
|
| 1048 |
|
| 1049 |
@spaces.GPU
|
| 1050 |
def generate_random_images(network_choice, class_label=None, num_images=2):
|
| 1051 |
# Get the selected generator
|
| 1052 |
+
if network_choice == NETWORK_CHOICES[0]: # FFHQ
|
| 1053 |
+
G = G_ffhq
|
| 1054 |
+
is_conditional = False
|
| 1055 |
+
elif network_choice == NETWORK_CHOICES[1]: # ImageNet
|
| 1056 |
+
G = G_imagenet
|
| 1057 |
+
is_conditional = True
|
| 1058 |
+
else:
|
| 1059 |
+
# This case should not be reached with the Gradio Radio component
|
| 1060 |
+
raise ValueError("Invalid network choice selected.")
|
| 1061 |
|
| 1062 |
# Generate random seeds
|
| 1063 |
seeds = np.random.randint(0, 100000, num_images)
|