Spaces:
Sleeping
Sleeping
| ''' | |
| from diffusers import StableDiffusionPipeline | |
| import torch | |
| # Load the fine-tuned DreamBooth model | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "./nyc-ad-model", | |
| torch_dtype=torch.float16, | |
| ).to("cuda") # use "cpu" if no GPU | |
| prompt = "brand name: xyc, fried chicken advertisement poster: a fried chicken in brooklyn street" | |
| image = pipe(prompt, num_inference_steps=500, guidance_scale=7.5).images[0] | |
| # Display or save the image | |
| image.save("output_nyc_ad.png") | |
| image.show() | |
| ''' | |
| ''' | |
| import torch, faiss, json | |
| from sentence_transformers import SentenceTransformer | |
| from diffusers import StableDiffusionPipeline | |
| texts=json.load(open("prompt.txt")) | |
| index=faiss.read_index("prompt.index") | |
| emb=SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
| pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda") | |
| def rag_prompt(query,k=3): | |
| q=emb.encode(query,normalize_embeddings=True).astype("float32") | |
| _,I=index.search(q.reshape(1,-1),k) | |
| retrieved=" ".join(texts[i] for i in I[0]) | |
| return f"{retrieved}. {query}" | |
| prompt=rag_prompt("fried chicken advertisement poster") | |
| img=pipe(prompt,num_inference_steps=30,guidance_scale=7.5).images[0] | |
| img.save("rag_output.png") | |
| ''' | |
| import torch, faiss, json | |
| from sentence_transformers import SentenceTransformer | |
| from diffusers import StableDiffusionPipeline | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Load RAG index | |
| texts = json.load(open("prompt.txt")) | |
| index = faiss.read_index("prompt.index") | |
| emb = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
| # Load image generation pipeline | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "./nyc-ad-model", | |
| torch_dtype=torch.float16 | |
| ).to("cuda") | |
| # Load your own fine-tuned SFT model | |
| text_model_path = "./sft-model" # Path to your SFT-finetuned model | |
| tokenizer = AutoTokenizer.from_pretrained(text_model_path) | |
| text_model = AutoModelForCausalLM.from_pretrained( | |
| text_model_path, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| # Build retrieval-augmented prompt | |
| def rag_prompt(query, k=3): | |
| q = emb.encode(query, normalize_embeddings=True).astype("float32") | |
| _, I = index.search(q.reshape(1, -1), k) | |
| retrieved = " ".join(texts[i] for i in I[0]) | |
| return f"{retrieved}. {query}" | |
| # Prompt for generation | |
| user_prompt = "fried chicken advertisement poster" | |
| full_prompt = rag_prompt(user_prompt) | |
| # Generate image | |
| image = pipe(full_prompt, num_inference_steps=30, guidance_scale=7.5).images[0] | |
| image.save("rag_output.png") | |
| # Construct input prompt compatible with SFT format | |
| copy_prompt = f"""### Instruction: | |
| Generate a catchy advertisement slogan for: {user_prompt} | |
| ### Response:""" | |
| inputs = tokenizer(copy_prompt, return_tensors="pt").to("cuda") | |
| output_ids = text_model.generate( | |
| **inputs, | |
| max_new_tokens=30, | |
| do_sample=True, | |
| top_p=0.95 | |
| ) | |
| response = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
| # Output result | |
| print("๐ผ๏ธ Image saved to rag_output.png") | |
| print("๐ Generated slogan:") | |
| print(response.strip()) |