97 lines
3.1 KiB
Python
97 lines
3.1 KiB
Python
# File: obsidian_rag/rag_chain.py
|
|
|
|
from pathlib import Path
|
|
from typing import List, Tuple
|
|
|
|
from indexer import EMBEDDING_MODEL
|
|
from llm_client import LLMClient
|
|
from searcher import search_vault, SearchResult
|
|
|
|
|
|
class RAGChain:
|
|
"""
|
|
Retrieval-Augmented Generation (RAG) chain for answering queries
|
|
using semantic search over an Obsidian vault and LLM.
|
|
|
|
Attributes:
|
|
chroma_db_path (Path): Path to ChromaDB.
|
|
collection_name (str): Chroma collection name.
|
|
embedding_model (str): Embedding model name.
|
|
top_k (int): Number of chunks to send to the LLM.
|
|
min_score (float): Minimum similarity score for chunks.
|
|
system_prompt (str): System prompt to instruct the LLM.
|
|
llm_client (LLMClient): Internal LLM client instance.
|
|
"""
|
|
|
|
DEFAULT_SYSTEM_PROMPT = (
|
|
"You are an assistant specialized in analyzing Obsidian notes.\n\n"
|
|
"INSTRUCTIONS:\n"
|
|
"- Answer based ONLY on the provided context\n"
|
|
"- Cite the sources (files) you use\n"
|
|
"- If the information is not in the context, say \"I did not find this information in your notes\"\n"
|
|
"- Be concise but thorough\n"
|
|
"- Structure your answer with sections if necessary"
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
chroma_db_path: str,
|
|
api_key: str,
|
|
base_url: str,
|
|
collection_name: str = "obsidian_vault",
|
|
embedding_model: str = EMBEDDING_MODEL,
|
|
top_k: int = 5,
|
|
min_score: float = 0.0,
|
|
system_prompt: str = None,
|
|
) -> None:
|
|
self.chroma_db_path = Path(chroma_db_path)
|
|
self.collection_name = collection_name
|
|
self.embedding_model = embedding_model
|
|
self.top_k = top_k
|
|
self.min_score = min_score
|
|
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
|
|
|
|
# Instantiate internal LLM client
|
|
self.llm_client = LLMClient(api_key=api_key, base_url=base_url)
|
|
|
|
def answer_query(self, query: str) -> Tuple[str, List[SearchResult]]:
|
|
"""
|
|
Answer a user query using RAG: search vault, build context, call LLM.
|
|
|
|
Args:
|
|
query (str): User query.
|
|
|
|
Returns:
|
|
Tuple[str, List[SearchResult]]:
|
|
- LLM answer (str)
|
|
- List of used SearchResult chunks
|
|
"""
|
|
# 1. Perform semantic search
|
|
chunks: List[SearchResult] = search_vault(
|
|
query=query,
|
|
chroma_db_path=str(self.chroma_db_path),
|
|
collection_name=self.collection_name,
|
|
embedding_model=self.embedding_model,
|
|
limit=self.top_k,
|
|
min_score=self.min_score,
|
|
)
|
|
|
|
# 2. Build context string with citations
|
|
context_parts: List[str] = []
|
|
for chunk in chunks:
|
|
chunk_text = chunk.text.strip()
|
|
citation = f"[{chunk.file_path}#L{chunk.line_start}-L{chunk.line_end}]"
|
|
context_parts.append(f"{chunk_text}\n{citation}")
|
|
|
|
context_str = "\n\n".join(context_parts) if context_parts else ""
|
|
|
|
# 3. Call LLM with context + question
|
|
llm_response = self.llm_client.generate(
|
|
system_prompt=self.system_prompt,
|
|
user_prompt=query,
|
|
context=context_str,
|
|
)
|
|
|
|
answer_text = llm_response.get("answer", "")
|
|
return answer_text, chunks
|