Untitled

 avatar
unknown
plain_text
a month ago
2.1 kB
7
Indexable
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import fitz
import openai
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import RetrievalQA
import os

app = FastAPI()

# Initialize OpenAI and models
openai.api_key = ""
embeddings_model = OpenAIEmbeddings(openai_api_key=openai.api_key)

def extract_text_from_pdf(pdf_path):
    text = ""
    with fitz.open(pdf_path) as pdf:
        for page in pdf:
            text += page.get_text()
    return text

def chunk_text(text, chunk_size=500):
    words = text.split()
    chunks = [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
    return chunks

def get_chat_model():    
    return ChatOpenAI(
        model_name="gpt-4",
        openai_api_key=openai.api_key
    )

# Initialize or load vector store
VECTOR_STORE_PATH = "faiss_store"

if os.path.exists(VECTOR_STORE_PATH):
    print("Loading existing vector store...")
    vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings_model, allow_dangerous_deserialization=True )
else:
    print("Creating new vector store...")
    pdf_text = extract_text_from_pdf("combined_laws.pdf")
    chunks = chunk_text(pdf_text, chunk_size=200)
    vector_store = FAISS.from_texts(chunks, embeddings_model)
    # Save the vector store
    vector_store.save_local(VECTOR_STORE_PATH)
    print("Vector store saved successfully!")

retriever = vector_store.as_retriever()
qa_chain = RetrievalQA.from_chain_type(
    llm=get_chat_model(),
    retriever=retriever
)

class Question(BaseModel):
    question: str

@app.post("/ask")
async def ask_question(question: Question):
    try:
        answer = qa_chain.invoke({"query": question.question})
        return {"answer": answer}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    return {"message": "PDF QA API is running"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)
Leave a Comment