|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
from typing import List |
|
|
from langchain_openai import OpenAIEmbeddings, ChatOpenAI |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_community.vectorstores import Chroma |
|
|
from langchain_core.prompts import ChatPromptTemplate |
|
|
from langchain_core.output_parsers import StrOutputParser |
|
|
from langchain_core.runnables import RunnableParallel |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
|
|
RULES_PATH = os.path.join(SCRIPT_DIR, 'rules.txt') |
|
|
|
|
|
|
|
|
try: |
|
|
with open(RULES_PATH, 'r') as file: |
|
|
golf_rules = file.read() |
|
|
except FileNotFoundError: |
|
|
print(f"Error: Could not find rules.txt at {RULES_PATH}") |
|
|
golf_rules = "" |
|
|
|
|
|
if not golf_rules: |
|
|
raise RuntimeError("Failed to load golf rules. Please ensure rules.txt is present in the repository.") |
|
|
|
|
|
|
|
|
major_splitter = RecursiveCharacterTextSplitter( |
|
|
separators=[r"\n\*\*\*\nRule"], |
|
|
chunk_size=10000, |
|
|
chunk_overlap=0, |
|
|
length_function=len, |
|
|
is_separator_regex=True, |
|
|
) |
|
|
|
|
|
|
|
|
detail_splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=1000, |
|
|
chunk_overlap=200, |
|
|
length_function=len, |
|
|
) |
|
|
|
|
|
|
|
|
major_chunks = major_splitter.split_text(golf_rules) |
|
|
print(f"Created {len(major_chunks)} major rule chunks") |
|
|
|
|
|
|
|
|
chunks = [] |
|
|
for chunk in major_chunks: |
|
|
if len(chunk) > 1000: |
|
|
sub_chunks = detail_splitter.split_text(chunk) |
|
|
chunks.extend(sub_chunks) |
|
|
else: |
|
|
chunks.append(chunk) |
|
|
|
|
|
print(f"Created {len(chunks)} total chunks") |
|
|
|
|
|
|
|
|
embeddings = OpenAIEmbeddings() |
|
|
llm = ChatOpenAI(temperature=0, model="gpt-4o-mini") |
|
|
|
|
|
|
|
|
vectorstore = Chroma.from_texts( |
|
|
texts=major_chunks, |
|
|
embedding=embeddings, |
|
|
) |
|
|
|
|
|
|
|
|
template = """You are a helpful golf rules assistant. Use the following pieces of context to answer the question at the end. |
|
|
If you don't know the answer, just say that you don't know, don't try to make up an answer. |
|
|
You can only answer questions about the rules of golf. If a question is not about golf, kindly remind them that you only are a golf rules assistant. |
|
|
Think step by step and remember to use emojis and cheer the golfer on! |
|
|
|
|
|
Context: {context} |
|
|
|
|
|
Question: {question} |
|
|
|
|
|
Answer:""" |
|
|
|
|
|
prompt = ChatPromptTemplate.from_template(template) |
|
|
|
|
|
|
|
|
def format_docs(docs): |
|
|
formatted_docs = [] |
|
|
for i, doc in enumerate(docs, 1): |
|
|
formatted_docs.append(f"[Source {i}]: {doc.page_content}") |
|
|
return "\n\n".join(formatted_docs) |
|
|
|
|
|
def format_response(response, doctitle): |
|
|
return f"{response}\n\n{'='*50}\nSource used: {doctitle}" |
|
|
|
|
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 1}) |
|
|
|
|
|
def rag_chain_with_sources(question): |
|
|
docs = retriever.invoke(question) |
|
|
chain = ( |
|
|
RunnableParallel({ |
|
|
"context": lambda _: format_docs(docs), |
|
|
"question": lambda _: question |
|
|
}) |
|
|
| prompt |
|
|
| llm |
|
|
| StrOutputParser() |
|
|
) |
|
|
response = chain.invoke({}) |
|
|
return response, docs |
|
|
|
|
|
|
|
|
def query_golf_rules(question: str) -> str: |
|
|
response, docs = rag_chain_with_sources(question) |
|
|
content_lines = [line for line in docs[0].page_content.split("\n") if line.strip() and line.strip() != "***"] |
|
|
doctitle = content_lines[0] if content_lines else "Unknown Rule" |
|
|
|
|
|
return format_response(response, doctitle) |
|
|
|
|
|
|
|
|
def gradio_interface(question): |
|
|
return query_golf_rules(question) |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=gradio_interface, |
|
|
inputs=gr.Textbox( |
|
|
lines=2, |
|
|
placeholder="What would you like to know?", |
|
|
label="Your Question" |
|
|
), |
|
|
outputs=gr.Textbox( |
|
|
lines=10, |
|
|
label="GolfGPT Answer" |
|
|
), |
|
|
title="GolfGPT Rules Assistant", |
|
|
description="Ask questions about golf rules and get accurate answers based on the official rules of golf. The model can make mistakes", |
|
|
examples=[ |
|
|
"What are the rules for taking a drop?", |
|
|
"How do I handle a lost ball?", |
|
|
"Can I repair ball marks on the green?", |
|
|
"What are the rules for playing from a bunker?", |
|
|
"How do I handle an unplayable lie?" |
|
|
], |
|
|
theme=gr.themes.Soft() |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |