gaur3009 commited on
Commit
709c9f6
Β·
verified Β·
1 Parent(s): b5317df

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -24
app.py CHANGED
@@ -11,15 +11,12 @@ import torch
11
  file_path = "marketing-campaigns.csv"
12
  df = pd.read_csv(file_path)
13
 
14
- # Flexible column handling
15
- if "description" in df.columns:
16
- df = df.dropna(subset=["campaign_name", "description"])
17
- df["text"] = df["campaign_name"].astype(str) + ": " + df["description"].astype(str)
18
- elif "campaign_name" in df.columns:
19
- df = df.dropna(subset=["campaign_name"])
20
- df["text"] = df["campaign_name"].astype(str)
21
- else:
22
- raise ValueError("CSV must contain at least a 'campaign_name' column")
23
 
24
  # -------------------------------
25
  # Embeddings + FAISS
@@ -35,7 +32,7 @@ index.add(embeddings_np)
35
  # -------------------------------
36
  # Load LLM (Phi-4-mini)
37
  # -------------------------------
38
- model_name = "microsoft/phi-2"
39
  tokenizer = AutoTokenizer.from_pretrained(model_name)
40
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, device_map="auto")
41
 
@@ -43,37 +40,81 @@ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float
43
  # RAG functions
44
  # -------------------------------
45
  def retrieve_context(query, k=3):
 
 
 
46
  query_vec = embed_model.encode([query], convert_to_tensor=True).cpu().numpy()
47
- D, I = index.search(query_vec, k)
48
  results = [df.iloc[i]["text"] for i in I[0]]
49
  return results
50
 
51
- def generate_with_rag(prompt):
52
- context = retrieve_context(prompt, k=3)
53
- context_str = "\n".join(context)
 
 
 
 
 
 
54
 
 
55
  rag_prompt = f"""
56
- You are an AI marketing assistant.
57
- Here are some past campaigns for reference:\n{context_str}\n
58
- Based on these, generate a new creative campaign idea for: {prompt}
 
 
 
 
 
 
 
 
 
 
59
  """
60
 
 
61
  inputs = tokenizer(rag_prompt, return_tensors="pt").to(model.device)
62
- outputs = model.generate(**inputs, max_length=200, temperature=0.7, top_p=0.9)
 
 
 
 
 
 
63
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
64
 
 
 
 
 
 
 
 
65
  # -------------------------------
66
  # Gradio UI
67
  # -------------------------------
68
  with gr.Blocks() as demo:
69
- gr.Markdown("## πŸ€– RAG-powered AI Marketing Campaign Generator")
 
70
 
71
- with gr.Row():
72
- query = gr.Textbox(label="Enter campaign idea or keyword")
73
- output = gr.Textbox(label="Generated Campaign")
74
- btn = gr.Button("Generate with RAG")
 
75
 
76
- btn.click(generate_with_rag, inputs=query, outputs=output)
 
 
 
 
 
 
 
 
77
 
78
  if __name__ == "__main__":
79
  demo.launch()
 
11
  file_path = "marketing-campaigns.csv"
12
  df = pd.read_csv(file_path)
13
 
14
+ if df.empty:
15
+ raise ValueError("CSV is empty. Please provide a dataset with campaign info.")
16
+
17
+ # Join all columns to form knowledge text
18
+ df = df.dropna()
19
+ df["text"] = df.astype(str).agg(" | ".join, axis=1)
 
 
 
20
 
21
  # -------------------------------
22
  # Embeddings + FAISS
 
32
  # -------------------------------
33
  # Load LLM (Phi-4-mini)
34
  # -------------------------------
35
+ model_name = "microsoft/phi-4-mini"
36
  tokenizer = AutoTokenizer.from_pretrained(model_name)
37
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, device_map="auto")
38
 
 
40
  # RAG functions
41
  # -------------------------------
42
  def retrieve_context(query, k=3):
43
+ """Retrieve top-k similar rows from dataset"""
44
+ if not query.strip():
45
+ return []
46
  query_vec = embed_model.encode([query], convert_to_tensor=True).cpu().numpy()
47
+ D, I = index.search(query_vec, min(k, len(df)))
48
  results = [df.iloc[i]["text"] for i in I[0]]
49
  return results
50
 
51
+ def generate_with_rag(prompt, k=3, temperature=0.7):
52
+ if not prompt.strip():
53
+ return "⚠️ Please enter a campaign idea or theme."
54
+
55
+ # Step 1: Retrieve supporting facts
56
+ context = retrieve_context(prompt, k)
57
+ if not context:
58
+ return "⚠️ No relevant context found in dataset."
59
+ context_str = "\n".join(context[:k])
60
 
61
+ # Step 2: Build grounded structured prompt
62
  rag_prompt = f"""
63
+ You are a top-tier creative marketing AI assistant.
64
+ Use the following supporting dataset entries as context:
65
+ {context_str}
66
+
67
+ Task: Generate a **structured marketing campaign** for:
68
+ {prompt}
69
+
70
+ Format your answer clearly with:
71
+ - πŸ“Œ Campaign Title
72
+ - ✨ Tagline
73
+ - πŸ§‘β€πŸ€β€πŸ§‘ Target Audience
74
+ - 🎯 Key Selling Points
75
+ - 🎬 Creative Idea
76
  """
77
 
78
+ # Step 3: LLM Generation
79
  inputs = tokenizer(rag_prompt, return_tensors="pt").to(model.device)
80
+ outputs = model.generate(
81
+ **inputs,
82
+ max_length=300,
83
+ temperature=float(temperature),
84
+ top_p=0.9,
85
+ do_sample=True
86
+ )
87
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
88
 
89
+ def search_dataset(query, k=5):
90
+ """Search dataset and return top matching rows"""
91
+ context = retrieve_context(query, k)
92
+ if not context:
93
+ return "⚠️ No results found."
94
+ return "\n\n".join(context)
95
+
96
  # -------------------------------
97
  # Gradio UI
98
  # -------------------------------
99
  with gr.Blocks() as demo:
100
+ gr.Markdown("# πŸ€– RAG-powered Creative Campaign Assistant")
101
+ gr.Markdown("Generate **smart, creative, and data-grounded campaigns** with retrieval-augmented AI.")
102
 
103
+ with gr.Tab("πŸ”Ž Explore Dataset"):
104
+ search_query = gr.Textbox(label="Search dataset by keyword / theme")
105
+ search_results = gr.Textbox(label="Top Matches", lines=10)
106
+ search_btn = gr.Button("Search")
107
+ search_btn.click(search_dataset, inputs=search_query, outputs=search_results)
108
 
109
+ with gr.Tab("✍️ Generate Campaign"):
110
+ with gr.Row():
111
+ prompt = gr.Textbox(label="Enter campaign idea / theme", lines=3)
112
+ with gr.Row():
113
+ k_slider = gr.Slider(1, 10, value=3, step=1, label="Number of supporting facts (k)")
114
+ temp_slider = gr.Slider(0.3, 1.2, value=0.7, step=0.1, label="Creativity (temperature)")
115
+ campaign_output = gr.Textbox(label="Generated Campaign", lines=15)
116
+ gen_btn = gr.Button("Generate with RAG")
117
+ gen_btn.click(generate_with_rag, inputs=[prompt, k_slider, temp_slider], outputs=campaign_output)
118
 
119
  if __name__ == "__main__":
120
  demo.launch()