Spaces:
Build error
Build error
| from PIL import Image | |
| import requests | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| from gradio.mix import Parallel | |
| import torch | |
| from transformers import ( | |
| ViTConfig, | |
| ViTForImageClassification, | |
| ViTFeatureExtractor, | |
| AutoModelForCausalLM, | |
| LogitsProcessorList, | |
| MinLengthLogitsProcessor, | |
| StoppingCriteriaList, | |
| MaxLengthCriteria, | |
| ImageClassificationPipeline, | |
| PerceiverForImageClassificationConvProcessing, | |
| PerceiverFeatureExtractor, | |
| VisionEncoderDecoderModel, | |
| AutoTokenizer, | |
| ) | |
| import json | |
| import os | |
| #get from local file spaces_info.py | |
| from spaces_info import description, examples, initial_prompt_value | |
| #some constants | |
| API_URL = os.getenv("API_URL") | |
| HF_API_TOKEN = os.getenv("HF_API_TOKEN") | |
| ##Bloom Inference API | |
| API_URL = "https://api-inference.huggingface.co/models/bigscience/bloom" | |
| #HF_API_TOKEN = os.environ["HF_API_TOKEN"] | |
| headers = {"Authorization": f"Bearer {HF_API_TOKEN}"} | |
| print(API_URL) | |
| print(HF_API_TOKEN) | |
| def query(payload): | |
| print(payload) | |
| response = requests.request("POST", API_URL, json=payload, headers={"Authorization": f"Bearer {HF_API_TOKEN}"}) | |
| print(response) | |
| return json.loads(response.content.decode("utf-8")) | |
| def inference(input_sentence, max_length, sample_or_greedy, seed=42): | |
| if sample_or_greedy == "Sample": | |
| parameters = { | |
| "max_new_tokens": max_length, | |
| "top_p": 0.9, | |
| "do_sample": True, | |
| "seed": seed, | |
| "early_stopping": False, | |
| "length_penalty": 0.0, | |
| "eos_token_id": None, | |
| } | |
| else: | |
| parameters = { | |
| "max_new_tokens": max_length, | |
| "do_sample": False, | |
| "seed": seed, | |
| "early_stopping": False, | |
| "length_penalty": 0.0, | |
| "eos_token_id": None, | |
| } | |
| payload = {"inputs": input_sentence, "parameters": parameters,"options" : {"use_cache": False} } | |
| data = query(payload) | |
| if "error" in data: | |
| return (None, None, f"<span style='color:red'>ERROR: {data['error']} </span>") | |
| generation = data[0]["generated_text"].split(input_sentence, 1)[1] | |
| print(generation) | |
| ''' | |
| return ( | |
| input_sentence | |
| + prompt_to_generation | |
| + generation | |
| + after_generation, | |
| data[0]["generated_text"], | |
| "", | |
| ) | |
| ''' | |
| return input_sentence + generation | |
| def self_caption(image): | |
| repo_name = "ydshieh/vit-gpt2-coco-en" | |
| test_image = image | |
| feature_extractor2 = ViTFeatureExtractor.from_pretrained(repo_name) | |
| tokenizer = AutoTokenizer.from_pretrained(repo_name) | |
| model2 = VisionEncoderDecoderModel.from_pretrained(repo_name) | |
| pixel_values = feature_extractor2(test_image, return_tensors="pt").pixel_values | |
| print("Pixel Values") | |
| print(pixel_values) | |
| # autoregressively generate text (using beam search or other decoding strategy) | |
| generated_ids = model2.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True) | |
| # decode into text | |
| preds = tokenizer.batch_decode(generated_ids[0], skip_special_tokens=True) | |
| preds = [pred.strip() for pred in preds] | |
| print("Predictions") | |
| print(preds) | |
| print("The preds type is : ",type(preds)) | |
| pred_keys = ["Prediction"] | |
| pred_value = preds | |
| pred_dictionary = dict(zip(pred_keys, pred_value)) | |
| print("Pred dictionary") | |
| print(pred_dictionary) | |
| preds = ' '.join(preds) | |
| #inference(input_sentence, max_length, sample_or_greedy, seed=42) | |
| story = inference(preds, 64, "Sample", 42) | |
| return story | |
| def classify_image(image): | |
| config = ViTConfig(num_hidden_layers=12, hidden_size=768) | |
| model = ViTForImageClassification(config) | |
| #print(config) | |
| feature_extractor = ViTFeatureExtractor() | |
| # or, to load one that corresponds to a checkpoint on the hub: | |
| #feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224") | |
| #the following gets called by classify_image() | |
| feature_extractor = PerceiverFeatureExtractor.from_pretrained("deepmind/vision-perceiver-conv") | |
| model = PerceiverForImageClassificationConvProcessing.from_pretrained("deepmind/vision-perceiver-conv") | |
| #google/vit-base-patch16-224, deepmind/vision-perceiver-conv | |
| image_pipe = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor) | |
| results = image_pipe(image) | |
| print("RESULTS") | |
| print(results) | |
| # convert to format Gradio expects | |
| output = {} | |
| for prediction in results: | |
| predicted_label = prediction['label'] | |
| score = prediction['score'] | |
| output[predicted_label] = score | |
| print("OUTPUT") | |
| print(output) | |
| return output | |
| image = gr.inputs.Image(type="pil") | |
| label = gr.outputs.Label(num_top_classes=5) | |
| examples = [ ["cats.jpg"], ["batter.jpg"],["drinkers.jpg"] ] | |
| #examples = [ ["batter.jpg"] ] | |
| title = "Generate a Story from an Image using BLOOM" | |
| description = "Demo for classifying images with Perceiver IO. To use it, simply upload an image and click 'submit', a story is autogenerated as well, story generated using Bigscience/BLOOM" | |
| article = "<p style='text-align: center'></p>" | |
| img_info1 = gr.Interface( | |
| fn=classify_image, | |
| inputs=image, | |
| outputs=label, | |
| ) | |
| img_info2 = gr.Interface( | |
| fn=self_caption, | |
| inputs=image, | |
| #outputs=label, | |
| outputs = [ | |
| gr.outputs.Textbox(label = 'Story') | |
| ], | |
| ) | |
| Parallel(img_info1,img_info2, inputs=image, title=title, description=description, examples=examples, enable_queue=True).launch(debug=True) | |