CIAgent / app.py
provin's picture
Update app.py
c4e848c verified
import os
import re
import requests
import gradio as gr
from bs4 import BeautifulSoup
import torch
# Hugging Face Transformers
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
Trainer,
TrainingArguments,
pipeline
)
from datasets import Dataset
# -----------------------------
# 1) SCRAPING (OPTIONAL)
# -----------------------------
BASE_URL = "https://www.cia.gov"
ARCHIVE_URL = "https://www.cia.gov/resources/csi/studies-in-intelligence/archives/operations-subject-index/"
def get_article_links():
"""
Fetch the archive page and extract article links pointing to the CIA Studies in Intelligence.
"""
response = requests.get(ARCHIVE_URL)
response.raise_for_status()
soup = BeautifulSoup(response.text, 'html.parser')
links = []
for a_tag in soup.find_all('a', href=True):
href = a_tag['href']
if "resources/csi/studies-in-intelligence" in href.lower():
# Convert relative links to absolute
if href.startswith("/"):
href = BASE_URL + href
links.append(href)
return list(set(links)) # remove duplicates
def scrape_article_text(url):
"""
Fetch the article text from the URL if it's HTML.
(Skipping PDFs for demo.)
"""
response = requests.get(url)
response.raise_for_status()
content_type = response.headers.get('Content-Type', '')
if 'application/pdf' in content_type.lower():
# Skip PDFs in this simple demo.
return None
soup = BeautifulSoup(response.text, 'html.parser')
paragraphs = soup.find_all('p')
article_text = "\n".join(p.get_text(strip=True) for p in paragraphs)
return article_text
def scrape_all_articles(article_links):
"""
Iterate through all links and gather text into a dict {url: text}.
"""
corpus_data = {}
for link in article_links:
text = scrape_article_text(link)
if text:
corpus_data[link] = text
return corpus_data
# -----------------------------
# 2) DATA CLEANING
# -----------------------------
import re
def clean_text(text):
# Simple cleaning: remove extra whitespace
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
def prepare_dataset(corpus_data):
cleaned_texts = []
for url, text in corpus_data.items():
cleaned_texts.append(clean_text(text))
return cleaned_texts
# -----------------------------
# 3) FINE-TUNING (OPTIONAL)
# -----------------------------
def fine_tune_model(cleaned_texts, model_name="gpt2", output_dir="cia_agent_model"):
"""
Fine-tunes GPT-2 on your CIA corpus.
Warning: resource-intensive! The free Hugging Face Spaces might time out.
"""
ds = Dataset.from_dict({"text": cleaned_texts})
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token # GPT-2 doesn't have a pad token
def tokenize_function(examples):
return tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=128
)
tokenized_ds = ds.map(tokenize_function, batched=True)
model = AutoModelForCausalLM.from_pretrained(model_name)
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
num_train_epochs=1, # demonstration only
per_device_train_batch_size=1,
save_steps=100,
save_total_limit=1,
logging_steps=10,
evaluation_strategy="no", # or 'steps'
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_ds,
)
trainer.train()
trainer.save_model(output_dir)
tokenizer.save_pretrained(output_dir)
return model, tokenizer
# -----------------------------
# 4) CIAgent INFERENCE
# -----------------------------
class CIAgent:
def __init__(self, model_path="cia_agent_model"):
"""
Initialize a pipeline from a local fine-tuned model folder or fallback to GPT-2.
"""
if not os.path.exists(model_path):
model_path = "gpt2"
self.generator = pipeline(
"text-generation",
model=model_path,
tokenizer=model_path
)
self.max_length = 128
def query(self, prompt, max_length=128, temperature=0.7, top_p=0.9):
"""
Generate text from the model.
"""
outputs = self.generator(
prompt,
max_length=max_length,
temperature=temperature,
top_p=top_p,
num_return_sequences=1
)
return outputs[0]["generated_text"]
# -----------------------------
# 5) GRADIO CHAT INTERFACE
# -----------------------------
# Create (or load) your CIAgent. In a real workflow, you might have already
# fine-tuned locally and just upload the "cia_agent_model" folder to your Space.
agent = CIAgent(model_path="cia_agent_model") # or "gpt2" if you haven't trained
def respond(
message,
history: list[tuple[str, str]],
system_message,
max_tokens,
temperature,
top_p
):
"""
This function is called by Gradio's ChatInterface. It receives:
- message: current user message
- history: list of (user_text, assistant_text) pairs
- system_message: the "system" instruction to guide the model
- max_tokens, temperature, top_p: generation parameters
We build a 'prompt' from all conversation turns + system message.
Then we query the CIAgent to get one text output.
Since CIAgent doesn't stream tokens by default, we yield once at the end.
"""
# Build the conversation prompt
# For demonstration, we simply concatenate everything in a naive format.
# You could style it in a more advanced way for better context handling.
prompt = f"System: {system_message}\n\n"
for user_text, assistant_text in history:
if user_text:
prompt += f"User: {user_text}\n"
if assistant_text:
prompt += f"Assistant: {assistant_text}\n"
# Add the new user message
prompt += f"User: {message}\nAssistant: "
# Query the local CIAgent
response_text = agent.query(
prompt,
max_length=max_tokens,
temperature=temperature,
top_p=top_p
)
# We can yield partial tokens if we want streaming, but the pipeline
# returns the entire text at once. Let's yield a single chunk:
yield response_text
# Create the ChatInterface
demo = gr.ChatInterface(
fn=respond,
additional_inputs=[
gr.Textbox(
value="You are a friendly Chatbot that knows about CIA Studies in Intelligence.",
label="System message"
),
gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
# WARNING: Running scraping & fine-tuning on a free HF Space
# might exceed time/memory limits. If you do want to train, uncomment:
#
# article_links = get_article_links()
# corpus_data = scrape_all_articles(article_links)
# cleaned_texts = prepare_dataset(corpus_data)
# model, tokenizer = fine_tune_model(cleaned_texts)
#
# Then re-initialize agent = CIAgent("cia_agent_model")
#
# For now, just launch the Gradio chat using the existing or fallback GPT-2 model.
demo.launch()