|
|
import os |
|
|
import re |
|
|
import requests |
|
|
import gradio as gr |
|
|
from bs4 import BeautifulSoup |
|
|
import torch |
|
|
|
|
|
|
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
pipeline |
|
|
) |
|
|
from datasets import Dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
BASE_URL = "https://www.cia.gov" |
|
|
ARCHIVE_URL = "https://www.cia.gov/resources/csi/studies-in-intelligence/archives/operations-subject-index/" |
|
|
|
|
|
def get_article_links(): |
|
|
""" |
|
|
Fetch the archive page and extract article links pointing to the CIA Studies in Intelligence. |
|
|
""" |
|
|
response = requests.get(ARCHIVE_URL) |
|
|
response.raise_for_status() |
|
|
soup = BeautifulSoup(response.text, 'html.parser') |
|
|
|
|
|
links = [] |
|
|
for a_tag in soup.find_all('a', href=True): |
|
|
href = a_tag['href'] |
|
|
if "resources/csi/studies-in-intelligence" in href.lower(): |
|
|
|
|
|
if href.startswith("/"): |
|
|
href = BASE_URL + href |
|
|
links.append(href) |
|
|
|
|
|
return list(set(links)) |
|
|
|
|
|
def scrape_article_text(url): |
|
|
""" |
|
|
Fetch the article text from the URL if it's HTML. |
|
|
(Skipping PDFs for demo.) |
|
|
""" |
|
|
response = requests.get(url) |
|
|
response.raise_for_status() |
|
|
content_type = response.headers.get('Content-Type', '') |
|
|
if 'application/pdf' in content_type.lower(): |
|
|
|
|
|
return None |
|
|
|
|
|
soup = BeautifulSoup(response.text, 'html.parser') |
|
|
paragraphs = soup.find_all('p') |
|
|
article_text = "\n".join(p.get_text(strip=True) for p in paragraphs) |
|
|
return article_text |
|
|
|
|
|
def scrape_all_articles(article_links): |
|
|
""" |
|
|
Iterate through all links and gather text into a dict {url: text}. |
|
|
""" |
|
|
corpus_data = {} |
|
|
for link in article_links: |
|
|
text = scrape_article_text(link) |
|
|
if text: |
|
|
corpus_data[link] = text |
|
|
return corpus_data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
def clean_text(text): |
|
|
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
text = text.strip() |
|
|
return text |
|
|
|
|
|
def prepare_dataset(corpus_data): |
|
|
cleaned_texts = [] |
|
|
for url, text in corpus_data.items(): |
|
|
cleaned_texts.append(clean_text(text)) |
|
|
return cleaned_texts |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fine_tune_model(cleaned_texts, model_name="gpt2", output_dir="cia_agent_model"): |
|
|
""" |
|
|
Fine-tunes GPT-2 on your CIA corpus. |
|
|
Warning: resource-intensive! The free Hugging Face Spaces might time out. |
|
|
""" |
|
|
ds = Dataset.from_dict({"text": cleaned_texts}) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def tokenize_function(examples): |
|
|
return tokenizer( |
|
|
examples["text"], |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=128 |
|
|
) |
|
|
|
|
|
tokenized_ds = ds.map(tokenize_function, batched=True) |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=output_dir, |
|
|
overwrite_output_dir=True, |
|
|
num_train_epochs=1, |
|
|
per_device_train_batch_size=1, |
|
|
save_steps=100, |
|
|
save_total_limit=1, |
|
|
logging_steps=10, |
|
|
evaluation_strategy="no", |
|
|
) |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=tokenized_ds, |
|
|
) |
|
|
trainer.train() |
|
|
|
|
|
trainer.save_model(output_dir) |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CIAgent: |
|
|
def __init__(self, model_path="cia_agent_model"): |
|
|
""" |
|
|
Initialize a pipeline from a local fine-tuned model folder or fallback to GPT-2. |
|
|
""" |
|
|
if not os.path.exists(model_path): |
|
|
model_path = "gpt2" |
|
|
self.generator = pipeline( |
|
|
"text-generation", |
|
|
model=model_path, |
|
|
tokenizer=model_path |
|
|
) |
|
|
self.max_length = 128 |
|
|
|
|
|
def query(self, prompt, max_length=128, temperature=0.7, top_p=0.9): |
|
|
""" |
|
|
Generate text from the model. |
|
|
""" |
|
|
outputs = self.generator( |
|
|
prompt, |
|
|
max_length=max_length, |
|
|
temperature=temperature, |
|
|
top_p=top_p, |
|
|
num_return_sequences=1 |
|
|
) |
|
|
return outputs[0]["generated_text"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent = CIAgent(model_path="cia_agent_model") |
|
|
|
|
|
def respond( |
|
|
message, |
|
|
history: list[tuple[str, str]], |
|
|
system_message, |
|
|
max_tokens, |
|
|
temperature, |
|
|
top_p |
|
|
): |
|
|
""" |
|
|
This function is called by Gradio's ChatInterface. It receives: |
|
|
- message: current user message |
|
|
- history: list of (user_text, assistant_text) pairs |
|
|
- system_message: the "system" instruction to guide the model |
|
|
- max_tokens, temperature, top_p: generation parameters |
|
|
|
|
|
We build a 'prompt' from all conversation turns + system message. |
|
|
Then we query the CIAgent to get one text output. |
|
|
Since CIAgent doesn't stream tokens by default, we yield once at the end. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
prompt = f"System: {system_message}\n\n" |
|
|
for user_text, assistant_text in history: |
|
|
if user_text: |
|
|
prompt += f"User: {user_text}\n" |
|
|
if assistant_text: |
|
|
prompt += f"Assistant: {assistant_text}\n" |
|
|
|
|
|
prompt += f"User: {message}\nAssistant: " |
|
|
|
|
|
|
|
|
response_text = agent.query( |
|
|
prompt, |
|
|
max_length=max_tokens, |
|
|
temperature=temperature, |
|
|
top_p=top_p |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
yield response_text |
|
|
|
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
fn=respond, |
|
|
additional_inputs=[ |
|
|
gr.Textbox( |
|
|
value="You are a friendly Chatbot that knows about CIA Studies in Intelligence.", |
|
|
label="System message" |
|
|
), |
|
|
gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"), |
|
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
|
|
gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=1.0, |
|
|
value=0.9, |
|
|
step=0.05, |
|
|
label="Top-p (nucleus sampling)", |
|
|
), |
|
|
], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|