commit d4925f7969f6d4fb4f680f886d8f21f43f3c9fcc Author: Kodjo Sossouvi Date: Fri Dec 12 11:31:44 2025 +0100 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..25475f0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +# Ignorer tous les fichiers .DS_Store quelle que soit leur profondeur +**/.DS_Store +prompts/spec/ diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..1c2fda5 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# Default ignored files +/shelf/ +/workspace.xml +# Editor-based HTTP Client requests +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/MyObsidianAI.iml b/.idea/MyObsidianAI.iml new file mode 100644 index 0000000..5524bed --- /dev/null +++ b/.idea/MyObsidianAI.iml @@ -0,0 +1,18 @@ + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..4069cea --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..470dda8 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..4d2dc3e --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..dc68dba --- /dev/null +++ b/README.md @@ -0,0 +1,191 @@ +# Obsidian RAG Backend + +A local, semantic search backend for Obsidian markdown files. + +## Project Overview + +### Context + +- **Target vault size**: ~1900 files, 480 MB +- **Deployment**: 100% local (no external APIs) +- **Usage**: Command-line interface (CLI) +- **Language**: Python 3.12 + +### Phase 1 Scope (Current) + +Semantic search system that: + +- Indexes markdown files from an Obsidian vault +- Performs semantic search using local embeddings +- Returns relevant results with metadata + +**Phase 2 (Future)**: Add LLM integration for answer generation using Phase 1 search results. + +## Features + +### Indexation + +- Manual, on-demand indexing +- Processes all `.md` files in vault +- Extracts document structure (sections, line numbers) +- Hybrid chunking strategy: + - Short sections (≤200 tokens): indexed as-is + - Long sections: split with sliding window (200 tokens, 30 tokens overlap) +- Robust error handling: continues indexing even if individual files fail + +### Search Results + +Each search result includes: + +- File path (relative to vault root) +- Similarity score +- Relevant text excerpt +- Location in file (section and line number) + +## Architecture + +``` +obsidian_rag/ +├── obsidian_rag/ +│ ├── __init__.py +│ ├── markdown_parser.py # Parse .md files, extract structure +│ ├── indexer.py # Generate embeddings and vector index +│ ├── searcher.py # Perform semantic search +│ └── cli.py # Typer CLI interface +├── tests/ +│ ├── __init__.py +│ ├── test_markdown_parser.py +│ ├── test_indexer.py +│ └── test_searcher.py +├── pyproject.toml +└── README.md +``` + +## Technical Choices + +### Technology Stack + +| Component | Technology | Rationale | +|---------------|--------------------------------------------|----------------------------------------------| +| Embeddings | sentence-transformers (`all-MiniLM-L6-v2`) | Local, lightweight (~80MB), good performance | +| Vector Store | ChromaDB | Simple, persistent, good Python integration | +| CLI Framework | Typer | Modern, type-safe, excellent UX | +| Testing | pytest | Standard, powerful, good ecosystem | + +### Design Decisions + +1. **Modular architecture**: Separate concerns (parsing, indexing, searching, CLI) for maintainability and testability +2. **Local-only**: All processing happens on local machine, no data sent to external services +3. **Manual indexing**: User triggers re-indexing when needed (incremental updates deferred to future phases) +4. **Hybrid chunking**: Preserves small sections intact while handling large sections with sliding window +5. **Token-based chunking**: Uses model's tokenizer for precise chunk sizing (max 200 tokens, 30 tokens overlap) +6. **Robust error handling**: Indexing continues even if individual files fail, with detailed error reporting +7. **Extensible design**: Architecture prepared for future LLM integration + +### Chunking Strategy Details + +The indexer uses a hybrid approach: + +- **Short sections** (≤200 tokens): Indexed as a single chunk to preserve semantic coherence +- **Long sections** (>200 tokens): Split using sliding window with: + - Maximum chunk size: 200 tokens (safe margin under model's 256 token limit) + - Overlap: 30 tokens (~15% overlap to preserve context at boundaries) + - Token counting: Uses sentence-transformers' native tokenizer for accuracy + +### Metadata Structure + +Each chunk stored in ChromaDB includes: + +```python +{ + "file_path": str, # Relative path from vault root + "section_title": str, # Markdown section heading + "line_start": int, # Starting line number in file + "line_end": int # Ending line number in file +} +``` + +## Dependencies + +### Required + +```bash +sentence-transformers # Local embeddings model (includes tokenizer) +chromadb # Vector database +typer # CLI framework +rich # Terminal formatting (Typer dependency) +``` + +### Development + +```bash +pytest # Testing framework +pytest-cov # Test coverage +``` + +### Installation + +```bash +pip install sentence-transformers chromadb typer[all] pytest pytest-cov +``` + +## Usage (Planned) + +```bash +# Index vault +obsidian-rag index /path/to/vault + +# Search +obsidian-rag search "your query here" + +# Search with options +obsidian-rag search "query" --limit 10 --min-score 0.5 +``` + +## Development Standards + +### Code Style + +- Follow PEP 8 conventions +- Use snake_case for variables and functions +- Docstrings in Google or NumPy format +- All code, comments, and documentation in English + +### Testing Strategy + +- Unit tests with pytest +- Test function naming: `test_i_can_xxx` (passing tests) or `test_i_cannot_xxx` (error cases) +- Functions over classes unless inheritance required +- Test plan validation before implementation + +### File Management + +- All file modifications documented with full file path +- Clear separation of concerns across modules + +## Project Status + +- [x] Requirements gathering +- [x] Architecture design +- [x] Chunking strategy validation +- [ ] Implementation + - [x] `markdown_parser.py` + - [x] `indexer.py` + - [x] `searcher.py` + - [x] `cli.py` +- [ ] Unit tests + - [x] `test_markdown_parser.py` + - [x] `test_indexer.py` (tests written, debugging in progress) + - [x] `test_searcher.py` + - [ ] `test_cli.py` +- [ ] Integration testing +- [ ] Documentation +- [ ] Phase 2: LLM integration + +## License + +[To be determined] + +## Contributing + +[To be determined] \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..e093252 --- /dev/null +++ b/main.py @@ -0,0 +1,16 @@ +# This is a sample Python script. + +# Press Maj+F10 to execute it or replace it with your code. +# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings. + + +def print_hi(name): + # Use a breakpoint in the code line below to debug your script. + print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint. + + +# Press the green button in the gutter to run the script. +if __name__ == '__main__': + print_hi('PyCharm') + +# See PyCharm help at https://www.jetbrains.com/help/pycharm/ diff --git a/obsidian_rag/__init__.py b/obsidian_rag/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/obsidian_rag/cli.py b/obsidian_rag/cli.py new file mode 100644 index 0000000..11d7efd --- /dev/null +++ b/obsidian_rag/cli.py @@ -0,0 +1,403 @@ +""" +CLI module for Obsidian RAG Backend. + +Provides command-line interface for indexing and searching the Obsidian vault. +""" +import os +from pathlib import Path +from typing import Optional + +import typer +from rich.console import Console +from rich.panel import Panel +from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn + +from indexer import index_vault +from rag_chain import RAGChain +from searcher import search_vault, SearchResult + +app = typer.Typer( + name="obsidian-rag", + help="Local semantic search backend for Obsidian markdown files", + add_completion=False, +) +console = Console() + +# Default ChromaDB path +DEFAULT_CHROMA_PATH = Path.home() / ".obsidian_rag" / "chroma_db" + + +def _truncate_path(path: str, max_len: int = 60) -> str: + """Return a truncated version of the file path if too long.""" + if len(path) <= max_len: + return path + return "..." + path[-(max_len - 3):] + + +@app.command() +def index( + vault_path: str = typer.Argument( + ..., + help="Path to the Obsidian vault directory", + ), + chroma_path: Optional[str] = typer.Option( + None, + "--chroma-path", + "-c", + help=f"Path to ChromaDB storage (default: {DEFAULT_CHROMA_PATH})", + ), + collection_name: str = typer.Option( + "obsidian_vault", + "--collection", + help="Name of the ChromaDB collection", + ), + max_chunk_tokens: int = typer.Option( + 200, + "--max-tokens", + help="Maximum tokens per chunk", + ), + overlap_tokens: int = typer.Option( + 30, + "--overlap", + help="Number of overlapping tokens between chunks", + ), +): + """ + Index all markdown files from the Obsidian vault into ChromaDB. + """ + vault_path_obj = Path(vault_path) + chroma_path_obj = Path(chroma_path) if chroma_path else DEFAULT_CHROMA_PATH + + if not vault_path_obj.exists(): + console.print(f"[red]✗ Error:[/red] Vault path does not exist: {vault_path}") + raise typer.Exit(code=1) + + if not vault_path_obj.is_dir(): + console.print(f"[red]✗ Error:[/red] Vault path is not a directory: {vault_path}") + raise typer.Exit(code=1) + + chroma_path_obj.mkdir(parents=True, exist_ok=True) + + md_files = list(vault_path_obj.rglob("*.md")) + total_files = len(md_files) + + if total_files == 0: + console.print(f"[yellow]⚠ Warning:[/yellow] No markdown files found in {vault_path}") + raise typer.Exit(code=0) + + console.print(f"\n[cyan]Found {total_files} markdown files to index[/cyan]\n") + + # One single stable progress bar + with Progress( + SpinnerColumn(), + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + console=console, + ) as progress: + + main_task = progress.add_task("[cyan]Indexing vault...", total=total_files) + + # Create a separate status line below the progress bar + status_line = console.status("[dim]Preparing first file...") + + def progress_callback(current_file: str, files_processed: int, total: int): + """Update progress bar and status message.""" + progress.update(main_task, completed=files_processed) + + short_file = _truncate_path(current_file) + status_line.update(f"[dim]Processing: {short_file}") + + try: + with status_line: + stats = index_vault( + vault_path=str(vault_path_obj), + chroma_db_path=str(chroma_path_obj), + collection_name=collection_name, + max_chunk_tokens=max_chunk_tokens, + overlap_tokens=overlap_tokens, + progress_callback=progress_callback, + ) + + progress.update(main_task, completed=total_files) + status_line.update("[green]✓ Completed") + + except Exception as e: + console.print(f"\n[red]✗ Error during indexing:[/red] {str(e)}") + raise typer.Exit(code=1) + + console.print() + _display_index_results(stats) + + +@app.command() +def search( + query: str = typer.Argument( + ..., + help="Search query", + ), + chroma_path: Optional[str] = typer.Option( + None, + "--chroma-path", + "-c", + help=f"Path to ChromaDB storage (default: {DEFAULT_CHROMA_PATH})", + ), + collection_name: str = typer.Option( + "obsidian_vault", + "--collection", + help="Name of the ChromaDB collection", + ), + limit: int = typer.Option( + 5, + "--limit", + "-l", + help="Maximum number of results to return", + ), + min_score: float = typer.Option( + 0.0, + "--min-score", + "-s", + help="Minimum similarity score (0.0 to 1.0)", + ), + format: str = typer.Option( + "compact", + "--format", + "-f", + help="Output format: compact (default), panel, table", + ), +): + """ + Search the indexed vault for semantically similar content. + + Returns relevant sections from your Obsidian notes based on + semantic similarity to the query. + """ + # Resolve paths + chroma_path_obj = Path(chroma_path) if chroma_path else DEFAULT_CHROMA_PATH + + # Validate chroma path exists + if not chroma_path_obj.exists(): + console.print( + f"[red]✗ Error:[/red] ChromaDB not found at {chroma_path_obj}\n" + f"Please run 'obsidian-rag index ' first to create the index." + ) + raise typer.Exit(code=1) + + # Validate format + valid_formats = ["compact", "panel", "table"] + if format not in valid_formats: + console.print(f"[red]✗ Error:[/red] Invalid format '{format}'. Valid options: {', '.join(valid_formats)}") + raise typer.Exit(code=1) + + # Perform search + try: + with console.status("[cyan]Searching...", spinner="dots"): + results = search_vault( + query=query, + chroma_db_path=str(chroma_path_obj), + collection_name=collection_name, + limit=limit, + min_score=min_score, + ) + except ValueError as e: + console.print(f"[red]✗ Error:[/red] {str(e)}") + raise typer.Exit(code=1) + except Exception as e: + console.print(f"[red]✗ Unexpected error:[/red] {str(e)}") + raise typer.Exit(code=1) + + # Display results + if not results: + console.print(f"\n[yellow]No results found for query:[/yellow] '{query}'") + if min_score > 0: + console.print(f"[dim]Try lowering --min-score (currently {min_score})[/dim]") + raise typer.Exit(code=0) + + console.print(f"\n[cyan]Found {len(results)} result(s) for:[/cyan] '{query}'\n") + + # Display with selected format + if format == "compact": + _display_results_compact(results) + elif format == "panel": + _display_results_panel(results) + elif format == "table": + _display_results_table(results) + + +@app.command() +def ask( + query: str = typer.Argument( + ..., + help="Question to ask the LLM based on your Obsidian notes." + ), + chroma_path: Optional[str] = typer.Option( + None, + "--chroma-path", + "-c", + help=f"Path to ChromaDB storage (default: {DEFAULT_CHROMA_PATH})", + ), + collection_name: str = typer.Option( + "obsidian_vault", + "--collection", + help="Name of the ChromaDB collection", + ), + top_k: int = typer.Option( + 5, + "--top-k", + "-k", + help="Number of top chunks to use for context", + ), + min_score: float = typer.Option( + 0.0, + "--min-score", + "-s", + help="Minimum similarity score for chunks", + ), + api_key: Optional[str] = typer.Option( + None, + "--api-key", + help="Clovis API key (or set CLOVIS_API_KEY environment variable)", + ), + base_url: Optional[str] = typer.Option( + None, + "--base-url", + help="Clovis base URL (or set CLOVIS_BASE_URL environment variable)", + ), +): + """ + Ask a question to the LLM using RAG over your Obsidian vault. + """ + + # Resolve ChromaDB path + chroma_path_obj = Path(chroma_path) if chroma_path else DEFAULT_CHROMA_PATH + if not chroma_path_obj.exists(): + console.print( + f"[red]✗ Error:[/red] ChromaDB not found at {chroma_path_obj}\n" + f"Please run 'obsidian-rag index ' first to create the index." + ) + raise typer.Exit(code=1) + + # Resolve API key and base URL + api_key = api_key or os.getenv("CLOVIS_API_KEY") + base_url = base_url or os.getenv("CLOVIS_BASE_URL") + if not api_key or not base_url: + console.print( + "[red]✗ Error:[/red] API key or base URL not provided.\n" + "Set them via --api-key / --base-url or environment variables CLOVIS_API_KEY and CLOVIS_BASE_URL." + ) + raise typer.Exit(code=1) + + # Instantiate RAGChain + rag = RAGChain( + chroma_db_path=str(chroma_path_obj), + collection_name=collection_name, + top_k=top_k, + min_score=min_score, + api_key=api_key, + base_url=base_url, + ) + + # Get answer from RAG + try: + with console.status("[cyan]Querying LLM...", spinner="dots"): + answer, used_chunks = rag.answer_query(query) + except Exception as e: + console.print(f"[red]✗ Error:[/red] {str(e)}") + raise typer.Exit(code=1) + + # Display answer + console.print("\n[bold green]Answer:[/bold green]\n") + console.print(answer + "\n") + + # Display sources used + if used_chunks: + sources = ", ".join(f"{c.file_path}#L{c.line_start}-L{c.line_end}" for c in used_chunks) + console.print(f"[bold cyan]Sources:[/bold cyan] {sources}\n") + else: + console.print("[bold cyan]Sources:[/bold cyan] None\n") + + +def _display_index_results(stats: dict): + """ + Display indexing results with rich formatting. + + Args: + stats: Statistics dictionary from index_vault + """ + files_processed = stats["files_processed"] + chunks_created = stats["chunks_created"] + errors = stats["errors"] + + # Success summary + console.print(Panel( + f"[green]✓[/green] Indexing completed\n\n" + f"Files processed: [cyan]{files_processed}[/cyan]\n" + f"Chunks created: [cyan]{chunks_created}[/cyan]\n" + f"Collection: [cyan]{stats['collection_name']}[/cyan]", + title="[bold]Indexing Results[/bold]", + border_style="green", + )) + + # Display errors if any + if errors: + console.print(f"\n[yellow]⚠ {len(errors)} file(s) skipped due to errors:[/yellow]\n") + for error in errors: + console.print(f" [red]•[/red] {error['file']}: [dim]{error['error']}[/dim]") + + +def _display_results_compact(results: list[SearchResult]): + """ + Display search results in compact format. + + Args: + results: List of SearchResult objects + """ + for idx, result in enumerate(results, 1): + # Format score as stars (0-5 scale) + stars = "⭐" * int(result.score * 5) + + console.print(f"[bold cyan]{idx}.[/bold cyan] {result.file_path} [dim](score: {result.score:.2f} {stars})[/dim]") + console.print( + f" Section: [yellow]{result.section_title}[/yellow] | Lines: [dim]{result.line_start}-{result.line_end}[/dim]") + + # Truncate text if too long + text = result.text + if len(text) > 200: + text = text[:200] + "..." + + console.print(f" {text}\n") + + +def _display_results_panel(results: list[SearchResult]): + """ + Display search results in panel format (rich boxes). + + Args: + results: List of SearchResult objects + """ + # TODO: Implement panel format in future + console.print("[yellow]Panel format not yet implemented. Using compact format.[/yellow]\n") + _display_results_compact(results) + + +def _display_results_table(results: list[SearchResult]): + """ + Display search results in table format. + + Args: + results: List of SearchResult objects + """ + # TODO: Implement table format in future + console.print("[yellow]Table format not yet implemented. Using compact format.[/yellow]\n") + _display_results_compact(results) + + +def main(): + """ + Entry point for the CLI application. + """ + app() + + +if __name__ == "__main__": + main() diff --git a/obsidian_rag/indexer.py b/obsidian_rag/indexer.py new file mode 100644 index 0000000..960436f --- /dev/null +++ b/obsidian_rag/indexer.py @@ -0,0 +1,284 @@ +""" +Indexer module for Obsidian RAG Backend. + +This module handles the indexing of markdown files into a ChromaDB vector store +using local embeddings from sentence-transformers. +""" +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Dict, List, Optional, Callable + +import chromadb +from chromadb.config import Settings +from sentence_transformers import SentenceTransformer + +from markdown_parser import ParsedDocument +from markdown_parser import parse_markdown_file + +# EMBEDDING_MODEL = "all-MiniLM-L6-v2" +EMBEDDING_MODEL = "all-MiniLM-L12-v2" + + +@dataclass +class ChunkMetadata: + file_path: str + section_title: str + line_start: int + line_end: int + + +@dataclass +class Chunk: + id: str + text: str + metadata: ChunkMetadata + + +def index_vault( + vault_path: str, + chroma_db_path: str, + collection_name: str = "obsidian_vault", + embedding_model: str = EMBEDDING_MODEL, + max_chunk_tokens: int = 200, + overlap_tokens: int = 30, + progress_callback: Optional[Callable[[str, int, int], None]] = None, +) -> Dict: + """ + Index all markdown files from vault into ChromaDB. + + Args: + vault_path: Path to the Obsidian vault directory + chroma_db_path: Path where ChromaDB will store its data + collection_name: Name of the ChromaDB collection + embedding_model: Name of the sentence-transformers model to use + max_chunk_tokens: Maximum tokens per chunk + overlap_tokens: Number of overlapping tokens between chunks + progress_callback: Optional callback function called for each file processed. + Signature: callback(current_file: str, files_processed: int, total_files: int) + + Returns: + Dictionary with indexing statistics: + - files_processed: Number of files successfully processed + - chunks_created: Total number of chunks created + - errors: List of errors encountered (file path and error message) + - collection_name: Name of the collection used + """ + + vault_path_obj = Path(vault_path) + if not vault_path_obj.exists(): + raise ValueError(f"Vault path does not exist: {vault_path}") + + # Initialize embedding model and tokenizer + model = SentenceTransformer(embedding_model) + tokenizer = model.tokenizer + + # Initialize ChromaDB client and collection + chroma_client = chromadb.PersistentClient( + path=chroma_db_path, + settings=Settings(anonymized_telemetry=False) + ) + collection = _get_or_create_collection(chroma_client, collection_name) + + # Find all markdown files + md_files = list(vault_path_obj.rglob("*.md")) + total_files = len(md_files) + + # Statistics tracking + stats = { + "files_processed": 0, + "chunks_created": 0, + "errors": [], + "collection_name": collection_name, + } + + # Process each file + for md_file in md_files: + # Get relative path for display + relative_path = md_file.relative_to(vault_path_obj) + + # Notify callback that we're starting this file + if progress_callback: + progress_callback(str(relative_path), stats["files_processed"], total_files) + + try: + # Parse markdown file + parsed_doc = parse_markdown_file(md_file) + + # Create chunks from document + chunks = _create_chunks_from_document( + parsed_doc, + tokenizer, + max_chunk_tokens, + overlap_tokens, + vault_path_obj, + ) + + if chunks: + # Extract data for ChromaDB + documents = [chunk.text for chunk in chunks] + metadatas = [asdict(chunk.metadata) for chunk in chunks] + ids = [chunk.id for chunk in chunks] + + # Generate embeddings and add to collection + embeddings = model.encode(documents, show_progress_bar=False) + collection.add( + documents=documents, + metadatas=metadatas, + ids=ids, + embeddings=embeddings.tolist(), + ) + + stats["chunks_created"] += len(chunks) + + stats["files_processed"] += 1 + + except Exception as e: + # Log error but continue processing + stats["errors"].append({ + "file": str(relative_path), + "error": str(e), + }) + + return stats + + +def _get_or_create_collection( + chroma_client: chromadb.PersistentClient, + collection_name: str, +) -> chromadb.Collection: + """ + Get or create a ChromaDB collection, resetting it if it already exists. + + Args: + chroma_client: ChromaDB client instance + collection_name: Name of the collection + + Returns: + ChromaDB collection instance + """ + try: + # Try to delete existing collection + chroma_client.delete_collection(name=collection_name) + except Exception: + # Collection doesn't exist, that's fine + pass + + # Create fresh collection + collection = chroma_client.create_collection( + name=collection_name, + metadata={"hnsw:space": "cosine"} # Use cosine similarity + ) + + return collection + + +def _create_chunks_from_document( + parsed_doc: ParsedDocument, + tokenizer, + max_chunk_tokens: int, + overlap_tokens: int, + vault_path: Path, +) -> List[Chunk]: + """ + Transform a parsed document into chunks with metadata. + + Implements hybrid chunking strategy: + - Short sections (≤max_chunk_tokens): one chunk per section + - Long sections (>max_chunk_tokens): split with sliding window + + Args: + parsed_doc: Parsed document from markdown_parser + tokenizer: Tokenizer from sentence-transformers model + max_chunk_tokens: Maximum tokens per chunk + overlap_tokens: Number of overlapping tokens between chunks + vault_path: Path to vault root (for relative path calculation) + + Returns: + List of chunk dictionaries with 'text', 'metadata', and 'id' keys + """ + chunks = [] + file_path = parsed_doc.file_path + relative_path = file_path.relative_to(vault_path) + + for section in parsed_doc.sections: + section_text = f"{parsed_doc.title} {section.title} {section.content}" + section_title = section.title + line_start = section.start_line + line_end = section.end_line + + # Tokenize section to check length + tokens = tokenizer.encode(section_text, add_special_tokens=False) + + if len(tokens) <= max_chunk_tokens: + # Short section: create single chunk + chunk_id = f"{relative_path}::{section_title}::{line_start}-{line_end}" + chunks.append(Chunk(chunk_id, section_text, ChunkMetadata(str(relative_path), + section_title, + line_start, + line_start + ))) + else: + # Long section: split with sliding window + sub_chunks = _chunk_section( + section_text, + tokenizer, + max_chunk_tokens, + overlap_tokens, + ) + + # Create chunk for each sub-chunk + for idx, sub_chunk_text in enumerate(sub_chunks): + chunk_id = f"{relative_path}::{section_title}::{line_start}-{line_end}::chunk{idx}" + chunks.append(Chunk(chunk_id, sub_chunk_text, ChunkMetadata(str(relative_path), + section_title, + line_start, + line_start + ))) + return chunks + + +def _chunk_section( + section_text: str, + tokenizer, + max_chunk_tokens: int, + overlap_tokens: int, +) -> List[str]: + """ + Split a section into overlapping chunks using sliding window. + + Args: + section_text: Text content to chunk + tokenizer: Tokenizer from sentence-transformers model + max_chunk_tokens: Maximum tokens per chunk + overlap_tokens: Number of overlapping tokens between chunks + + Returns: + List of text chunks + """ + # Apply safety margin to prevent decode/encode inconsistencies + # from exceeding the max token limit + max_chunk_tokens_to_use = int(max_chunk_tokens * 0.98) + + # Tokenize the full text + tokens = tokenizer.encode(section_text, add_special_tokens=False) + + chunks = [] + start_idx = 0 + + while start_idx < len(tokens): + # Extract chunk tokens + end_idx = start_idx + max_chunk_tokens_to_use + chunk_tokens = tokens[start_idx:end_idx] + + # Decode back to text + chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True) + chunks.append(chunk_text) + + # Move window forward (with overlap) + start_idx += max_chunk_tokens_to_use - overlap_tokens + + # Avoid infinite loop if overlap >= max_chunk_tokens + if start_idx <= start_idx - (max_chunk_tokens_to_use - overlap_tokens): + break + + return chunks diff --git a/obsidian_rag/llm_client.py b/obsidian_rag/llm_client.py new file mode 100644 index 0000000..2689a11 --- /dev/null +++ b/obsidian_rag/llm_client.py @@ -0,0 +1,74 @@ +from typing import Dict + +import openai + + +class LLMClient: + """ + Minimalist client for interacting with Clovis LLM via OpenAI SDK. + + Attributes: + api_key (str): API key for Clovis. + base_url (str): Base URL for Clovis LLM gateway. + model (str): Model name to use. Defaults to 'ClovisLLM'. + """ + + def __init__(self, api_key: str, base_url: str, model: str = "ClovisLLM") -> None: + if not api_key: + raise ValueError("API key is required for LLMClient.") + if not base_url: + raise ValueError("Base URL is required for LLMClient.") + + self.api_key = api_key + self.base_url = base_url + self.model = model + self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url) + + def generate(self, system_prompt: str, user_prompt: str, context: str) -> Dict[str, object]: + """ + Generate a response from the LLM given a system prompt, user prompt, and context. + + Args: + system_prompt (str): Instructions for the assistant. + user_prompt (str): The user's query. + context (str): Concatenated chunks from RAG search. + + Returns: + Dict[str, object]: Contains: + - "answer" (str): Text generated by the LLM. + - "usage" (int): Total tokens used in the completion. + """ + # Construct user message with explicit CONTEXT / QUESTION separation + user_message_content = f"CONTEXT:\n{context}\n\nQUESTION:\n{user_prompt}" + + try: + response = self.client.chat.completions.create(model=self.model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_message_content} + ], + temperature=0.7, + max_tokens=2000, + top_p=1.0, + n=1, + # stream=False, + # presence_penalty=0.0, + # frequency_penalty=0.0, + # stop=None, + # logit_bias={}, + user="obsidian_rag", + ) + except Exception as e: + # For now, propagate exceptions (C1 minimal) + raise e + + # Extract text and usage + try: + answer_text = response.choices[0].message.content + total_tokens = response.usage.total_tokens + except AttributeError: + # Fallback if response structure is unexpected + answer_text = "" + total_tokens = 0 + + return {"answer": answer_text, "usage": total_tokens} diff --git a/obsidian_rag/markdown_parser.py b/obsidian_rag/markdown_parser.py new file mode 100644 index 0000000..e865fcc --- /dev/null +++ b/obsidian_rag/markdown_parser.py @@ -0,0 +1,213 @@ +"""Markdown parser for Obsidian vault files. + +This module provides functionality to parse markdown files and extract +their structure (sections, line numbers) for semantic search indexing. +""" + +import re +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + + +@dataclass +class MarkdownSection: + """Represents a section in a markdown document. + + Attributes: + level: Header level (0 for no header, 1 for #, 2 for ##, etc.) + title: Section title (empty string if level=0) + content: Text content without the header line + start_line: Line number where section starts (1-indexed) + end_line: Line number where section ends (1-indexed, inclusive) + """ + level: int + title: str + content: str + parents: list[str] + start_line: int + end_line: int + + +@dataclass +class ParsedDocument: + """Represents a parsed markdown document. + + Attributes: + file_path: Path to the markdown file + sections: List of sections extracted from the document + raw_content: Full file content as string + """ + file_path: Path + title: str + sections: List[MarkdownSection] + raw_content: str + + +def _compute_parents(current_parents, previous_level, previous_title, current_level): + """Computes the parents of `current_parents`.""" + return current_parents + + +def parse_markdown_file(file_path: Path, vault_path=None) -> ParsedDocument: + """Parse a markdown file and extract its structure. + + This function reads a markdown file, identifies all header sections, + and extracts their content with precise line number tracking. + Files without headers are treated as a single section with level 0. + + Args: + file_path: Path to the markdown file to parse + vault_path: Path to the vault file. + + Returns: + ParsedDocument containing the file structure and content + + Raises: + FileNotFoundError: If the file does not exist + + Example: + >>> doc = parse_markdown_file(Path("notes/example.md")) + >>> print(f"Found {len(doc.sections)} sections") + >>> print(doc.sections[0].title) + """ + if not file_path.exists(): + raise FileNotFoundError(f"File not found: {file_path}") + + if vault_path: + title = str(file_path.relative_to(vault_path)).replace(".md", "") + title = title.replace("\\", " ").replace("/", " ") + else: + title = file_path.stem + + raw_content = file_path.read_text(encoding="utf-8") + lines = raw_content.splitlines() + + sections: List[MarkdownSection] = [] + current_section_start = 1 + current_level = 0 + current_title = "" + current_parents = [] + current_content_lines: List[str] = [] + + header_pattern = re.compile(r"^(#{1,6})\s+(.+)$") + + for line_num, line in enumerate(lines, start=1): + match = header_pattern.match(line) + + if match: + # Save the previous section only if it actually has content. + if current_content_lines: + content = "\n".join(current_content_lines) + sections.append( + MarkdownSection( + level=current_level, + title=current_title, + content=content, + parents=current_parents, + start_line=current_section_start, + end_line=line_num - 1, + ) + ) + + # Start a new section with the detected header. + previous_level = current_level + previous_title = current_title + current_level = len(match.group(1)) + current_title = match.group(2).strip() + current_section_start = line_num + current_parents = _compute_parents(current_parents, previous_level, previous_title, current_level) + current_content_lines = [] + else: + current_content_lines.append(line) + + # Handle the final section (or whole file if no headers were found). + if lines: + content = "\n".join(current_content_lines) + end_line = len(lines) + + # Case 1 – no header was ever found. + if not sections and current_level == 0: + sections.append( + MarkdownSection( + level=0, + title="", + content=content, + parents=current_parents, + start_line=1, + end_line=end_line, + ) + ) + # Case 2 – a single header was found (sections empty but we have a title). + elif not sections: + sections.append( + MarkdownSection( + level=current_level, + title=current_title, + content=content, + parents=current_parents, + start_line=current_section_start, + end_line=end_line, + ) + ) + # Case 3 – multiple headers were found (sections already contains earlier ones). + else: + sections.append( + MarkdownSection( + level=current_level, + title=current_title, + content=content, + parents=current_parents, + start_line=current_section_start, + end_line=end_line, + ) + ) + else: + # Empty file: create a single empty level‑0 section. + sections.append( + MarkdownSection( + level=0, + title="", + content="", + parents=[], + start_line=1, + end_line=1, + ) + ) + + return ParsedDocument( + file_path=file_path, + title=title, + sections=sections, + raw_content=raw_content, + ) + + +def find_section_at_line( + document: ParsedDocument, + line_number: int, +) -> Optional[MarkdownSection]: + """Find which section contains a given line number. + + This function searches through the document's sections to find + which section contains the specified line number. + + Args: + document: Parsed markdown document + line_number: Line number to search for (1-indexed) + + Returns: + MarkdownSection containing the line, or None if line number + is invalid or out of range + + Example: + >>> section = find_section_at_line(doc, 42) + >>> if section: + ... print(f"Line 42 is in section: {section.title}") + """ + if line_number < 1: + return None + + for section in document.sections: + if section.start_line <= line_number <= section.end_line: + return section diff --git a/obsidian_rag/rag_chain.py b/obsidian_rag/rag_chain.py new file mode 100644 index 0000000..6691ead --- /dev/null +++ b/obsidian_rag/rag_chain.py @@ -0,0 +1,96 @@ +# File: obsidian_rag/rag_chain.py + +from pathlib import Path +from typing import List, Tuple + +from indexer import EMBEDDING_MODEL +from llm_client import LLMClient +from searcher import search_vault, SearchResult + + +class RAGChain: + """ + Retrieval-Augmented Generation (RAG) chain for answering queries + using semantic search over an Obsidian vault and LLM. + + Attributes: + chroma_db_path (Path): Path to ChromaDB. + collection_name (str): Chroma collection name. + embedding_model (str): Embedding model name. + top_k (int): Number of chunks to send to the LLM. + min_score (float): Minimum similarity score for chunks. + system_prompt (str): System prompt to instruct the LLM. + llm_client (LLMClient): Internal LLM client instance. + """ + + DEFAULT_SYSTEM_PROMPT = ( + "You are an assistant specialized in analyzing Obsidian notes.\n\n" + "INSTRUCTIONS:\n" + "- Answer based ONLY on the provided context\n" + "- Cite the sources (files) you use\n" + "- If the information is not in the context, say \"I did not find this information in your notes\"\n" + "- Be concise but thorough\n" + "- Structure your answer with sections if necessary" + ) + + def __init__( + self, + chroma_db_path: str, + api_key: str, + base_url: str, + collection_name: str = "obsidian_vault", + embedding_model: str = EMBEDDING_MODEL, + top_k: int = 5, + min_score: float = 0.0, + system_prompt: str = None, + ) -> None: + self.chroma_db_path = Path(chroma_db_path) + self.collection_name = collection_name + self.embedding_model = embedding_model + self.top_k = top_k + self.min_score = min_score + self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT + + # Instantiate internal LLM client + self.llm_client = LLMClient(api_key=api_key, base_url=base_url) + + def answer_query(self, query: str) -> Tuple[str, List[SearchResult]]: + """ + Answer a user query using RAG: search vault, build context, call LLM. + + Args: + query (str): User query. + + Returns: + Tuple[str, List[SearchResult]]: + - LLM answer (str) + - List of used SearchResult chunks + """ + # 1. Perform semantic search + chunks: List[SearchResult] = search_vault( + query=query, + chroma_db_path=str(self.chroma_db_path), + collection_name=self.collection_name, + embedding_model=self.embedding_model, + limit=self.top_k, + min_score=self.min_score, + ) + + # 2. Build context string with citations + context_parts: List[str] = [] + for chunk in chunks: + chunk_text = chunk.text.strip() + citation = f"[{chunk.file_path}#L{chunk.line_start}-L{chunk.line_end}]" + context_parts.append(f"{chunk_text}\n{citation}") + + context_str = "\n\n".join(context_parts) if context_parts else "" + + # 3. Call LLM with context + question + llm_response = self.llm_client.generate( + system_prompt=self.system_prompt, + user_prompt=query, + context=context_str, + ) + + answer_text = llm_response.get("answer", "") + return answer_text, chunks diff --git a/obsidian_rag/searcher.py b/obsidian_rag/searcher.py new file mode 100644 index 0000000..e80d211 --- /dev/null +++ b/obsidian_rag/searcher.py @@ -0,0 +1,131 @@ +""" +Searcher module for Obsidian RAG Backend. + +This module handles semantic search operations on the indexed ChromaDB collection. +""" +from dataclasses import dataclass +from typing import List +import chromadb +from chromadb.config import Settings +from sentence_transformers import SentenceTransformer + +from indexer import EMBEDDING_MODEL + + +@dataclass +class SearchResult: + """ + Represents a single search result with metadata and relevance score. + """ + file_path: str + section_title: str + line_start: int + line_end: int + score: float + text: str + + +def search_vault( + query: str, + chroma_db_path: str, + collection_name: str = "obsidian_vault", + embedding_model: str = EMBEDDING_MODEL, + limit: int = 5, + min_score: float = 0.0, +) -> List[SearchResult]: + """ + Search the indexed vault for semantically similar content. + + Args: + query: Search query string + chroma_db_path: Path to ChromaDB data directory + collection_name: Name of the ChromaDB collection to search + embedding_model: Model used for embeddings (must match indexing model) + limit: Maximum number of results to return + min_score: Minimum similarity score threshold (0.0 to 1.0) + + Returns: + List of SearchResult objects, sorted by relevance (highest score first) + + Raises: + ValueError: If the collection does not exist or query is empty + """ + if not query or not query.strip(): + raise ValueError("Query cannot be empty") + + # Initialize ChromaDB client + chroma_client = chromadb.PersistentClient( + path=chroma_db_path, + settings=Settings(anonymized_telemetry=False) + ) + + # Get collection (will raise if it doesn't exist) + try: + collection = chroma_client.get_collection(name=collection_name) + except Exception as e: + raise ValueError( + f"Collection '{collection_name}' not found. " + f"Please index your vault first using the index command." + ) from e + + # Initialize embedding model (same as used during indexing) + model = SentenceTransformer(embedding_model) + + # Generate query embedding + query_embedding = model.encode(query, show_progress_bar=False) + + # Perform search + results = collection.query( + query_embeddings=[query_embedding.tolist()], + n_results=limit, + ) + + # Parse and format results + search_results = _parse_search_results(results, min_score) + + return search_results + + +def _parse_search_results( + raw_results: dict, + min_score: float, +) -> List[SearchResult]: + """ + Parse ChromaDB query results into SearchResult objects. + + ChromaDB returns distances (lower = more similar). We convert to + similarity scores (higher = more similar) using: score = 1 - distance + + Args: + raw_results: Raw results dictionary from ChromaDB query + min_score: Minimum similarity score to include + + Returns: + List of SearchResult objects filtered by min_score + """ + search_results = [] + + # ChromaDB returns results as lists of lists (one list per query) + # We only have one query, so we take the first element + documents = raw_results.get("documents", [[]])[0] + metadatas = raw_results.get("metadatas", [[]])[0] + distances = raw_results.get("distances", [[]])[0] + + for doc, metadata, distance in zip(documents, metadatas, distances): + # Convert distance to similarity score (cosine distance -> cosine similarity) + score = 1.0 - distance + + # Filter by minimum score + if score < min_score: + continue + + search_results.append(SearchResult( + file_path=metadata["file_path"], + section_title=metadata["section_title"], + line_start=metadata["line_start"], + line_end=metadata["line_end"], + score=score, + text=doc, + )) + + return search_results \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..47b1125 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,537 @@ +""" +Unit tests for the CLI module. +""" +import pytest +from pathlib import Path +from typer.testing import CliRunner +from obsidian_rag.cli import app, _display_index_results, _display_results_compact +from obsidian_rag.indexer import index_vault +from obsidian_rag.searcher import SearchResult + +runner = CliRunner() + + +@pytest.fixture +def temp_vault(tmp_path): + """ + Create a temporary vault with sample markdown files. + """ + vault_path = tmp_path / "test_vault" + vault_path.mkdir() + + # Create sample files + file1 = vault_path / "python.md" + file1.write_text("""# Python Programming + +Python is a high-level programming language. + +## Features + +Python has dynamic typing and automatic memory management. +""") + + file2 = vault_path / "javascript.md" + file2.write_text("""# JavaScript + +JavaScript is a scripting language for web development. + +## Usage + +JavaScript runs in web browsers and Node.js environments. +""") + + file3 = vault_path / "cooking.md" + file3.write_text("""# Cooking Tips + +Learn how to cook delicious meals. + +## Basics + +Start with simple recipes and basic techniques. +""") + + return vault_path + + +# Tests for 'index' command - Passing tests + + +def test_i_can_index_vault_successfully(temp_vault, tmp_path): + """ + Test that we can index a vault successfully. + """ + chroma_path = tmp_path / "chroma_db" + + result = runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(chroma_path), + ]) + + assert result.exit_code == 0 + assert "Found 3 markdown files to index" in result.stdout + assert "Indexing completed" in result.stdout + assert "Files processed:" in result.stdout + assert "Chunks created:" in result.stdout + + +def test_i_can_index_with_custom_chroma_path(temp_vault, tmp_path): + """ + Test that we can specify a custom ChromaDB path. + """ + custom_chroma = tmp_path / "my_custom_db" + + result = runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(custom_chroma), + ]) + + assert result.exit_code == 0 + assert custom_chroma.exists() + assert (custom_chroma / "chroma.sqlite3").exists() + + +def test_i_can_index_with_custom_collection_name(temp_vault, tmp_path): + """ + Test that we can use a custom collection name. + """ + chroma_path = tmp_path / "chroma_db" + collection_name = "my_custom_collection" + + result = runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(chroma_path), + "--collection", collection_name, + ]) + + assert result.exit_code == 0 + assert f"Collection: {collection_name}" in result.stdout + + +def test_i_can_see_errors_in_index_results(tmp_path): + """ + Test that errors during indexing are displayed. + """ + vault_path = tmp_path / "vault_with_errors" + vault_path.mkdir() + + # Create a valid file + valid_file = vault_path / "valid.md" + valid_file.write_text("# Valid File\n\nThis is valid content.") + + # Create an invalid file (will cause parsing error) + invalid_file = vault_path / "invalid.md" + invalid_file.write_bytes(b"\xff\xfe\x00\x00") # Invalid UTF-8 + + chroma_path = tmp_path / "chroma_db" + + result = runner.invoke(app, [ + "index", + str(vault_path), + "--chroma-path", str(chroma_path), + ]) + + # Should still complete (exit code 0) but show errors + assert result.exit_code == 0 + assert "Indexing completed" in result.stdout + # Note: Error display might vary, just check it completed + + +# Tests for 'index' command - Error tests + + +def test_i_cannot_index_nonexistent_vault(tmp_path): + """ + Test that indexing a nonexistent vault fails with clear error. + """ + nonexistent_path = tmp_path / "does_not_exist" + chroma_path = tmp_path / "chroma_db" + + result = runner.invoke(app, [ + "index", + str(nonexistent_path), + "--chroma-path", str(chroma_path), + ]) + + assert result.exit_code == 1 + assert "does not exist" in result.stdout + + +def test_i_cannot_index_file_instead_of_directory(tmp_path): + """ + Test that indexing a file (not directory) fails. + """ + file_path = tmp_path / "somefile.txt" + file_path.write_text("I am a file") + chroma_path = tmp_path / "chroma_db" + + result = runner.invoke(app, [ + "index", + str(file_path), + "--chroma-path", str(chroma_path), + ]) + + assert result.exit_code == 1 + assert "not a directory" in result.stdout + + +def test_i_can_handle_empty_vault_gracefully(tmp_path): + """ + Test that an empty vault (no .md files) is handled gracefully. + """ + empty_vault = tmp_path / "empty_vault" + empty_vault.mkdir() + + # Create a non-markdown file + (empty_vault / "readme.txt").write_text("Not a markdown file") + + chroma_path = tmp_path / "chroma_db" + + result = runner.invoke(app, [ + "index", + str(empty_vault), + "--chroma-path", str(chroma_path), + ]) + + assert result.exit_code == 0 + assert "No markdown files found" in result.stdout + + +# Tests for 'search' command - Passing tests + + +def test_i_can_search_indexed_vault(temp_vault, tmp_path): + """ + Test that we can search an indexed vault. + """ + chroma_path = tmp_path / "chroma_db" + + # First, index the vault + index_result = runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(chroma_path), + ]) + assert index_result.exit_code == 0 + + # Then search + search_result = runner.invoke(app, [ + "search", + "Python programming", + "--chroma-path", str(chroma_path), + ]) + + assert search_result.exit_code == 0 + assert "Found" in search_result.stdout + assert "result(s) for:" in search_result.stdout + assert "python.md" in search_result.stdout + + +def test_i_can_search_with_limit_option(temp_vault, tmp_path): + """ + Test that the --limit option works. + """ + chroma_path = tmp_path / "chroma_db" + + # Index + runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(chroma_path), + ]) + + # Search with limit + result = runner.invoke(app, [ + "search", + "programming", + "--chroma-path", str(chroma_path), + "--limit", "2", + ]) + + assert result.exit_code == 0 + # Count result numbers (1., 2., etc.) + result_count = result.stdout.count("[bold cyan]") + assert result_count <= 2 + + +def test_i_can_search_with_min_score_option(temp_vault, tmp_path): + """ + Test that the --min-score option works. + """ + chroma_path = tmp_path / "chroma_db" + + # Index + runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(chroma_path), + ]) + + # Search with high min-score + result = runner.invoke(app, [ + "search", + "Python", + "--chroma-path", str(chroma_path), + "--min-score", "0.5", + ]) + + assert result.exit_code == 0 + # Should have results (Python file should match well) + assert "Found" in result.stdout + + +def test_i_can_search_with_custom_collection(temp_vault, tmp_path): + """ + Test that we can search in a custom collection. + """ + chroma_path = tmp_path / "chroma_db" + collection_name = "test_collection" + + # Index with custom collection + runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(chroma_path), + "--collection", collection_name, + ]) + + # Search in same collection + result = runner.invoke(app, [ + "search", + "Python", + "--chroma-path", str(chroma_path), + "--collection", collection_name, + ]) + + assert result.exit_code == 0 + assert "Found" in result.stdout + + +def test_i_can_handle_no_results_gracefully(temp_vault, tmp_path): + """ + Test that no results scenario is handled gracefully. + """ + chroma_path = tmp_path / "chroma_db" + + # Index + runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(chroma_path), + ]) + + # Search for something unlikely with high threshold + result = runner.invoke(app, [ + "search", + "quantum physics relativity", + "--chroma-path", str(chroma_path), + "--min-score", "0.95", + ]) + + assert result.exit_code == 0 + assert "No results found" in result.stdout + + +def test_i_can_use_compact_format(temp_vault, tmp_path): + """ + Test that compact format displays correctly. + """ + chroma_path = tmp_path / "chroma_db" + + # Index + runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(chroma_path), + ]) + + # Search with explicit compact format + result = runner.invoke(app, [ + "search", + "Python", + "--chroma-path", str(chroma_path), + "--format", "compact", + ]) + + assert result.exit_code == 0 + # Check for compact format elements + assert "Section:" in result.stdout + assert "Lines:" in result.stdout + assert "score:" in result.stdout + + +# Tests for 'search' command - Error tests + + +def test_i_cannot_search_without_index(tmp_path): + """ + Test that searching without indexing fails with clear message. + """ + chroma_path = tmp_path / "nonexistent_chroma" + + result = runner.invoke(app, [ + "search", + "test query", + "--chroma-path", str(chroma_path), + ]) + + assert result.exit_code == 1 + assert "not found" in result.stdout + assert "index" in result.stdout.lower() + + +def test_i_cannot_search_nonexistent_collection(temp_vault, tmp_path): + """ + Test that searching in a nonexistent collection fails. + """ + chroma_path = tmp_path / "chroma_db" + + # Index with default collection + runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(chroma_path), + ]) + + # Search in different collection + result = runner.invoke(app, [ + "search", + "Python", + "--chroma-path", str(chroma_path), + "--collection", "nonexistent_collection", + ]) + + assert result.exit_code == 1 + assert "not found" in result.stdout + + +def test_i_cannot_use_invalid_format(temp_vault, tmp_path): + """ + Test that an invalid format is rejected. + """ + chroma_path = tmp_path / "chroma_db" + + # Index + runner.invoke(app, [ + "index", + str(temp_vault), + "--chroma-path", str(chroma_path), + ]) + + # Search with invalid format + result = runner.invoke(app, [ + "search", + "Python", + "--chroma-path", str(chroma_path), + "--format", "invalid_format", + ]) + + assert result.exit_code == 1 + assert "Invalid format" in result.stdout + assert "compact" in result.stdout + + +# Tests for helper functions + + +def test_i_can_display_index_results(capsys): + """ + Test that index results are displayed correctly. + """ + stats = { + "files_processed": 10, + "chunks_created": 50, + "collection_name": "test_collection", + "errors": [], + } + + _display_index_results(stats) + + captured = capsys.readouterr() + assert "Indexing completed" in captured.out + assert "10" in captured.out + assert "50" in captured.out + assert "test_collection" in captured.out + + +def test_i_can_display_index_results_with_errors(capsys): + """ + Test that index results with errors are displayed correctly. + """ + stats = { + "files_processed": 8, + "chunks_created": 40, + "collection_name": "test_collection", + "errors": [ + {"file": "broken.md", "error": "Invalid encoding"}, + {"file": "corrupt.md", "error": "Parse error"}, + ], + } + + _display_index_results(stats) + + captured = capsys.readouterr() + assert "Indexing completed" in captured.out + assert "2 file(s) skipped" in captured.out + assert "broken.md" in captured.out + assert "Invalid encoding" in captured.out + + +def test_i_can_display_results_compact(capsys): + """ + Test that compact results display correctly. + """ + results = [ + SearchResult( + file_path="notes/python.md", + section_title="Introduction", + line_start=1, + line_end=5, + score=0.87, + text="Python is a high-level programming language.", + ), + SearchResult( + file_path="notes/javascript.md", + section_title="Overview", + line_start=10, + line_end=15, + score=0.65, + text="JavaScript is used for web development.", + ), + ] + + _display_results_compact(results) + + captured = capsys.readouterr() + assert "python.md" in captured.out + assert "javascript.md" in captured.out + assert "0.87" in captured.out + assert "0.65" in captured.out + assert "Introduction" in captured.out + assert "Overview" in captured.out + + +def test_i_can_display_results_compact_with_long_text(capsys): + """ + Test that long text is truncated in compact display. + """ + long_text = "A" * 300 # Text longer than 200 characters + + results = [ + SearchResult( + file_path="notes/long.md", + section_title="Long Section", + line_start=1, + line_end=10, + score=0.75, + text=long_text, + ), + ] + + _display_results_compact(results) + + captured = capsys.readouterr() + assert "..." in captured.out # Should be truncated + assert len([line for line in captured.out.split('\n') if 'A' * 200 in line]) == 0 # Full text not shown \ No newline at end of file diff --git a/tests/test_indexer.py b/tests/test_indexer.py new file mode 100644 index 0000000..bb5ae88 --- /dev/null +++ b/tests/test_indexer.py @@ -0,0 +1,381 @@ +""" +Unit tests for the indexer module. +""" + +import chromadb +import pytest +from chromadb.config import Settings +from sentence_transformers import SentenceTransformer + +from indexer import ( + index_vault, + _chunk_section, + _create_chunks_from_document, + _get_or_create_collection, EMBEDDING_MODEL, +) +from obsidian_rag.markdown_parser import ParsedDocument, MarkdownSection + + +# Fixtures + +@pytest.fixture +def tokenizer(): + """Provide sentence-transformers tokenizer.""" + model = SentenceTransformer(EMBEDDING_MODEL) + return model.tokenizer + + +@pytest.fixture +def embedding_model(): + """Provide sentence-transformers model.""" + return SentenceTransformer(EMBEDDING_MODEL) + + +@pytest.fixture +def chroma_client(tmp_path): + """Provide ChromaDB client with temporary storage.""" + client = chromadb.PersistentClient( + path=str(tmp_path / "chroma_test"), + settings=Settings(anonymized_telemetry=False) + ) + return client + + +@pytest.fixture +def test_vault(tmp_path): + """Create a temporary vault with test markdown files.""" + vault_path = tmp_path / "test_vault" + vault_path.mkdir() + return vault_path + + +# Tests for _chunk_section() + +def test_i_can_chunk_short_section_into_single_chunk(tokenizer): + """Test that a short section is not split.""" + # Create text with ~100 tokens + short_text = " ".join(["word"] * 100) + + chunks = _chunk_section( + section_text=short_text, + tokenizer=tokenizer, + max_chunk_tokens=200, + overlap_tokens=30, + ) + + assert len(chunks) == 1 + assert chunks[0] == short_text + + +def test_i_can_chunk_long_section_with_overlap(tokenizer): + """Test splitting long section with overlap.""" + # Create text with ~500 tokens + long_text = " ".join([f"word{i}" for i in range(500)]) + + chunks = _chunk_section( + section_text=long_text, + tokenizer=tokenizer, + max_chunk_tokens=200, + overlap_tokens=30, + ) + + # Should create multiple chunks + assert len(chunks) >= 2 + + # Verify no chunk exceeds max tokens + for chunk in chunks: + tokens = tokenizer.encode(chunk, add_special_tokens=False) + assert len(tokens) <= 200 + + # Verify overlap exists between consecutive chunks + for i in range(len(chunks) - 1): + # Check that some words from end of chunk[i] appear in start of chunk[i+1] + words_chunk1 = chunks[i].split()[-10:] # Last 10 words + words_chunk2 = chunks[i + 1].split()[:10] # First 10 words + + # At least some overlap should exist + overlap_found = any(word in words_chunk2 for word in words_chunk1) + assert overlap_found + + +def test_i_can_chunk_empty_section(tokenizer): + """Test chunking an empty section.""" + empty_text = "" + + chunks = _chunk_section( + section_text=empty_text, + tokenizer=tokenizer, + max_chunk_tokens=200, + overlap_tokens=30, + ) + + assert len(chunks) == 0 + + +# Tests for _create_chunks_from_document() + +def test_i_can_create_chunks_from_document_with_short_sections(tmp_path, tokenizer): + """Test creating chunks from document with only short sections.""" + vault_path = tmp_path / "vault" + vault_path.mkdir() + + parsed_doc = ParsedDocument( + file_path=vault_path / "test.md", + title="test.md", + sections=[ + MarkdownSection(1, "Section 1", "This is a short section with few words.", [], 1, 2), + MarkdownSection(2, "Section 2", "Another short section here.", ["Section 1"], 3, 4), + MarkdownSection(3, "Section 3", "Third short section.", ["Section 1", "Section 3"], 5, 6), + ], + raw_content="" # not used in this test + ) + + chunks = _create_chunks_from_document( + parsed_doc=parsed_doc, + tokenizer=tokenizer, + max_chunk_tokens=200, + overlap_tokens=30, + vault_path=vault_path, + ) + + # Should create 3 chunks (one per section) + assert len(chunks) == 3 + + # Verify metadata + for i, chunk in enumerate(chunks): + metadata = chunk.metadata + assert metadata.file_path == "test.md" + assert metadata.section_title == f"Section {i + 1}" + assert isinstance(metadata.line_start, int) + assert isinstance(metadata.line_end, int) + + # Verify ID format + assert "test.md" in chunk.id + assert f"Section {i + 1}" in chunk.id + + +def test_i_can_create_chunks_from_document_with_long_section(tmp_path, tokenizer): + """Test creating chunks from document with a long section that needs splitting.""" + vault_path = tmp_path / "vault" + vault_path.mkdir() + + # Create long content (~500 tokens) + long_content = " ".join([f"word{i}" for i in range(500)]) + + parsed_doc = ParsedDocument( + file_path=vault_path / "test.md", + title="test.md", + sections=[ + MarkdownSection(1, "Long Section", long_content, [], 1, 1) + ], + raw_content=long_content, + ) + + chunks = _create_chunks_from_document( + parsed_doc=parsed_doc, + tokenizer=tokenizer, + max_chunk_tokens=200, + overlap_tokens=30, + vault_path=vault_path, + ) + + # Should create multiple chunks + assert len(chunks) >= 2 + + # All chunks should have same section_title + for chunk in chunks: + assert chunk.metadata.section_title == "Long Section" + assert chunk.metadata.line_start == 1 + assert chunk.metadata.line_end == 1 + + # IDs should include chunk numbers + assert "::chunk0" in chunks[0].id + assert "::chunk1" in chunks[1].id + + +def test_i_can_create_chunks_with_correct_relative_paths(tmp_path, tokenizer): + """Test that relative paths are correctly computed.""" + vault_path = tmp_path / "vault" + vault_path.mkdir() + + # Create subdirectory + subdir = vault_path / "subfolder" + subdir.mkdir() + + parsed_doc = ParsedDocument( + file_path=subdir / "nested.md", + title=f"{subdir} nested.md", + sections=[ + MarkdownSection(1, "Section", "Some content here.", [], 1, 2), + ], + raw_content="", + ) + + chunks = _create_chunks_from_document( + parsed_doc=parsed_doc, + tokenizer=tokenizer, + max_chunk_tokens=200, + overlap_tokens=30, + vault_path=vault_path, + ) + + assert len(chunks) == 1 + assert chunks[0].metadata.file_path == "subfolder/nested.md" + + +# Tests for _get_or_create_collection() + +def test_i_can_create_new_collection(chroma_client): + """Test creating a new collection that doesn't exist.""" + collection_name = "test_collection" + + collection = _get_or_create_collection(chroma_client, collection_name) + + assert collection.name == collection_name + assert collection.count() == 0 # Should be empty + + +def test_i_can_reset_existing_collection(chroma_client): + """Test that an existing collection is deleted and recreated.""" + collection_name = "test_collection" + + # Create collection and add data + first_collection = chroma_client.create_collection(collection_name) + first_collection.add( + documents=["test document"], + ids=["test_id"], + ) + assert first_collection.count() == 1 + + # Reset collection + new_collection = _get_or_create_collection(chroma_client, collection_name) + + assert new_collection.name == collection_name + assert new_collection.count() == 0 # Should be empty after reset + + +# Tests for index_vault() + +def test_i_can_index_single_markdown_file(test_vault, tmp_path, embedding_model): + """Test indexing a single markdown file.""" + # Create test markdown file + test_file = test_vault / "test.md" + test_file.write_text( + "# Title\n\nThis is a test document with some content.\n\n## Section\n\nMore content here." + ) + + chroma_path = tmp_path / "chroma_db" + + stats = index_vault( + vault_path=str(test_vault), + chroma_db_path=str(chroma_path), + collection_name="test_collection", + ) + + assert stats["files_processed"] == 1 + assert stats["chunks_created"] > 0 + assert stats["errors"] == [] + assert stats["collection_name"] == "test_collection" + + # Verify collection contains data + client = chromadb.PersistentClient( + path=str(chroma_path), + settings=Settings(anonymized_telemetry=False) + ) + collection = client.get_collection("test_collection") + assert collection.count() == stats["chunks_created"] + + +def test_i_can_index_multiple_markdown_files(test_vault, tmp_path): + """Test indexing multiple markdown files.""" + # Create multiple test files + for i in range(3): + test_file = test_vault / f"test{i}.md" + test_file.write_text(f"# Document {i}\n\nContent for document {i}.") + + chroma_path = tmp_path / "chroma_db" + + stats = index_vault( + vault_path=str(test_vault), + chroma_db_path=str(chroma_path), + ) + + assert stats["files_processed"] == 3 + assert stats["chunks_created"] >= 3 # At least one chunk per file + assert stats["errors"] == [] + + +def test_i_can_continue_indexing_after_file_error(test_vault, tmp_path, monkeypatch): + """Test that indexing continues after encountering an error.""" + # Create valid files + (test_vault / "valid1.md").write_text("# Valid 1\n\nContent here.") + (test_vault / "valid2.md").write_text("# Valid 2\n\nMore content.") + (test_vault / "problematic.md").write_text("# Problem\n\nThis will fail.") + + # Mock parse_markdown_file to fail for problematic.md + from obsidian_rag import markdown_parser + original_parse = markdown_parser.parse_markdown_file + + def mock_parse(file_path): + if "problematic.md" in str(file_path): + raise ValueError("Simulated parsing error") + return original_parse(file_path) + + monkeypatch.setattr("indexer.parse_markdown_file", mock_parse) + + chroma_path = tmp_path / "chroma_db" + + stats = index_vault( + vault_path=str(test_vault), + chroma_db_path=str(chroma_path), + ) + + # Should process 2 valid files + assert stats["files_processed"] == 2 + assert len(stats["errors"]) == 1 + assert "problematic.md" in stats["errors"][0]["file"] + assert "Simulated parsing error" in stats["errors"][0]["error"] + + +def test_i_cannot_index_nonexistent_vault(tmp_path): + """Test that indexing a nonexistent vault raises an error.""" + nonexistent_path = tmp_path / "nonexistent_vault" + chroma_path = tmp_path / "chroma_db" + + with pytest.raises(ValueError, match="Vault path does not exist"): + index_vault( + vault_path=str(nonexistent_path), + chroma_db_path=str(chroma_path), + ) + + +def test_i_can_verify_embeddings_are_generated(test_vault, tmp_path): + """Test that embeddings are properly generated and stored.""" + # Create test file + test_file = test_vault / "test.md" + test_file.write_text("# Test\n\nThis is test content for embedding generation.") + + chroma_path = tmp_path / "chroma_db" + + stats = index_vault( + vault_path=str(test_vault), + chroma_db_path=str(chroma_path), + ) + + # Verify embeddings in collection + client = chromadb.PersistentClient( + path=str(chroma_path), + settings=Settings(anonymized_telemetry=False) + ) + collection = client.get_collection("obsidian_vault") + + # Get all items + results = collection.get(include=["embeddings"]) + + assert len(results["ids"]) == stats["chunks_created"] + assert results["embeddings"] is not None + + # Verify embeddings are non-zero vectors of correct dimension + for embedding in results["embeddings"]: + assert len(embedding) == 384 # all-MiniLM-L6-v2 dimension + assert any(val != 0 for val in embedding) # Not all zeros diff --git a/tests/test_markdown_parser.py b/tests/test_markdown_parser.py new file mode 100644 index 0000000..f951de9 --- /dev/null +++ b/tests/test_markdown_parser.py @@ -0,0 +1,238 @@ +"""Unit tests for markdown_parser module.""" + +import pytest +from pathlib import Path +from markdown_parser import ( + parse_markdown_file, + find_section_at_line, + MarkdownSection, + ParsedDocument +) + + +@pytest.fixture +def tmp_markdown_file(tmp_path): + """Fixture to create temporary markdown files for testing. + + Args: + tmp_path: pytest temporary directory fixture + + Returns: + Function that creates a markdown file with given content + """ + + def _create_file(content: str, filename: str = "test.md") -> Path: + file_path = tmp_path / filename + file_path.write_text(content, encoding="utf-8") + return file_path + + return _create_file + + +# Tests for parse_markdown_file() + +def test_i_can_parse_file_with_single_section(tmp_markdown_file): + """Test parsing a file with a single header section.""" + content = """# Main Title +This is the content of the section. +It has multiple lines.""" + + file_path = tmp_markdown_file(content) + doc = parse_markdown_file(file_path) + + assert len(doc.sections) == 1 + assert doc.sections[0].level == 1 + assert doc.sections[0].title == "Main Title" + assert "This is the content" in doc.sections[0].content + assert doc.sections[0].start_line == 1 + assert doc.sections[0].end_line == 3 + + +def test_i_can_parse_file_with_multiple_sections(tmp_markdown_file): + """Test parsing a file with multiple sections at the same level.""" + content = """# Section One +Content of section one. + +# Section Two +Content of section two. + +# Section Three +Content of section three.""" + + file_path = tmp_markdown_file(content) + doc = parse_markdown_file(file_path) + + assert len(doc.sections) == 3 + assert doc.sections[0].title == "Section One" + assert doc.sections[1].title == "Section Two" + assert doc.sections[2].title == "Section Three" + assert all(section.level == 1 for section in doc.sections) + + +def test_i_can_parse_file_with_nested_sections(tmp_markdown_file): + """Test parsing a file with nested headers (different levels).""" + content = """# Main Title +Introduction text. + +## Subsection A +Content A. + +## Subsection B +Content B. + +### Sub-subsection +Nested content.""" + + file_path = tmp_markdown_file(content) + doc = parse_markdown_file(file_path) + + assert len(doc.sections) == 4 + assert doc.sections[0].level == 1 + assert doc.sections[0].title == "Main Title" + assert doc.sections[1].level == 2 + assert doc.sections[1].title == "Subsection A" + assert doc.sections[2].level == 2 + assert doc.sections[2].title == "Subsection B" + assert doc.sections[3].level == 3 + assert doc.sections[3].title == "Sub-subsection" + + +def test_i_can_parse_file_without_headers(tmp_markdown_file): + """Test parsing a file with no headers (plain text).""" + content = """This is a plain text file. +It has no headers at all. +Just regular content.""" + + file_path = tmp_markdown_file(content) + doc = parse_markdown_file(file_path) + + assert len(doc.sections) == 1 + assert doc.sections[0].level == 0 + assert doc.sections[0].title == "" + assert doc.sections[0].content == content + assert doc.sections[0].start_line == 1 + assert doc.sections[0].end_line == 3 + + +def test_i_can_parse_empty_file(tmp_markdown_file): + """Test parsing an empty file.""" + content = "" + + file_path = tmp_markdown_file(content) + doc = parse_markdown_file(file_path) + + assert len(doc.sections) == 1 + assert doc.sections[0].level == 0 + assert doc.sections[0].title == "" + assert doc.sections[0].content == "" + assert doc.sections[0].start_line == 1 + assert doc.sections[0].end_line == 1 + + +def test_i_can_track_correct_line_numbers(tmp_markdown_file): + """Test that line numbers are correctly tracked for each section.""" + content = """# First Section +Line 2 +Line 3 + +# Second Section +Line 6 +Line 7 +Line 8""" + + file_path = tmp_markdown_file(content) + doc = parse_markdown_file(file_path) + + assert doc.sections[0].start_line == 1 + assert doc.sections[0].end_line == 4 + assert doc.sections[1].start_line == 5 + assert doc.sections[1].end_line == 8 + + +def test_i_cannot_parse_nonexistent_file(): + """Test that parsing a non-existent file raises FileNotFoundError.""" + fake_path = Path("/nonexistent/path/to/file.md") + + with pytest.raises(FileNotFoundError): + parse_markdown_file(fake_path) + + +# Tests for find_section_at_line() + +def test_i_can_find_section_at_specific_line(tmp_markdown_file): + """Test finding a section at a line in the middle of content.""" + content = """# Section One +Line 2 +Line 3 + +# Section Two +Line 6 +Line 7""" + + file_path = tmp_markdown_file(content) + doc = parse_markdown_file(file_path) + + section = find_section_at_line(doc, 3) + + assert section is not None + assert section.title == "Section One" + + section = find_section_at_line(doc, 6) + + assert section is not None + assert section.title == "Section Two" + + +def test_i_can_find_section_at_first_line(tmp_markdown_file): + """Test finding a section at the header line itself.""" + content = """# Main Title +Content here.""" + + file_path = tmp_markdown_file(content) + doc = parse_markdown_file(file_path) + + section = find_section_at_line(doc, 1) + + assert section is not None + assert section.title == "Main Title" + + +def test_i_can_find_section_at_last_line(tmp_markdown_file): + """Test finding a section at its last line.""" + content = """# Section One +Line 2 +Line 3 + +# Section Two +Line 6""" + + file_path = tmp_markdown_file(content) + doc = parse_markdown_file(file_path) + + section = find_section_at_line(doc, 3) + + assert section is not None + assert section.title == "Section One" + + section = find_section_at_line(doc, 6) + + assert section is not None + assert section.title == "Section Two" + + +def test_i_cannot_find_section_for_invalid_line_number(tmp_markdown_file): + """Test that invalid line numbers return None.""" + content = """# Title +Content""" + + file_path = tmp_markdown_file(content) + doc = parse_markdown_file(file_path) + + # Negative line number + assert find_section_at_line(doc, -1) is None + + # Zero line number + assert find_section_at_line(doc, 0) is None + + # Line number beyond file length + assert find_section_at_line(doc, 1000) is None diff --git a/tests/test_searcher.py b/tests/test_searcher.py new file mode 100644 index 0000000..b35e9f8 --- /dev/null +++ b/tests/test_searcher.py @@ -0,0 +1,337 @@ +""" +Unit tests for the searcher module. +""" +import pytest +from pathlib import Path +from indexer import index_vault +from searcher import search_vault, _parse_search_results, SearchResult + + +@pytest.fixture +def temp_vault(tmp_path): + """ + Create a temporary vault with sample markdown files. + """ + vault_path = tmp_path / "test_vault" + vault_path.mkdir() + + # Create sample files + file1 = vault_path / "python_basics.md" + file1.write_text("""# Python Programming + +Python is a high-level programming language known for its simplicity and readability. + +## Variables and Data Types + +In Python, you can create variables without declaring their type explicitly. +Numbers, strings, and booleans are the basic data types. + +## Functions + +Functions in Python are defined using the def keyword. +They help organize code into reusable blocks. +""") + + file2 = vault_path / "machine_learning.md" + file2.write_text("""# Machine Learning + +Machine learning is a subset of artificial intelligence. + +## Supervised Learning + +Supervised learning uses labeled data to train models. +Common algorithms include linear regression and decision trees. + +## Deep Learning + +Deep learning uses neural networks with multiple layers. +It's particularly effective for image and speech recognition. +""") + + file3 = vault_path / "cooking.md" + file3.write_text("""# Italian Cuisine + +Italian cooking emphasizes fresh ingredients and simple preparation. + +## Pasta Dishes + +Pasta is a staple of Italian cuisine. +There are hundreds of pasta shapes and sauce combinations. + +## Pizza Making + +Traditional Italian pizza uses a thin crust and fresh mozzarella. +""") + + return vault_path + + +@pytest.fixture +def indexed_vault(temp_vault, tmp_path): + """ + Create and index a temporary vault. + """ + chroma_path = tmp_path / "chroma_db" + chroma_path.mkdir() + + # Index the vault + stats = index_vault( + vault_path=str(temp_vault), + chroma_db_path=str(chroma_path), + collection_name="test_collection", + ) + + return { + "vault_path": temp_vault, + "chroma_path": chroma_path, + "collection_name": "test_collection", + "stats": stats, + } + + +# Passing tests + + +def test_i_can_search_vault_with_valid_query(indexed_vault): + """ + Test that a basic search returns valid results. + """ + results = search_vault( + query="Python programming language", + chroma_db_path=str(indexed_vault["chroma_path"]), + collection_name=indexed_vault["collection_name"], + ) + + # Should return results + assert len(results) > 0 + + # All results should be SearchResult instances + for result in results: + assert isinstance(result, SearchResult) + + # Check that all fields are present + assert isinstance(result.file_path, str) + assert isinstance(result.section_title, str) + assert isinstance(result.line_start, int) + assert isinstance(result.line_end, int) + assert isinstance(result.score, float) + assert isinstance(result.text, str) + + # Scores should be between 0 and 1 + assert 0.0 <= result.score <= 1.0 + + # Results should be sorted by score (descending) + scores = [r.score for r in results] + assert scores == sorted(scores, reverse=True) + + +def test_i_can_search_vault_with_limit_parameter(indexed_vault): + """ + Test that the limit parameter is respected. + """ + limit = 3 + results = search_vault( + query="learning", + chroma_db_path=str(indexed_vault["chroma_path"]), + collection_name=indexed_vault["collection_name"], + limit=limit, + ) + + # Should return at most 'limit' results + assert len(results) <= limit + + +def test_i_can_search_vault_with_min_score_filter(indexed_vault): + """ + Test that only results above min_score are returned. + """ + min_score = 0.5 + results = search_vault( + query="Python", + chroma_db_path=str(indexed_vault["chroma_path"]), + collection_name=indexed_vault["collection_name"], + min_score=min_score, + ) + + # All results should have score >= min_score + for result in results: + assert result.score >= min_score + + +def test_i_can_get_correct_metadata_in_results(indexed_vault): + """ + Test that metadata in results is correct. + """ + results = search_vault( + query="Python programming", + chroma_db_path=str(indexed_vault["chroma_path"]), + collection_name=indexed_vault["collection_name"], + limit=1, + ) + + assert len(results) > 0 + top_result = results[0] + + # Should find python_basics.md as most relevant + assert "python_basics.md" in top_result.file_path + + # Should have a section title + assert len(top_result.section_title) > 0 + + # Line numbers should be positive + assert top_result.line_start > 0 + assert top_result.line_end >= top_result.line_start + + # Text should not be empty + assert len(top_result.text) > 0 + + +def test_i_can_search_with_different_collection_name(temp_vault, tmp_path): + """ + Test that we can search in a collection with a custom name. + """ + chroma_path = tmp_path / "chroma_custom" + chroma_path.mkdir() + custom_collection = "my_custom_collection" + + # Index with custom collection name + index_vault( + vault_path=str(temp_vault), + chroma_db_path=str(chroma_path), + collection_name=custom_collection, + ) + + # Search with the same custom collection name + results = search_vault( + query="Python", + chroma_db_path=str(chroma_path), + collection_name=custom_collection, + ) + + assert len(results) > 0 + + +def test_i_can_get_empty_results_when_no_match(indexed_vault): + """ + Test that a search with no matches returns an empty list. + """ + results = search_vault( + query="quantum physics relativity theory", + chroma_db_path=str(indexed_vault["chroma_path"]), + collection_name=indexed_vault["collection_name"], + min_score=0.9, # Very high threshold + ) + + # Should return empty list, not raise exception + assert isinstance(results, list) + assert len(results) == 0 + + +# Error tests + + +def test_i_cannot_search_with_empty_query(indexed_vault): + """ + Test that an empty query raises ValueError. + """ + with pytest.raises(ValueError, match="Query cannot be empty"): + search_vault( + query="", + chroma_db_path=str(indexed_vault["chroma_path"]), + collection_name=indexed_vault["collection_name"], + ) + + +def test_i_cannot_search_nonexistent_collection(tmp_path): + """ + Test that searching a nonexistent collection raises ValueError. + """ + chroma_path = tmp_path / "empty_chroma" + chroma_path.mkdir() + + with pytest.raises(ValueError, match="not found"): + search_vault( + query="test query", + chroma_db_path=str(chroma_path), + collection_name="nonexistent_collection", + ) + + +def test_i_cannot_search_with_whitespace_only_query(indexed_vault): + """ + Test that a query with only whitespace raises ValueError. + """ + with pytest.raises(ValueError, match="Query cannot be empty"): + search_vault( + query=" ", + chroma_db_path=str(indexed_vault["chroma_path"]), + collection_name=indexed_vault["collection_name"], + ) + + +# Helper function tests + + +def test_i_can_parse_search_results_correctly(): + """ + Test that ChromaDB results are parsed correctly. + """ + # Mock ChromaDB query results + raw_results = { + "documents": [[ + "Python is a programming language", + "Machine learning basics", + ]], + "metadatas": [[ + { + "file_path": "notes/python.md", + "section_title": "Introduction", + "line_start": 1, + "line_end": 5, + }, + { + "file_path": "notes/ml.md", + "section_title": "Overview", + "line_start": 10, + "line_end": 15, + }, + ]], + "distances": [[0.2, 0.4]], # ChromaDB distances (lower = more similar) + } + + results = _parse_search_results(raw_results, min_score=0.0) + + assert len(results) == 2 + + # Check first result + assert results[0].file_path == "notes/python.md" + assert results[0].section_title == "Introduction" + assert results[0].line_start == 1 + assert results[0].line_end == 5 + assert results[0].text == "Python is a programming language" + assert results[0].score == pytest.approx(0.8) # 1 - 0.2 + + # Check second result + assert results[1].score == pytest.approx(0.6) # 1 - 0.4 + + +def test_i_can_filter_results_by_min_score(): + """ + Test that results are filtered by min_score during parsing. + """ + raw_results = { + "documents": [["text1", "text2", "text3"]], + "metadatas": [[ + {"file_path": "a.md", "section_title": "A", "line_start": 1, "line_end": 2}, + {"file_path": "b.md", "section_title": "B", "line_start": 1, "line_end": 2}, + {"file_path": "c.md", "section_title": "C", "line_start": 1, "line_end": 2}, + ]], + "distances": [[0.1, 0.5, 0.8]], # Scores will be: 0.9, 0.5, 0.2 + } + + results = _parse_search_results(raw_results, min_score=0.6) + + # Only first result should pass (score 0.9 >= 0.6) + assert len(results) == 1 + assert results[0].file_path == "a.md" + assert results[0].score == pytest.approx(0.9)