File size: 5,154 Bytes
e8746a9
 
9cca2b0
69de931
 
 
da3ef81
 
69de931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8746a9
da3ef81
 
 
 
 
 
 
 
 
9cca2b0
 
 
 
e8746a9
69de931
 
e8746a9
69de931
 
 
 
 
 
 
5c93b6c
 
 
da3ef81
 
5c93b6c
 
 
 
 
da3ef81
 
 
5c93b6c
69de931
e8746a9
69de931
9cca2b0
 
69de931
 
 
6200397
69de931
 
 
 
 
 
 
 
 
da3ef81
69de931
6200397
 
 
 
 
 
 
69de931
 
 
 
 
 
 
e8746a9
69de931
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from fastai.vision.all import *
import gradio as gr
import fal_client
from PIL import Image
import io
import base64
import random
import requests

search_terms_wikipedia = {
    "blazing star": "https://en.wikipedia.org/wiki/Mentzelia",
    "bristlecone pine": "https://en.wikipedia.org/wiki/Pinus_longaeva",
    "california bluebell": "https://en.wikipedia.org/wiki/Phacelia_minor",
    "california buckeye": "https://en.wikipedia.org/wiki/Aesculus_californica",
    "california buckwheat": "https://en.wikipedia.org/wiki/Eriogonum_fasciculatum",
    "california fuchsia": "https://en.wikipedia.org/wiki/Epilobium_canum",
    "california checkerbloom": "https://en.wikipedia.org/wiki/Sidalcea_malviflora",
    "california lilac": "https://en.wikipedia.org/wiki/Ceanothus",
    "california poppy": "https://en.wikipedia.org/wiki/Eschscholzia_californica",
    "california sagebrush": "https://en.wikipedia.org/wiki/Artemisia_californica",
    "california wild grape": "https://en.wikipedia.org/wiki/Vitis_californica",
    "california wild rose": "https://en.wikipedia.org/wiki/Rosa_californica",
    "coyote mint": "https://en.wikipedia.org/wiki/Monardella",
    "elegant clarkia": "https://en.wikipedia.org/wiki/Clarkia_unguiculata",
    "baby blue eyes": "https://en.wikipedia.org/wiki/Nemophila_menziesii",
    "hummingbird sage": "https://en.wikipedia.org/wiki/Salvia_spathacea",
    "delphiniumr": "https://en.wikipedia.org/wiki/Delphinium",
    "matilija poppy": "https://en.wikipedia.org/wiki/Romneya_coulteri",
    "blue-eyed grass": "https://en.wikipedia.org/wiki/Sisyrinchium_bellum",
    "penstemon spectabilis": "https://en.wikipedia.org/wiki/Penstemon_spectabilis",
    "seaside daisy": "https://en.wikipedia.org/wiki/Erigeron_glaucus",
    "sticky monkeyflower": "https://en.wikipedia.org/wiki/Diplacus_aurantiacus",
    "tidy tips": "https://en.wikipedia.org/wiki/Layia_platyglossa",
    "wild cucumber": "https://en.wikipedia.org/wiki/Marah_(plant)",
    "douglas iris": "https://en.wikipedia.org/wiki/Iris_douglasiana",
    "goldfields coreopsis": "https://en.wikipedia.org/wiki/Coreopsis"
}

# Update prompt templates
prompt_templates = [
    "A cosmic {flower} blooming in space, with petals made of swirling galaxies and nebulae, glowing softly against a backdrop of distant stars.",
    "An enchanted garden filled with a bioluminescent {flower}, each petal radiating vibrant, otherworldly colors, illuminating the dark, mystical forest around them.",
    "A mechanical {flower} with petals made of polished metal and intricate gears, unfolding in a steampunk-inspired futuristic landscape.",
    "A surreal pot of a {flower} where each bloom is a miniature landscape, showing tiny mountains, rivers, and clouds nestled within the petals.",
    "An abstract explosion of a {flower}, blending vibrant colors and fluid shapes in a chaotic, dreamlike composition, evoking movement and emotion."
]

def on_queue_update(update):
    if isinstance(update, fal_client.InProgress):
        for log in update.logs:
           print(log["message"])

def process_image(img):
    # First do the classification
    pred, idx, probs = learn.predict(img)
    classification_results = dict(zip(search_terms_wikipedia.keys(), map(float, probs)))
    
    # Get Wikipedia URL for the predicted class
    predicted_class = max(classification_results.items(), key=lambda x: x[1])[0]
    wiki_url = search_terms_wikipedia.get(predicted_class, "No Wikipedia entry found.")
    
    # Generate FLUX image
    result = fal_client.subscribe(
        "fal-ai/flux/schnell",
        arguments={
            "prompt": random.choice(prompt_templates).format(flower=predicted_class),
            "image_size": "square"
        },
        with_logs=True,
        on_queue_update=on_queue_update,
    )
    
    image_url = result['images'][0]['url']
    response = requests.get(image_url)
    generated_image = Image.open(io.BytesIO(response.content))
    
    return classification_results, generated_image, wiki_url

# Load the learner
learn = load_learner('export.pkl')

# Create Gradio interface
with gr.Blocks() as demo:
    with gr.Row():
        input_image = gr.Image(height=192, width=192, label="Upload Image for Classification", type="pil")
    with gr.Row():
        with gr.Column():
            label_output = gr.Label(label="Classification Results")
            wiki_output = gr.Textbox(label="Wikipedia Article Link", lines=1)
        generated_image = gr.Image(label="AI Generated Interpretation")
    
    # Example images
    examples = [
        'https://www.deserthorizonnursery.com/wp-content/uploads/2024/03/Brittlebush-Encelia-Farinosa-desert-horizon-nursery.jpg',
        'https://cdn.mos.cms.futurecdn.net/VJE7gSuQ9KWbkqEsWgX5zS.jpg'
    ]
    gr.Examples(
        examples=examples,
        inputs=input_image,
        examples_per_page=5,
        fn=process_image,
        outputs=[label_output, generated_image, wiki_output]
    )
    
    # Set up event handler
    input_image.change(
        fn=process_image,
        inputs=input_image,
        outputs=[label_output, generated_image, wiki_output]
    )

demo.launch(inline=False)