Initial commit
This commit is contained in:
96
obsidian_rag/rag_chain.py
Normal file
96
obsidian_rag/rag_chain.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# File: obsidian_rag/rag_chain.py
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from indexer import EMBEDDING_MODEL
|
||||
from llm_client import LLMClient
|
||||
from searcher import search_vault, SearchResult
|
||||
|
||||
|
||||
class RAGChain:
|
||||
"""
|
||||
Retrieval-Augmented Generation (RAG) chain for answering queries
|
||||
using semantic search over an Obsidian vault and LLM.
|
||||
|
||||
Attributes:
|
||||
chroma_db_path (Path): Path to ChromaDB.
|
||||
collection_name (str): Chroma collection name.
|
||||
embedding_model (str): Embedding model name.
|
||||
top_k (int): Number of chunks to send to the LLM.
|
||||
min_score (float): Minimum similarity score for chunks.
|
||||
system_prompt (str): System prompt to instruct the LLM.
|
||||
llm_client (LLMClient): Internal LLM client instance.
|
||||
"""
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are an assistant specialized in analyzing Obsidian notes.\n\n"
|
||||
"INSTRUCTIONS:\n"
|
||||
"- Answer based ONLY on the provided context\n"
|
||||
"- Cite the sources (files) you use\n"
|
||||
"- If the information is not in the context, say \"I did not find this information in your notes\"\n"
|
||||
"- Be concise but thorough\n"
|
||||
"- Structure your answer with sections if necessary"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chroma_db_path: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
collection_name: str = "obsidian_vault",
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
top_k: int = 5,
|
||||
min_score: float = 0.0,
|
||||
system_prompt: str = None,
|
||||
) -> None:
|
||||
self.chroma_db_path = Path(chroma_db_path)
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model = embedding_model
|
||||
self.top_k = top_k
|
||||
self.min_score = min_score
|
||||
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
# Instantiate internal LLM client
|
||||
self.llm_client = LLMClient(api_key=api_key, base_url=base_url)
|
||||
|
||||
def answer_query(self, query: str) -> Tuple[str, List[SearchResult]]:
|
||||
"""
|
||||
Answer a user query using RAG: search vault, build context, call LLM.
|
||||
|
||||
Args:
|
||||
query (str): User query.
|
||||
|
||||
Returns:
|
||||
Tuple[str, List[SearchResult]]:
|
||||
- LLM answer (str)
|
||||
- List of used SearchResult chunks
|
||||
"""
|
||||
# 1. Perform semantic search
|
||||
chunks: List[SearchResult] = search_vault(
|
||||
query=query,
|
||||
chroma_db_path=str(self.chroma_db_path),
|
||||
collection_name=self.collection_name,
|
||||
embedding_model=self.embedding_model,
|
||||
limit=self.top_k,
|
||||
min_score=self.min_score,
|
||||
)
|
||||
|
||||
# 2. Build context string with citations
|
||||
context_parts: List[str] = []
|
||||
for chunk in chunks:
|
||||
chunk_text = chunk.text.strip()
|
||||
citation = f"[{chunk.file_path}#L{chunk.line_start}-L{chunk.line_end}]"
|
||||
context_parts.append(f"{chunk_text}\n{citation}")
|
||||
|
||||
context_str = "\n\n".join(context_parts) if context_parts else ""
|
||||
|
||||
# 3. Call LLM with context + question
|
||||
llm_response = self.llm_client.generate(
|
||||
system_prompt=self.system_prompt,
|
||||
user_prompt=query,
|
||||
context=context_str,
|
||||
)
|
||||
|
||||
answer_text = llm_response.get("answer", "")
|
||||
return answer_text, chunks
|
||||
Reference in New Issue
Block a user