Initial commit
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# Ignorer tous les fichiers .DS_Store quelle que soit leur profondeur
|
||||||
|
**/.DS_Store
|
||||||
|
prompts/spec/
|
||||||
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
# Default ignored files
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# Editor-based HTTP Client requests
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
18
.idea/MyObsidianAI.iml
generated
Normal file
18
.idea/MyObsidianAI.iml
generated
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/obsidian_rag" isTestSource="false" />
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="jdk" jdkName="Python 3.12.3 WSL (Ubuntu-24.04): (/home/kodjo/.virtualenvs/MyObsidianAI/bin/python)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
<component name="TestRunnerService">
|
||||||
|
<option name="PROJECT_TEST_RUNNER" value="py.test" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
6
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
6
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="PyInitNewSignatureInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
||||||
|
</profile>
|
||||||
|
</component>
|
||||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.12 (MyObsidianAI)" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12.3 WSL (Ubuntu-24.04): (/home/kodjo/.virtualenvs/MyObsidianAI/bin/python)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/MyObsidianAI.iml" filepath="$PROJECT_DIR$/.idea/MyObsidianAI.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
191
README.md
Normal file
191
README.md
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
# Obsidian RAG Backend
|
||||||
|
|
||||||
|
A local, semantic search backend for Obsidian markdown files.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
### Context
|
||||||
|
|
||||||
|
- **Target vault size**: ~1900 files, 480 MB
|
||||||
|
- **Deployment**: 100% local (no external APIs)
|
||||||
|
- **Usage**: Command-line interface (CLI)
|
||||||
|
- **Language**: Python 3.12
|
||||||
|
|
||||||
|
### Phase 1 Scope (Current)
|
||||||
|
|
||||||
|
Semantic search system that:
|
||||||
|
|
||||||
|
- Indexes markdown files from an Obsidian vault
|
||||||
|
- Performs semantic search using local embeddings
|
||||||
|
- Returns relevant results with metadata
|
||||||
|
|
||||||
|
**Phase 2 (Future)**: Add LLM integration for answer generation using Phase 1 search results.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
### Indexation
|
||||||
|
|
||||||
|
- Manual, on-demand indexing
|
||||||
|
- Processes all `.md` files in vault
|
||||||
|
- Extracts document structure (sections, line numbers)
|
||||||
|
- Hybrid chunking strategy:
|
||||||
|
- Short sections (≤200 tokens): indexed as-is
|
||||||
|
- Long sections: split with sliding window (200 tokens, 30 tokens overlap)
|
||||||
|
- Robust error handling: continues indexing even if individual files fail
|
||||||
|
|
||||||
|
### Search Results
|
||||||
|
|
||||||
|
Each search result includes:
|
||||||
|
|
||||||
|
- File path (relative to vault root)
|
||||||
|
- Similarity score
|
||||||
|
- Relevant text excerpt
|
||||||
|
- Location in file (section and line number)
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
obsidian_rag/
|
||||||
|
├── obsidian_rag/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── markdown_parser.py # Parse .md files, extract structure
|
||||||
|
│ ├── indexer.py # Generate embeddings and vector index
|
||||||
|
│ ├── searcher.py # Perform semantic search
|
||||||
|
│ └── cli.py # Typer CLI interface
|
||||||
|
├── tests/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── test_markdown_parser.py
|
||||||
|
│ ├── test_indexer.py
|
||||||
|
│ └── test_searcher.py
|
||||||
|
├── pyproject.toml
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## Technical Choices
|
||||||
|
|
||||||
|
### Technology Stack
|
||||||
|
|
||||||
|
| Component | Technology | Rationale |
|
||||||
|
|---------------|--------------------------------------------|----------------------------------------------|
|
||||||
|
| Embeddings | sentence-transformers (`all-MiniLM-L6-v2`) | Local, lightweight (~80MB), good performance |
|
||||||
|
| Vector Store | ChromaDB | Simple, persistent, good Python integration |
|
||||||
|
| CLI Framework | Typer | Modern, type-safe, excellent UX |
|
||||||
|
| Testing | pytest | Standard, powerful, good ecosystem |
|
||||||
|
|
||||||
|
### Design Decisions
|
||||||
|
|
||||||
|
1. **Modular architecture**: Separate concerns (parsing, indexing, searching, CLI) for maintainability and testability
|
||||||
|
2. **Local-only**: All processing happens on local machine, no data sent to external services
|
||||||
|
3. **Manual indexing**: User triggers re-indexing when needed (incremental updates deferred to future phases)
|
||||||
|
4. **Hybrid chunking**: Preserves small sections intact while handling large sections with sliding window
|
||||||
|
5. **Token-based chunking**: Uses model's tokenizer for precise chunk sizing (max 200 tokens, 30 tokens overlap)
|
||||||
|
6. **Robust error handling**: Indexing continues even if individual files fail, with detailed error reporting
|
||||||
|
7. **Extensible design**: Architecture prepared for future LLM integration
|
||||||
|
|
||||||
|
### Chunking Strategy Details
|
||||||
|
|
||||||
|
The indexer uses a hybrid approach:
|
||||||
|
|
||||||
|
- **Short sections** (≤200 tokens): Indexed as a single chunk to preserve semantic coherence
|
||||||
|
- **Long sections** (>200 tokens): Split using sliding window with:
|
||||||
|
- Maximum chunk size: 200 tokens (safe margin under model's 256 token limit)
|
||||||
|
- Overlap: 30 tokens (~15% overlap to preserve context at boundaries)
|
||||||
|
- Token counting: Uses sentence-transformers' native tokenizer for accuracy
|
||||||
|
|
||||||
|
### Metadata Structure
|
||||||
|
|
||||||
|
Each chunk stored in ChromaDB includes:
|
||||||
|
|
||||||
|
```python
|
||||||
|
{
|
||||||
|
"file_path": str, # Relative path from vault root
|
||||||
|
"section_title": str, # Markdown section heading
|
||||||
|
"line_start": int, # Starting line number in file
|
||||||
|
"line_end": int # Ending line number in file
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
### Required
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sentence-transformers # Local embeddings model (includes tokenizer)
|
||||||
|
chromadb # Vector database
|
||||||
|
typer # CLI framework
|
||||||
|
rich # Terminal formatting (Typer dependency)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Development
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest # Testing framework
|
||||||
|
pytest-cov # Test coverage
|
||||||
|
```
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install sentence-transformers chromadb typer[all] pytest pytest-cov
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage (Planned)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Index vault
|
||||||
|
obsidian-rag index /path/to/vault
|
||||||
|
|
||||||
|
# Search
|
||||||
|
obsidian-rag search "your query here"
|
||||||
|
|
||||||
|
# Search with options
|
||||||
|
obsidian-rag search "query" --limit 10 --min-score 0.5
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development Standards
|
||||||
|
|
||||||
|
### Code Style
|
||||||
|
|
||||||
|
- Follow PEP 8 conventions
|
||||||
|
- Use snake_case for variables and functions
|
||||||
|
- Docstrings in Google or NumPy format
|
||||||
|
- All code, comments, and documentation in English
|
||||||
|
|
||||||
|
### Testing Strategy
|
||||||
|
|
||||||
|
- Unit tests with pytest
|
||||||
|
- Test function naming: `test_i_can_xxx` (passing tests) or `test_i_cannot_xxx` (error cases)
|
||||||
|
- Functions over classes unless inheritance required
|
||||||
|
- Test plan validation before implementation
|
||||||
|
|
||||||
|
### File Management
|
||||||
|
|
||||||
|
- All file modifications documented with full file path
|
||||||
|
- Clear separation of concerns across modules
|
||||||
|
|
||||||
|
## Project Status
|
||||||
|
|
||||||
|
- [x] Requirements gathering
|
||||||
|
- [x] Architecture design
|
||||||
|
- [x] Chunking strategy validation
|
||||||
|
- [ ] Implementation
|
||||||
|
- [x] `markdown_parser.py`
|
||||||
|
- [x] `indexer.py`
|
||||||
|
- [x] `searcher.py`
|
||||||
|
- [x] `cli.py`
|
||||||
|
- [ ] Unit tests
|
||||||
|
- [x] `test_markdown_parser.py`
|
||||||
|
- [x] `test_indexer.py` (tests written, debugging in progress)
|
||||||
|
- [x] `test_searcher.py`
|
||||||
|
- [ ] `test_cli.py`
|
||||||
|
- [ ] Integration testing
|
||||||
|
- [ ] Documentation
|
||||||
|
- [ ] Phase 2: LLM integration
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
[To be determined]
|
||||||
|
|
||||||
|
## Contributing
|
||||||
|
|
||||||
|
[To be determined]
|
||||||
16
main.py
Normal file
16
main.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# This is a sample Python script.
|
||||||
|
|
||||||
|
# Press Maj+F10 to execute it or replace it with your code.
|
||||||
|
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
|
||||||
|
|
||||||
|
|
||||||
|
def print_hi(name):
|
||||||
|
# Use a breakpoint in the code line below to debug your script.
|
||||||
|
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
|
||||||
|
|
||||||
|
|
||||||
|
# Press the green button in the gutter to run the script.
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print_hi('PyCharm')
|
||||||
|
|
||||||
|
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
|
||||||
0
obsidian_rag/__init__.py
Normal file
0
obsidian_rag/__init__.py
Normal file
403
obsidian_rag/cli.py
Normal file
403
obsidian_rag/cli.py
Normal file
@@ -0,0 +1,403 @@
|
|||||||
|
"""
|
||||||
|
CLI module for Obsidian RAG Backend.
|
||||||
|
|
||||||
|
Provides command-line interface for indexing and searching the Obsidian vault.
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
||||||
|
|
||||||
|
from indexer import index_vault
|
||||||
|
from rag_chain import RAGChain
|
||||||
|
from searcher import search_vault, SearchResult
|
||||||
|
|
||||||
|
app = typer.Typer(
|
||||||
|
name="obsidian-rag",
|
||||||
|
help="Local semantic search backend for Obsidian markdown files",
|
||||||
|
add_completion=False,
|
||||||
|
)
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
# Default ChromaDB path
|
||||||
|
DEFAULT_CHROMA_PATH = Path.home() / ".obsidian_rag" / "chroma_db"
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate_path(path: str, max_len: int = 60) -> str:
|
||||||
|
"""Return a truncated version of the file path if too long."""
|
||||||
|
if len(path) <= max_len:
|
||||||
|
return path
|
||||||
|
return "..." + path[-(max_len - 3):]
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def index(
|
||||||
|
vault_path: str = typer.Argument(
|
||||||
|
...,
|
||||||
|
help="Path to the Obsidian vault directory",
|
||||||
|
),
|
||||||
|
chroma_path: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--chroma-path",
|
||||||
|
"-c",
|
||||||
|
help=f"Path to ChromaDB storage (default: {DEFAULT_CHROMA_PATH})",
|
||||||
|
),
|
||||||
|
collection_name: str = typer.Option(
|
||||||
|
"obsidian_vault",
|
||||||
|
"--collection",
|
||||||
|
help="Name of the ChromaDB collection",
|
||||||
|
),
|
||||||
|
max_chunk_tokens: int = typer.Option(
|
||||||
|
200,
|
||||||
|
"--max-tokens",
|
||||||
|
help="Maximum tokens per chunk",
|
||||||
|
),
|
||||||
|
overlap_tokens: int = typer.Option(
|
||||||
|
30,
|
||||||
|
"--overlap",
|
||||||
|
help="Number of overlapping tokens between chunks",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Index all markdown files from the Obsidian vault into ChromaDB.
|
||||||
|
"""
|
||||||
|
vault_path_obj = Path(vault_path)
|
||||||
|
chroma_path_obj = Path(chroma_path) if chroma_path else DEFAULT_CHROMA_PATH
|
||||||
|
|
||||||
|
if not vault_path_obj.exists():
|
||||||
|
console.print(f"[red]✗ Error:[/red] Vault path does not exist: {vault_path}")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
if not vault_path_obj.is_dir():
|
||||||
|
console.print(f"[red]✗ Error:[/red] Vault path is not a directory: {vault_path}")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
chroma_path_obj.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
md_files = list(vault_path_obj.rglob("*.md"))
|
||||||
|
total_files = len(md_files)
|
||||||
|
|
||||||
|
if total_files == 0:
|
||||||
|
console.print(f"[yellow]⚠ Warning:[/yellow] No markdown files found in {vault_path}")
|
||||||
|
raise typer.Exit(code=0)
|
||||||
|
|
||||||
|
console.print(f"\n[cyan]Found {total_files} markdown files to index[/cyan]\n")
|
||||||
|
|
||||||
|
# One single stable progress bar
|
||||||
|
with Progress(
|
||||||
|
SpinnerColumn(),
|
||||||
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
console=console,
|
||||||
|
) as progress:
|
||||||
|
|
||||||
|
main_task = progress.add_task("[cyan]Indexing vault...", total=total_files)
|
||||||
|
|
||||||
|
# Create a separate status line below the progress bar
|
||||||
|
status_line = console.status("[dim]Preparing first file...")
|
||||||
|
|
||||||
|
def progress_callback(current_file: str, files_processed: int, total: int):
|
||||||
|
"""Update progress bar and status message."""
|
||||||
|
progress.update(main_task, completed=files_processed)
|
||||||
|
|
||||||
|
short_file = _truncate_path(current_file)
|
||||||
|
status_line.update(f"[dim]Processing: {short_file}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with status_line:
|
||||||
|
stats = index_vault(
|
||||||
|
vault_path=str(vault_path_obj),
|
||||||
|
chroma_db_path=str(chroma_path_obj),
|
||||||
|
collection_name=collection_name,
|
||||||
|
max_chunk_tokens=max_chunk_tokens,
|
||||||
|
overlap_tokens=overlap_tokens,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
progress.update(main_task, completed=total_files)
|
||||||
|
status_line.update("[green]✓ Completed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"\n[red]✗ Error during indexing:[/red] {str(e)}")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
console.print()
|
||||||
|
_display_index_results(stats)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def search(
|
||||||
|
query: str = typer.Argument(
|
||||||
|
...,
|
||||||
|
help="Search query",
|
||||||
|
),
|
||||||
|
chroma_path: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--chroma-path",
|
||||||
|
"-c",
|
||||||
|
help=f"Path to ChromaDB storage (default: {DEFAULT_CHROMA_PATH})",
|
||||||
|
),
|
||||||
|
collection_name: str = typer.Option(
|
||||||
|
"obsidian_vault",
|
||||||
|
"--collection",
|
||||||
|
help="Name of the ChromaDB collection",
|
||||||
|
),
|
||||||
|
limit: int = typer.Option(
|
||||||
|
5,
|
||||||
|
"--limit",
|
||||||
|
"-l",
|
||||||
|
help="Maximum number of results to return",
|
||||||
|
),
|
||||||
|
min_score: float = typer.Option(
|
||||||
|
0.0,
|
||||||
|
"--min-score",
|
||||||
|
"-s",
|
||||||
|
help="Minimum similarity score (0.0 to 1.0)",
|
||||||
|
),
|
||||||
|
format: str = typer.Option(
|
||||||
|
"compact",
|
||||||
|
"--format",
|
||||||
|
"-f",
|
||||||
|
help="Output format: compact (default), panel, table",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Search the indexed vault for semantically similar content.
|
||||||
|
|
||||||
|
Returns relevant sections from your Obsidian notes based on
|
||||||
|
semantic similarity to the query.
|
||||||
|
"""
|
||||||
|
# Resolve paths
|
||||||
|
chroma_path_obj = Path(chroma_path) if chroma_path else DEFAULT_CHROMA_PATH
|
||||||
|
|
||||||
|
# Validate chroma path exists
|
||||||
|
if not chroma_path_obj.exists():
|
||||||
|
console.print(
|
||||||
|
f"[red]✗ Error:[/red] ChromaDB not found at {chroma_path_obj}\n"
|
||||||
|
f"Please run 'obsidian-rag index <vault_path>' first to create the index."
|
||||||
|
)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# Validate format
|
||||||
|
valid_formats = ["compact", "panel", "table"]
|
||||||
|
if format not in valid_formats:
|
||||||
|
console.print(f"[red]✗ Error:[/red] Invalid format '{format}'. Valid options: {', '.join(valid_formats)}")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# Perform search
|
||||||
|
try:
|
||||||
|
with console.status("[cyan]Searching...", spinner="dots"):
|
||||||
|
results = search_vault(
|
||||||
|
query=query,
|
||||||
|
chroma_db_path=str(chroma_path_obj),
|
||||||
|
collection_name=collection_name,
|
||||||
|
limit=limit,
|
||||||
|
min_score=min_score,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
console.print(f"[red]✗ Error:[/red] {str(e)}")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Unexpected error:[/red] {str(e)}")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# Display results
|
||||||
|
if not results:
|
||||||
|
console.print(f"\n[yellow]No results found for query:[/yellow] '{query}'")
|
||||||
|
if min_score > 0:
|
||||||
|
console.print(f"[dim]Try lowering --min-score (currently {min_score})[/dim]")
|
||||||
|
raise typer.Exit(code=0)
|
||||||
|
|
||||||
|
console.print(f"\n[cyan]Found {len(results)} result(s) for:[/cyan] '{query}'\n")
|
||||||
|
|
||||||
|
# Display with selected format
|
||||||
|
if format == "compact":
|
||||||
|
_display_results_compact(results)
|
||||||
|
elif format == "panel":
|
||||||
|
_display_results_panel(results)
|
||||||
|
elif format == "table":
|
||||||
|
_display_results_table(results)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def ask(
|
||||||
|
query: str = typer.Argument(
|
||||||
|
...,
|
||||||
|
help="Question to ask the LLM based on your Obsidian notes."
|
||||||
|
),
|
||||||
|
chroma_path: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--chroma-path",
|
||||||
|
"-c",
|
||||||
|
help=f"Path to ChromaDB storage (default: {DEFAULT_CHROMA_PATH})",
|
||||||
|
),
|
||||||
|
collection_name: str = typer.Option(
|
||||||
|
"obsidian_vault",
|
||||||
|
"--collection",
|
||||||
|
help="Name of the ChromaDB collection",
|
||||||
|
),
|
||||||
|
top_k: int = typer.Option(
|
||||||
|
5,
|
||||||
|
"--top-k",
|
||||||
|
"-k",
|
||||||
|
help="Number of top chunks to use for context",
|
||||||
|
),
|
||||||
|
min_score: float = typer.Option(
|
||||||
|
0.0,
|
||||||
|
"--min-score",
|
||||||
|
"-s",
|
||||||
|
help="Minimum similarity score for chunks",
|
||||||
|
),
|
||||||
|
api_key: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--api-key",
|
||||||
|
help="Clovis API key (or set CLOVIS_API_KEY environment variable)",
|
||||||
|
),
|
||||||
|
base_url: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--base-url",
|
||||||
|
help="Clovis base URL (or set CLOVIS_BASE_URL environment variable)",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ask a question to the LLM using RAG over your Obsidian vault.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Resolve ChromaDB path
|
||||||
|
chroma_path_obj = Path(chroma_path) if chroma_path else DEFAULT_CHROMA_PATH
|
||||||
|
if not chroma_path_obj.exists():
|
||||||
|
console.print(
|
||||||
|
f"[red]✗ Error:[/red] ChromaDB not found at {chroma_path_obj}\n"
|
||||||
|
f"Please run 'obsidian-rag index <vault_path>' first to create the index."
|
||||||
|
)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# Resolve API key and base URL
|
||||||
|
api_key = api_key or os.getenv("CLOVIS_API_KEY")
|
||||||
|
base_url = base_url or os.getenv("CLOVIS_BASE_URL")
|
||||||
|
if not api_key or not base_url:
|
||||||
|
console.print(
|
||||||
|
"[red]✗ Error:[/red] API key or base URL not provided.\n"
|
||||||
|
"Set them via --api-key / --base-url or environment variables CLOVIS_API_KEY and CLOVIS_BASE_URL."
|
||||||
|
)
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# Instantiate RAGChain
|
||||||
|
rag = RAGChain(
|
||||||
|
chroma_db_path=str(chroma_path_obj),
|
||||||
|
collection_name=collection_name,
|
||||||
|
top_k=top_k,
|
||||||
|
min_score=min_score,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get answer from RAG
|
||||||
|
try:
|
||||||
|
with console.status("[cyan]Querying LLM...", spinner="dots"):
|
||||||
|
answer, used_chunks = rag.answer_query(query)
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]✗ Error:[/red] {str(e)}")
|
||||||
|
raise typer.Exit(code=1)
|
||||||
|
|
||||||
|
# Display answer
|
||||||
|
console.print("\n[bold green]Answer:[/bold green]\n")
|
||||||
|
console.print(answer + "\n")
|
||||||
|
|
||||||
|
# Display sources used
|
||||||
|
if used_chunks:
|
||||||
|
sources = ", ".join(f"{c.file_path}#L{c.line_start}-L{c.line_end}" for c in used_chunks)
|
||||||
|
console.print(f"[bold cyan]Sources:[/bold cyan] {sources}\n")
|
||||||
|
else:
|
||||||
|
console.print("[bold cyan]Sources:[/bold cyan] None\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _display_index_results(stats: dict):
|
||||||
|
"""
|
||||||
|
Display indexing results with rich formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stats: Statistics dictionary from index_vault
|
||||||
|
"""
|
||||||
|
files_processed = stats["files_processed"]
|
||||||
|
chunks_created = stats["chunks_created"]
|
||||||
|
errors = stats["errors"]
|
||||||
|
|
||||||
|
# Success summary
|
||||||
|
console.print(Panel(
|
||||||
|
f"[green]✓[/green] Indexing completed\n\n"
|
||||||
|
f"Files processed: [cyan]{files_processed}[/cyan]\n"
|
||||||
|
f"Chunks created: [cyan]{chunks_created}[/cyan]\n"
|
||||||
|
f"Collection: [cyan]{stats['collection_name']}[/cyan]",
|
||||||
|
title="[bold]Indexing Results[/bold]",
|
||||||
|
border_style="green",
|
||||||
|
))
|
||||||
|
|
||||||
|
# Display errors if any
|
||||||
|
if errors:
|
||||||
|
console.print(f"\n[yellow]⚠ {len(errors)} file(s) skipped due to errors:[/yellow]\n")
|
||||||
|
for error in errors:
|
||||||
|
console.print(f" [red]•[/red] {error['file']}: [dim]{error['error']}[/dim]")
|
||||||
|
|
||||||
|
|
||||||
|
def _display_results_compact(results: list[SearchResult]):
|
||||||
|
"""
|
||||||
|
Display search results in compact format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of SearchResult objects
|
||||||
|
"""
|
||||||
|
for idx, result in enumerate(results, 1):
|
||||||
|
# Format score as stars (0-5 scale)
|
||||||
|
stars = "⭐" * int(result.score * 5)
|
||||||
|
|
||||||
|
console.print(f"[bold cyan]{idx}.[/bold cyan] {result.file_path} [dim](score: {result.score:.2f} {stars})[/dim]")
|
||||||
|
console.print(
|
||||||
|
f" Section: [yellow]{result.section_title}[/yellow] | Lines: [dim]{result.line_start}-{result.line_end}[/dim]")
|
||||||
|
|
||||||
|
# Truncate text if too long
|
||||||
|
text = result.text
|
||||||
|
if len(text) > 200:
|
||||||
|
text = text[:200] + "..."
|
||||||
|
|
||||||
|
console.print(f" {text}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _display_results_panel(results: list[SearchResult]):
|
||||||
|
"""
|
||||||
|
Display search results in panel format (rich boxes).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of SearchResult objects
|
||||||
|
"""
|
||||||
|
# TODO: Implement panel format in future
|
||||||
|
console.print("[yellow]Panel format not yet implemented. Using compact format.[/yellow]\n")
|
||||||
|
_display_results_compact(results)
|
||||||
|
|
||||||
|
|
||||||
|
def _display_results_table(results: list[SearchResult]):
|
||||||
|
"""
|
||||||
|
Display search results in table format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
results: List of SearchResult objects
|
||||||
|
"""
|
||||||
|
# TODO: Implement table format in future
|
||||||
|
console.print("[yellow]Table format not yet implemented. Using compact format.[/yellow]\n")
|
||||||
|
_display_results_compact(results)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""
|
||||||
|
Entry point for the CLI application.
|
||||||
|
"""
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
284
obsidian_rag/indexer.py
Normal file
284
obsidian_rag/indexer.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
"""
|
||||||
|
Indexer module for Obsidian RAG Backend.
|
||||||
|
|
||||||
|
This module handles the indexing of markdown files into a ChromaDB vector store
|
||||||
|
using local embeddings from sentence-transformers.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass, asdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Callable
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
from chromadb.config import Settings
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from markdown_parser import ParsedDocument
|
||||||
|
from markdown_parser import parse_markdown_file
|
||||||
|
|
||||||
|
# EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
||||||
|
EMBEDDING_MODEL = "all-MiniLM-L12-v2"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkMetadata:
|
||||||
|
file_path: str
|
||||||
|
section_title: str
|
||||||
|
line_start: int
|
||||||
|
line_end: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Chunk:
|
||||||
|
id: str
|
||||||
|
text: str
|
||||||
|
metadata: ChunkMetadata
|
||||||
|
|
||||||
|
|
||||||
|
def index_vault(
|
||||||
|
vault_path: str,
|
||||||
|
chroma_db_path: str,
|
||||||
|
collection_name: str = "obsidian_vault",
|
||||||
|
embedding_model: str = EMBEDDING_MODEL,
|
||||||
|
max_chunk_tokens: int = 200,
|
||||||
|
overlap_tokens: int = 30,
|
||||||
|
progress_callback: Optional[Callable[[str, int, int], None]] = None,
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Index all markdown files from vault into ChromaDB.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vault_path: Path to the Obsidian vault directory
|
||||||
|
chroma_db_path: Path where ChromaDB will store its data
|
||||||
|
collection_name: Name of the ChromaDB collection
|
||||||
|
embedding_model: Name of the sentence-transformers model to use
|
||||||
|
max_chunk_tokens: Maximum tokens per chunk
|
||||||
|
overlap_tokens: Number of overlapping tokens between chunks
|
||||||
|
progress_callback: Optional callback function called for each file processed.
|
||||||
|
Signature: callback(current_file: str, files_processed: int, total_files: int)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with indexing statistics:
|
||||||
|
- files_processed: Number of files successfully processed
|
||||||
|
- chunks_created: Total number of chunks created
|
||||||
|
- errors: List of errors encountered (file path and error message)
|
||||||
|
- collection_name: Name of the collection used
|
||||||
|
"""
|
||||||
|
|
||||||
|
vault_path_obj = Path(vault_path)
|
||||||
|
if not vault_path_obj.exists():
|
||||||
|
raise ValueError(f"Vault path does not exist: {vault_path}")
|
||||||
|
|
||||||
|
# Initialize embedding model and tokenizer
|
||||||
|
model = SentenceTransformer(embedding_model)
|
||||||
|
tokenizer = model.tokenizer
|
||||||
|
|
||||||
|
# Initialize ChromaDB client and collection
|
||||||
|
chroma_client = chromadb.PersistentClient(
|
||||||
|
path=chroma_db_path,
|
||||||
|
settings=Settings(anonymized_telemetry=False)
|
||||||
|
)
|
||||||
|
collection = _get_or_create_collection(chroma_client, collection_name)
|
||||||
|
|
||||||
|
# Find all markdown files
|
||||||
|
md_files = list(vault_path_obj.rglob("*.md"))
|
||||||
|
total_files = len(md_files)
|
||||||
|
|
||||||
|
# Statistics tracking
|
||||||
|
stats = {
|
||||||
|
"files_processed": 0,
|
||||||
|
"chunks_created": 0,
|
||||||
|
"errors": [],
|
||||||
|
"collection_name": collection_name,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Process each file
|
||||||
|
for md_file in md_files:
|
||||||
|
# Get relative path for display
|
||||||
|
relative_path = md_file.relative_to(vault_path_obj)
|
||||||
|
|
||||||
|
# Notify callback that we're starting this file
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(str(relative_path), stats["files_processed"], total_files)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse markdown file
|
||||||
|
parsed_doc = parse_markdown_file(md_file)
|
||||||
|
|
||||||
|
# Create chunks from document
|
||||||
|
chunks = _create_chunks_from_document(
|
||||||
|
parsed_doc,
|
||||||
|
tokenizer,
|
||||||
|
max_chunk_tokens,
|
||||||
|
overlap_tokens,
|
||||||
|
vault_path_obj,
|
||||||
|
)
|
||||||
|
|
||||||
|
if chunks:
|
||||||
|
# Extract data for ChromaDB
|
||||||
|
documents = [chunk.text for chunk in chunks]
|
||||||
|
metadatas = [asdict(chunk.metadata) for chunk in chunks]
|
||||||
|
ids = [chunk.id for chunk in chunks]
|
||||||
|
|
||||||
|
# Generate embeddings and add to collection
|
||||||
|
embeddings = model.encode(documents, show_progress_bar=False)
|
||||||
|
collection.add(
|
||||||
|
documents=documents,
|
||||||
|
metadatas=metadatas,
|
||||||
|
ids=ids,
|
||||||
|
embeddings=embeddings.tolist(),
|
||||||
|
)
|
||||||
|
|
||||||
|
stats["chunks_created"] += len(chunks)
|
||||||
|
|
||||||
|
stats["files_processed"] += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log error but continue processing
|
||||||
|
stats["errors"].append({
|
||||||
|
"file": str(relative_path),
|
||||||
|
"error": str(e),
|
||||||
|
})
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
def _get_or_create_collection(
|
||||||
|
chroma_client: chromadb.PersistentClient,
|
||||||
|
collection_name: str,
|
||||||
|
) -> chromadb.Collection:
|
||||||
|
"""
|
||||||
|
Get or create a ChromaDB collection, resetting it if it already exists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chroma_client: ChromaDB client instance
|
||||||
|
collection_name: Name of the collection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChromaDB collection instance
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to delete existing collection
|
||||||
|
chroma_client.delete_collection(name=collection_name)
|
||||||
|
except Exception:
|
||||||
|
# Collection doesn't exist, that's fine
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Create fresh collection
|
||||||
|
collection = chroma_client.create_collection(
|
||||||
|
name=collection_name,
|
||||||
|
metadata={"hnsw:space": "cosine"} # Use cosine similarity
|
||||||
|
)
|
||||||
|
|
||||||
|
return collection
|
||||||
|
|
||||||
|
|
||||||
|
def _create_chunks_from_document(
|
||||||
|
parsed_doc: ParsedDocument,
|
||||||
|
tokenizer,
|
||||||
|
max_chunk_tokens: int,
|
||||||
|
overlap_tokens: int,
|
||||||
|
vault_path: Path,
|
||||||
|
) -> List[Chunk]:
|
||||||
|
"""
|
||||||
|
Transform a parsed document into chunks with metadata.
|
||||||
|
|
||||||
|
Implements hybrid chunking strategy:
|
||||||
|
- Short sections (≤max_chunk_tokens): one chunk per section
|
||||||
|
- Long sections (>max_chunk_tokens): split with sliding window
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parsed_doc: Parsed document from markdown_parser
|
||||||
|
tokenizer: Tokenizer from sentence-transformers model
|
||||||
|
max_chunk_tokens: Maximum tokens per chunk
|
||||||
|
overlap_tokens: Number of overlapping tokens between chunks
|
||||||
|
vault_path: Path to vault root (for relative path calculation)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of chunk dictionaries with 'text', 'metadata', and 'id' keys
|
||||||
|
"""
|
||||||
|
chunks = []
|
||||||
|
file_path = parsed_doc.file_path
|
||||||
|
relative_path = file_path.relative_to(vault_path)
|
||||||
|
|
||||||
|
for section in parsed_doc.sections:
|
||||||
|
section_text = f"{parsed_doc.title} {section.title} {section.content}"
|
||||||
|
section_title = section.title
|
||||||
|
line_start = section.start_line
|
||||||
|
line_end = section.end_line
|
||||||
|
|
||||||
|
# Tokenize section to check length
|
||||||
|
tokens = tokenizer.encode(section_text, add_special_tokens=False)
|
||||||
|
|
||||||
|
if len(tokens) <= max_chunk_tokens:
|
||||||
|
# Short section: create single chunk
|
||||||
|
chunk_id = f"{relative_path}::{section_title}::{line_start}-{line_end}"
|
||||||
|
chunks.append(Chunk(chunk_id, section_text, ChunkMetadata(str(relative_path),
|
||||||
|
section_title,
|
||||||
|
line_start,
|
||||||
|
line_start
|
||||||
|
)))
|
||||||
|
else:
|
||||||
|
# Long section: split with sliding window
|
||||||
|
sub_chunks = _chunk_section(
|
||||||
|
section_text,
|
||||||
|
tokenizer,
|
||||||
|
max_chunk_tokens,
|
||||||
|
overlap_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create chunk for each sub-chunk
|
||||||
|
for idx, sub_chunk_text in enumerate(sub_chunks):
|
||||||
|
chunk_id = f"{relative_path}::{section_title}::{line_start}-{line_end}::chunk{idx}"
|
||||||
|
chunks.append(Chunk(chunk_id, sub_chunk_text, ChunkMetadata(str(relative_path),
|
||||||
|
section_title,
|
||||||
|
line_start,
|
||||||
|
line_start
|
||||||
|
)))
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
def _chunk_section(
|
||||||
|
section_text: str,
|
||||||
|
tokenizer,
|
||||||
|
max_chunk_tokens: int,
|
||||||
|
overlap_tokens: int,
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Split a section into overlapping chunks using sliding window.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
section_text: Text content to chunk
|
||||||
|
tokenizer: Tokenizer from sentence-transformers model
|
||||||
|
max_chunk_tokens: Maximum tokens per chunk
|
||||||
|
overlap_tokens: Number of overlapping tokens between chunks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of text chunks
|
||||||
|
"""
|
||||||
|
# Apply safety margin to prevent decode/encode inconsistencies
|
||||||
|
# from exceeding the max token limit
|
||||||
|
max_chunk_tokens_to_use = int(max_chunk_tokens * 0.98)
|
||||||
|
|
||||||
|
# Tokenize the full text
|
||||||
|
tokens = tokenizer.encode(section_text, add_special_tokens=False)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
start_idx = 0
|
||||||
|
|
||||||
|
while start_idx < len(tokens):
|
||||||
|
# Extract chunk tokens
|
||||||
|
end_idx = start_idx + max_chunk_tokens_to_use
|
||||||
|
chunk_tokens = tokens[start_idx:end_idx]
|
||||||
|
|
||||||
|
# Decode back to text
|
||||||
|
chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
|
||||||
|
chunks.append(chunk_text)
|
||||||
|
|
||||||
|
# Move window forward (with overlap)
|
||||||
|
start_idx += max_chunk_tokens_to_use - overlap_tokens
|
||||||
|
|
||||||
|
# Avoid infinite loop if overlap >= max_chunk_tokens
|
||||||
|
if start_idx <= start_idx - (max_chunk_tokens_to_use - overlap_tokens):
|
||||||
|
break
|
||||||
|
|
||||||
|
return chunks
|
||||||
74
obsidian_rag/llm_client.py
Normal file
74
obsidian_rag/llm_client.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import openai
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient:
|
||||||
|
"""
|
||||||
|
Minimalist client for interacting with Clovis LLM via OpenAI SDK.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
api_key (str): API key for Clovis.
|
||||||
|
base_url (str): Base URL for Clovis LLM gateway.
|
||||||
|
model (str): Model name to use. Defaults to 'ClovisLLM'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str, model: str = "ClovisLLM") -> None:
|
||||||
|
if not api_key:
|
||||||
|
raise ValueError("API key is required for LLMClient.")
|
||||||
|
if not base_url:
|
||||||
|
raise ValueError("Base URL is required for LLMClient.")
|
||||||
|
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
self.model = model
|
||||||
|
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
def generate(self, system_prompt: str, user_prompt: str, context: str) -> Dict[str, object]:
|
||||||
|
"""
|
||||||
|
Generate a response from the LLM given a system prompt, user prompt, and context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
system_prompt (str): Instructions for the assistant.
|
||||||
|
user_prompt (str): The user's query.
|
||||||
|
context (str): Concatenated chunks from RAG search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, object]: Contains:
|
||||||
|
- "answer" (str): Text generated by the LLM.
|
||||||
|
- "usage" (int): Total tokens used in the completion.
|
||||||
|
"""
|
||||||
|
# Construct user message with explicit CONTEXT / QUESTION separation
|
||||||
|
user_message_content = f"CONTEXT:\n{context}\n\nQUESTION:\n{user_prompt}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(model=self.model,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_message_content}
|
||||||
|
],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=2000,
|
||||||
|
top_p=1.0,
|
||||||
|
n=1,
|
||||||
|
# stream=False,
|
||||||
|
# presence_penalty=0.0,
|
||||||
|
# frequency_penalty=0.0,
|
||||||
|
# stop=None,
|
||||||
|
# logit_bias={},
|
||||||
|
user="obsidian_rag",
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# For now, propagate exceptions (C1 minimal)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# Extract text and usage
|
||||||
|
try:
|
||||||
|
answer_text = response.choices[0].message.content
|
||||||
|
total_tokens = response.usage.total_tokens
|
||||||
|
except AttributeError:
|
||||||
|
# Fallback if response structure is unexpected
|
||||||
|
answer_text = ""
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
return {"answer": answer_text, "usage": total_tokens}
|
||||||
213
obsidian_rag/markdown_parser.py
Normal file
213
obsidian_rag/markdown_parser.py
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
"""Markdown parser for Obsidian vault files.
|
||||||
|
|
||||||
|
This module provides functionality to parse markdown files and extract
|
||||||
|
their structure (sections, line numbers) for semantic search indexing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarkdownSection:
|
||||||
|
"""Represents a section in a markdown document.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
level: Header level (0 for no header, 1 for #, 2 for ##, etc.)
|
||||||
|
title: Section title (empty string if level=0)
|
||||||
|
content: Text content without the header line
|
||||||
|
start_line: Line number where section starts (1-indexed)
|
||||||
|
end_line: Line number where section ends (1-indexed, inclusive)
|
||||||
|
"""
|
||||||
|
level: int
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
parents: list[str]
|
||||||
|
start_line: int
|
||||||
|
end_line: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ParsedDocument:
|
||||||
|
"""Represents a parsed markdown document.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
file_path: Path to the markdown file
|
||||||
|
sections: List of sections extracted from the document
|
||||||
|
raw_content: Full file content as string
|
||||||
|
"""
|
||||||
|
file_path: Path
|
||||||
|
title: str
|
||||||
|
sections: List[MarkdownSection]
|
||||||
|
raw_content: str
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_parents(current_parents, previous_level, previous_title, current_level):
|
||||||
|
"""Computes the parents of `current_parents`."""
|
||||||
|
return current_parents
|
||||||
|
|
||||||
|
|
||||||
|
def parse_markdown_file(file_path: Path, vault_path=None) -> ParsedDocument:
|
||||||
|
"""Parse a markdown file and extract its structure.
|
||||||
|
|
||||||
|
This function reads a markdown file, identifies all header sections,
|
||||||
|
and extracts their content with precise line number tracking.
|
||||||
|
Files without headers are treated as a single section with level 0.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the markdown file to parse
|
||||||
|
vault_path: Path to the vault file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ParsedDocument containing the file structure and content
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the file does not exist
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> doc = parse_markdown_file(Path("notes/example.md"))
|
||||||
|
>>> print(f"Found {len(doc.sections)} sections")
|
||||||
|
>>> print(doc.sections[0].title)
|
||||||
|
"""
|
||||||
|
if not file_path.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
if vault_path:
|
||||||
|
title = str(file_path.relative_to(vault_path)).replace(".md", "")
|
||||||
|
title = title.replace("\\", " ").replace("/", " ")
|
||||||
|
else:
|
||||||
|
title = file_path.stem
|
||||||
|
|
||||||
|
raw_content = file_path.read_text(encoding="utf-8")
|
||||||
|
lines = raw_content.splitlines()
|
||||||
|
|
||||||
|
sections: List[MarkdownSection] = []
|
||||||
|
current_section_start = 1
|
||||||
|
current_level = 0
|
||||||
|
current_title = ""
|
||||||
|
current_parents = []
|
||||||
|
current_content_lines: List[str] = []
|
||||||
|
|
||||||
|
header_pattern = re.compile(r"^(#{1,6})\s+(.+)$")
|
||||||
|
|
||||||
|
for line_num, line in enumerate(lines, start=1):
|
||||||
|
match = header_pattern.match(line)
|
||||||
|
|
||||||
|
if match:
|
||||||
|
# Save the previous section only if it actually has content.
|
||||||
|
if current_content_lines:
|
||||||
|
content = "\n".join(current_content_lines)
|
||||||
|
sections.append(
|
||||||
|
MarkdownSection(
|
||||||
|
level=current_level,
|
||||||
|
title=current_title,
|
||||||
|
content=content,
|
||||||
|
parents=current_parents,
|
||||||
|
start_line=current_section_start,
|
||||||
|
end_line=line_num - 1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start a new section with the detected header.
|
||||||
|
previous_level = current_level
|
||||||
|
previous_title = current_title
|
||||||
|
current_level = len(match.group(1))
|
||||||
|
current_title = match.group(2).strip()
|
||||||
|
current_section_start = line_num
|
||||||
|
current_parents = _compute_parents(current_parents, previous_level, previous_title, current_level)
|
||||||
|
current_content_lines = []
|
||||||
|
else:
|
||||||
|
current_content_lines.append(line)
|
||||||
|
|
||||||
|
# Handle the final section (or whole file if no headers were found).
|
||||||
|
if lines:
|
||||||
|
content = "\n".join(current_content_lines)
|
||||||
|
end_line = len(lines)
|
||||||
|
|
||||||
|
# Case 1 – no header was ever found.
|
||||||
|
if not sections and current_level == 0:
|
||||||
|
sections.append(
|
||||||
|
MarkdownSection(
|
||||||
|
level=0,
|
||||||
|
title="",
|
||||||
|
content=content,
|
||||||
|
parents=current_parents,
|
||||||
|
start_line=1,
|
||||||
|
end_line=end_line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Case 2 – a single header was found (sections empty but we have a title).
|
||||||
|
elif not sections:
|
||||||
|
sections.append(
|
||||||
|
MarkdownSection(
|
||||||
|
level=current_level,
|
||||||
|
title=current_title,
|
||||||
|
content=content,
|
||||||
|
parents=current_parents,
|
||||||
|
start_line=current_section_start,
|
||||||
|
end_line=end_line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Case 3 – multiple headers were found (sections already contains earlier ones).
|
||||||
|
else:
|
||||||
|
sections.append(
|
||||||
|
MarkdownSection(
|
||||||
|
level=current_level,
|
||||||
|
title=current_title,
|
||||||
|
content=content,
|
||||||
|
parents=current_parents,
|
||||||
|
start_line=current_section_start,
|
||||||
|
end_line=end_line,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Empty file: create a single empty level‑0 section.
|
||||||
|
sections.append(
|
||||||
|
MarkdownSection(
|
||||||
|
level=0,
|
||||||
|
title="",
|
||||||
|
content="",
|
||||||
|
parents=[],
|
||||||
|
start_line=1,
|
||||||
|
end_line=1,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ParsedDocument(
|
||||||
|
file_path=file_path,
|
||||||
|
title=title,
|
||||||
|
sections=sections,
|
||||||
|
raw_content=raw_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def find_section_at_line(
|
||||||
|
document: ParsedDocument,
|
||||||
|
line_number: int,
|
||||||
|
) -> Optional[MarkdownSection]:
|
||||||
|
"""Find which section contains a given line number.
|
||||||
|
|
||||||
|
This function searches through the document's sections to find
|
||||||
|
which section contains the specified line number.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
document: Parsed markdown document
|
||||||
|
line_number: Line number to search for (1-indexed)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MarkdownSection containing the line, or None if line number
|
||||||
|
is invalid or out of range
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> section = find_section_at_line(doc, 42)
|
||||||
|
>>> if section:
|
||||||
|
... print(f"Line 42 is in section: {section.title}")
|
||||||
|
"""
|
||||||
|
if line_number < 1:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for section in document.sections:
|
||||||
|
if section.start_line <= line_number <= section.end_line:
|
||||||
|
return section
|
||||||
96
obsidian_rag/rag_chain.py
Normal file
96
obsidian_rag/rag_chain.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# File: obsidian_rag/rag_chain.py
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from indexer import EMBEDDING_MODEL
|
||||||
|
from llm_client import LLMClient
|
||||||
|
from searcher import search_vault, SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
class RAGChain:
|
||||||
|
"""
|
||||||
|
Retrieval-Augmented Generation (RAG) chain for answering queries
|
||||||
|
using semantic search over an Obsidian vault and LLM.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
chroma_db_path (Path): Path to ChromaDB.
|
||||||
|
collection_name (str): Chroma collection name.
|
||||||
|
embedding_model (str): Embedding model name.
|
||||||
|
top_k (int): Number of chunks to send to the LLM.
|
||||||
|
min_score (float): Minimum similarity score for chunks.
|
||||||
|
system_prompt (str): System prompt to instruct the LLM.
|
||||||
|
llm_client (LLMClient): Internal LLM client instance.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
|
"You are an assistant specialized in analyzing Obsidian notes.\n\n"
|
||||||
|
"INSTRUCTIONS:\n"
|
||||||
|
"- Answer based ONLY on the provided context\n"
|
||||||
|
"- Cite the sources (files) you use\n"
|
||||||
|
"- If the information is not in the context, say \"I did not find this information in your notes\"\n"
|
||||||
|
"- Be concise but thorough\n"
|
||||||
|
"- Structure your answer with sections if necessary"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
chroma_db_path: str,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str,
|
||||||
|
collection_name: str = "obsidian_vault",
|
||||||
|
embedding_model: str = EMBEDDING_MODEL,
|
||||||
|
top_k: int = 5,
|
||||||
|
min_score: float = 0.0,
|
||||||
|
system_prompt: str = None,
|
||||||
|
) -> None:
|
||||||
|
self.chroma_db_path = Path(chroma_db_path)
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.embedding_model = embedding_model
|
||||||
|
self.top_k = top_k
|
||||||
|
self.min_score = min_score
|
||||||
|
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
# Instantiate internal LLM client
|
||||||
|
self.llm_client = LLMClient(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
def answer_query(self, query: str) -> Tuple[str, List[SearchResult]]:
|
||||||
|
"""
|
||||||
|
Answer a user query using RAG: search vault, build context, call LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query (str): User query.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, List[SearchResult]]:
|
||||||
|
- LLM answer (str)
|
||||||
|
- List of used SearchResult chunks
|
||||||
|
"""
|
||||||
|
# 1. Perform semantic search
|
||||||
|
chunks: List[SearchResult] = search_vault(
|
||||||
|
query=query,
|
||||||
|
chroma_db_path=str(self.chroma_db_path),
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
embedding_model=self.embedding_model,
|
||||||
|
limit=self.top_k,
|
||||||
|
min_score=self.min_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Build context string with citations
|
||||||
|
context_parts: List[str] = []
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_text = chunk.text.strip()
|
||||||
|
citation = f"[{chunk.file_path}#L{chunk.line_start}-L{chunk.line_end}]"
|
||||||
|
context_parts.append(f"{chunk_text}\n{citation}")
|
||||||
|
|
||||||
|
context_str = "\n\n".join(context_parts) if context_parts else ""
|
||||||
|
|
||||||
|
# 3. Call LLM with context + question
|
||||||
|
llm_response = self.llm_client.generate(
|
||||||
|
system_prompt=self.system_prompt,
|
||||||
|
user_prompt=query,
|
||||||
|
context=context_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
answer_text = llm_response.get("answer", "")
|
||||||
|
return answer_text, chunks
|
||||||
131
obsidian_rag/searcher.py
Normal file
131
obsidian_rag/searcher.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""
|
||||||
|
Searcher module for Obsidian RAG Backend.
|
||||||
|
|
||||||
|
This module handles semantic search operations on the indexed ChromaDB collection.
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
import chromadb
|
||||||
|
from chromadb.config import Settings
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from indexer import EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchResult:
|
||||||
|
"""
|
||||||
|
Represents a single search result with metadata and relevance score.
|
||||||
|
"""
|
||||||
|
file_path: str
|
||||||
|
section_title: str
|
||||||
|
line_start: int
|
||||||
|
line_end: int
|
||||||
|
score: float
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
def search_vault(
|
||||||
|
query: str,
|
||||||
|
chroma_db_path: str,
|
||||||
|
collection_name: str = "obsidian_vault",
|
||||||
|
embedding_model: str = EMBEDDING_MODEL,
|
||||||
|
limit: int = 5,
|
||||||
|
min_score: float = 0.0,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""
|
||||||
|
Search the indexed vault for semantically similar content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Search query string
|
||||||
|
chroma_db_path: Path to ChromaDB data directory
|
||||||
|
collection_name: Name of the ChromaDB collection to search
|
||||||
|
embedding_model: Model used for embeddings (must match indexing model)
|
||||||
|
limit: Maximum number of results to return
|
||||||
|
min_score: Minimum similarity score threshold (0.0 to 1.0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects, sorted by relevance (highest score first)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the collection does not exist or query is empty
|
||||||
|
"""
|
||||||
|
if not query or not query.strip():
|
||||||
|
raise ValueError("Query cannot be empty")
|
||||||
|
|
||||||
|
# Initialize ChromaDB client
|
||||||
|
chroma_client = chromadb.PersistentClient(
|
||||||
|
path=chroma_db_path,
|
||||||
|
settings=Settings(anonymized_telemetry=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get collection (will raise if it doesn't exist)
|
||||||
|
try:
|
||||||
|
collection = chroma_client.get_collection(name=collection_name)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Collection '{collection_name}' not found. "
|
||||||
|
f"Please index your vault first using the index command."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Initialize embedding model (same as used during indexing)
|
||||||
|
model = SentenceTransformer(embedding_model)
|
||||||
|
|
||||||
|
# Generate query embedding
|
||||||
|
query_embedding = model.encode(query, show_progress_bar=False)
|
||||||
|
|
||||||
|
# Perform search
|
||||||
|
results = collection.query(
|
||||||
|
query_embeddings=[query_embedding.tolist()],
|
||||||
|
n_results=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse and format results
|
||||||
|
search_results = _parse_search_results(results, min_score)
|
||||||
|
|
||||||
|
return search_results
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_search_results(
|
||||||
|
raw_results: dict,
|
||||||
|
min_score: float,
|
||||||
|
) -> List[SearchResult]:
|
||||||
|
"""
|
||||||
|
Parse ChromaDB query results into SearchResult objects.
|
||||||
|
|
||||||
|
ChromaDB returns distances (lower = more similar). We convert to
|
||||||
|
similarity scores (higher = more similar) using: score = 1 - distance
|
||||||
|
|
||||||
|
Args:
|
||||||
|
raw_results: Raw results dictionary from ChromaDB query
|
||||||
|
min_score: Minimum similarity score to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SearchResult objects filtered by min_score
|
||||||
|
"""
|
||||||
|
search_results = []
|
||||||
|
|
||||||
|
# ChromaDB returns results as lists of lists (one list per query)
|
||||||
|
# We only have one query, so we take the first element
|
||||||
|
documents = raw_results.get("documents", [[]])[0]
|
||||||
|
metadatas = raw_results.get("metadatas", [[]])[0]
|
||||||
|
distances = raw_results.get("distances", [[]])[0]
|
||||||
|
|
||||||
|
for doc, metadata, distance in zip(documents, metadatas, distances):
|
||||||
|
# Convert distance to similarity score (cosine distance -> cosine similarity)
|
||||||
|
score = 1.0 - distance
|
||||||
|
|
||||||
|
# Filter by minimum score
|
||||||
|
if score < min_score:
|
||||||
|
continue
|
||||||
|
|
||||||
|
search_results.append(SearchResult(
|
||||||
|
file_path=metadata["file_path"],
|
||||||
|
section_title=metadata["section_title"],
|
||||||
|
line_start=metadata["line_start"],
|
||||||
|
line_end=metadata["line_end"],
|
||||||
|
score=score,
|
||||||
|
text=doc,
|
||||||
|
))
|
||||||
|
|
||||||
|
return search_results
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
537
tests/test_cli.py
Normal file
537
tests/test_cli.py
Normal file
@@ -0,0 +1,537 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the CLI module.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from typer.testing import CliRunner
|
||||||
|
from obsidian_rag.cli import app, _display_index_results, _display_results_compact
|
||||||
|
from obsidian_rag.indexer import index_vault
|
||||||
|
from obsidian_rag.searcher import SearchResult
|
||||||
|
|
||||||
|
runner = CliRunner()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_vault(tmp_path):
|
||||||
|
"""
|
||||||
|
Create a temporary vault with sample markdown files.
|
||||||
|
"""
|
||||||
|
vault_path = tmp_path / "test_vault"
|
||||||
|
vault_path.mkdir()
|
||||||
|
|
||||||
|
# Create sample files
|
||||||
|
file1 = vault_path / "python.md"
|
||||||
|
file1.write_text("""# Python Programming
|
||||||
|
|
||||||
|
Python is a high-level programming language.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
Python has dynamic typing and automatic memory management.
|
||||||
|
""")
|
||||||
|
|
||||||
|
file2 = vault_path / "javascript.md"
|
||||||
|
file2.write_text("""# JavaScript
|
||||||
|
|
||||||
|
JavaScript is a scripting language for web development.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
JavaScript runs in web browsers and Node.js environments.
|
||||||
|
""")
|
||||||
|
|
||||||
|
file3 = vault_path / "cooking.md"
|
||||||
|
file3.write_text("""# Cooking Tips
|
||||||
|
|
||||||
|
Learn how to cook delicious meals.
|
||||||
|
|
||||||
|
## Basics
|
||||||
|
|
||||||
|
Start with simple recipes and basic techniques.
|
||||||
|
""")
|
||||||
|
|
||||||
|
return vault_path
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for 'index' command - Passing tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_index_vault_successfully(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that we can index a vault successfully.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Found 3 markdown files to index" in result.stdout
|
||||||
|
assert "Indexing completed" in result.stdout
|
||||||
|
assert "Files processed:" in result.stdout
|
||||||
|
assert "Chunks created:" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_index_with_custom_chroma_path(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that we can specify a custom ChromaDB path.
|
||||||
|
"""
|
||||||
|
custom_chroma = tmp_path / "my_custom_db"
|
||||||
|
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(custom_chroma),
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert custom_chroma.exists()
|
||||||
|
assert (custom_chroma / "chroma.sqlite3").exists()
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_index_with_custom_collection_name(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that we can use a custom collection name.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
collection_name = "my_custom_collection"
|
||||||
|
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
"--collection", collection_name,
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert f"Collection: {collection_name}" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_see_errors_in_index_results(tmp_path):
|
||||||
|
"""
|
||||||
|
Test that errors during indexing are displayed.
|
||||||
|
"""
|
||||||
|
vault_path = tmp_path / "vault_with_errors"
|
||||||
|
vault_path.mkdir()
|
||||||
|
|
||||||
|
# Create a valid file
|
||||||
|
valid_file = vault_path / "valid.md"
|
||||||
|
valid_file.write_text("# Valid File\n\nThis is valid content.")
|
||||||
|
|
||||||
|
# Create an invalid file (will cause parsing error)
|
||||||
|
invalid_file = vault_path / "invalid.md"
|
||||||
|
invalid_file.write_bytes(b"\xff\xfe\x00\x00") # Invalid UTF-8
|
||||||
|
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(vault_path),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Should still complete (exit code 0) but show errors
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Indexing completed" in result.stdout
|
||||||
|
# Note: Error display might vary, just check it completed
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for 'index' command - Error tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_index_nonexistent_vault(tmp_path):
|
||||||
|
"""
|
||||||
|
Test that indexing a nonexistent vault fails with clear error.
|
||||||
|
"""
|
||||||
|
nonexistent_path = tmp_path / "does_not_exist"
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(nonexistent_path),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "does not exist" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_index_file_instead_of_directory(tmp_path):
|
||||||
|
"""
|
||||||
|
Test that indexing a file (not directory) fails.
|
||||||
|
"""
|
||||||
|
file_path = tmp_path / "somefile.txt"
|
||||||
|
file_path.write_text("I am a file")
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(file_path),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "not a directory" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_handle_empty_vault_gracefully(tmp_path):
|
||||||
|
"""
|
||||||
|
Test that an empty vault (no .md files) is handled gracefully.
|
||||||
|
"""
|
||||||
|
empty_vault = tmp_path / "empty_vault"
|
||||||
|
empty_vault.mkdir()
|
||||||
|
|
||||||
|
# Create a non-markdown file
|
||||||
|
(empty_vault / "readme.txt").write_text("Not a markdown file")
|
||||||
|
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(empty_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "No markdown files found" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for 'search' command - Passing tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_search_indexed_vault(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that we can search an indexed vault.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
# First, index the vault
|
||||||
|
index_result = runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
assert index_result.exit_code == 0
|
||||||
|
|
||||||
|
# Then search
|
||||||
|
search_result = runner.invoke(app, [
|
||||||
|
"search",
|
||||||
|
"Python programming",
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
assert search_result.exit_code == 0
|
||||||
|
assert "Found" in search_result.stdout
|
||||||
|
assert "result(s) for:" in search_result.stdout
|
||||||
|
assert "python.md" in search_result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_search_with_limit_option(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that the --limit option works.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
# Index
|
||||||
|
runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Search with limit
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"search",
|
||||||
|
"programming",
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
"--limit", "2",
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Count result numbers (1., 2., etc.)
|
||||||
|
result_count = result.stdout.count("[bold cyan]")
|
||||||
|
assert result_count <= 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_search_with_min_score_option(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that the --min-score option works.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
# Index
|
||||||
|
runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Search with high min-score
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"search",
|
||||||
|
"Python",
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
"--min-score", "0.5",
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Should have results (Python file should match well)
|
||||||
|
assert "Found" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_search_with_custom_collection(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that we can search in a custom collection.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
collection_name = "test_collection"
|
||||||
|
|
||||||
|
# Index with custom collection
|
||||||
|
runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
"--collection", collection_name,
|
||||||
|
])
|
||||||
|
|
||||||
|
# Search in same collection
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"search",
|
||||||
|
"Python",
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
"--collection", collection_name,
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "Found" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_handle_no_results_gracefully(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that no results scenario is handled gracefully.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
# Index
|
||||||
|
runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Search for something unlikely with high threshold
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"search",
|
||||||
|
"quantum physics relativity",
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
"--min-score", "0.95",
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
assert "No results found" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_use_compact_format(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that compact format displays correctly.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
# Index
|
||||||
|
runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Search with explicit compact format
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"search",
|
||||||
|
"Python",
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
"--format", "compact",
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 0
|
||||||
|
# Check for compact format elements
|
||||||
|
assert "Section:" in result.stdout
|
||||||
|
assert "Lines:" in result.stdout
|
||||||
|
assert "score:" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for 'search' command - Error tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_search_without_index(tmp_path):
|
||||||
|
"""
|
||||||
|
Test that searching without indexing fails with clear message.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "nonexistent_chroma"
|
||||||
|
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"search",
|
||||||
|
"test query",
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "not found" in result.stdout
|
||||||
|
assert "index" in result.stdout.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_search_nonexistent_collection(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that searching in a nonexistent collection fails.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
# Index with default collection
|
||||||
|
runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Search in different collection
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"search",
|
||||||
|
"Python",
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
"--collection", "nonexistent_collection",
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "not found" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_use_invalid_format(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that an invalid format is rejected.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
# Index
|
||||||
|
runner.invoke(app, [
|
||||||
|
"index",
|
||||||
|
str(temp_vault),
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Search with invalid format
|
||||||
|
result = runner.invoke(app, [
|
||||||
|
"search",
|
||||||
|
"Python",
|
||||||
|
"--chroma-path", str(chroma_path),
|
||||||
|
"--format", "invalid_format",
|
||||||
|
])
|
||||||
|
|
||||||
|
assert result.exit_code == 1
|
||||||
|
assert "Invalid format" in result.stdout
|
||||||
|
assert "compact" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for helper functions
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_display_index_results(capsys):
|
||||||
|
"""
|
||||||
|
Test that index results are displayed correctly.
|
||||||
|
"""
|
||||||
|
stats = {
|
||||||
|
"files_processed": 10,
|
||||||
|
"chunks_created": 50,
|
||||||
|
"collection_name": "test_collection",
|
||||||
|
"errors": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
_display_index_results(stats)
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "Indexing completed" in captured.out
|
||||||
|
assert "10" in captured.out
|
||||||
|
assert "50" in captured.out
|
||||||
|
assert "test_collection" in captured.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_display_index_results_with_errors(capsys):
|
||||||
|
"""
|
||||||
|
Test that index results with errors are displayed correctly.
|
||||||
|
"""
|
||||||
|
stats = {
|
||||||
|
"files_processed": 8,
|
||||||
|
"chunks_created": 40,
|
||||||
|
"collection_name": "test_collection",
|
||||||
|
"errors": [
|
||||||
|
{"file": "broken.md", "error": "Invalid encoding"},
|
||||||
|
{"file": "corrupt.md", "error": "Parse error"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
_display_index_results(stats)
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "Indexing completed" in captured.out
|
||||||
|
assert "2 file(s) skipped" in captured.out
|
||||||
|
assert "broken.md" in captured.out
|
||||||
|
assert "Invalid encoding" in captured.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_display_results_compact(capsys):
|
||||||
|
"""
|
||||||
|
Test that compact results display correctly.
|
||||||
|
"""
|
||||||
|
results = [
|
||||||
|
SearchResult(
|
||||||
|
file_path="notes/python.md",
|
||||||
|
section_title="Introduction",
|
||||||
|
line_start=1,
|
||||||
|
line_end=5,
|
||||||
|
score=0.87,
|
||||||
|
text="Python is a high-level programming language.",
|
||||||
|
),
|
||||||
|
SearchResult(
|
||||||
|
file_path="notes/javascript.md",
|
||||||
|
section_title="Overview",
|
||||||
|
line_start=10,
|
||||||
|
line_end=15,
|
||||||
|
score=0.65,
|
||||||
|
text="JavaScript is used for web development.",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
_display_results_compact(results)
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "python.md" in captured.out
|
||||||
|
assert "javascript.md" in captured.out
|
||||||
|
assert "0.87" in captured.out
|
||||||
|
assert "0.65" in captured.out
|
||||||
|
assert "Introduction" in captured.out
|
||||||
|
assert "Overview" in captured.out
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_display_results_compact_with_long_text(capsys):
|
||||||
|
"""
|
||||||
|
Test that long text is truncated in compact display.
|
||||||
|
"""
|
||||||
|
long_text = "A" * 300 # Text longer than 200 characters
|
||||||
|
|
||||||
|
results = [
|
||||||
|
SearchResult(
|
||||||
|
file_path="notes/long.md",
|
||||||
|
section_title="Long Section",
|
||||||
|
line_start=1,
|
||||||
|
line_end=10,
|
||||||
|
score=0.75,
|
||||||
|
text=long_text,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
_display_results_compact(results)
|
||||||
|
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "..." in captured.out # Should be truncated
|
||||||
|
assert len([line for line in captured.out.split('\n') if 'A' * 200 in line]) == 0 # Full text not shown
|
||||||
381
tests/test_indexer.py
Normal file
381
tests/test_indexer.py
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the indexer module.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import chromadb
|
||||||
|
import pytest
|
||||||
|
from chromadb.config import Settings
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
from indexer import (
|
||||||
|
index_vault,
|
||||||
|
_chunk_section,
|
||||||
|
_create_chunks_from_document,
|
||||||
|
_get_or_create_collection, EMBEDDING_MODEL,
|
||||||
|
)
|
||||||
|
from obsidian_rag.markdown_parser import ParsedDocument, MarkdownSection
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tokenizer():
|
||||||
|
"""Provide sentence-transformers tokenizer."""
|
||||||
|
model = SentenceTransformer(EMBEDDING_MODEL)
|
||||||
|
return model.tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def embedding_model():
|
||||||
|
"""Provide sentence-transformers model."""
|
||||||
|
return SentenceTransformer(EMBEDDING_MODEL)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def chroma_client(tmp_path):
|
||||||
|
"""Provide ChromaDB client with temporary storage."""
|
||||||
|
client = chromadb.PersistentClient(
|
||||||
|
path=str(tmp_path / "chroma_test"),
|
||||||
|
settings=Settings(anonymized_telemetry=False)
|
||||||
|
)
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_vault(tmp_path):
|
||||||
|
"""Create a temporary vault with test markdown files."""
|
||||||
|
vault_path = tmp_path / "test_vault"
|
||||||
|
vault_path.mkdir()
|
||||||
|
return vault_path
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for _chunk_section()
|
||||||
|
|
||||||
|
def test_i_can_chunk_short_section_into_single_chunk(tokenizer):
|
||||||
|
"""Test that a short section is not split."""
|
||||||
|
# Create text with ~100 tokens
|
||||||
|
short_text = " ".join(["word"] * 100)
|
||||||
|
|
||||||
|
chunks = _chunk_section(
|
||||||
|
section_text=short_text,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_chunk_tokens=200,
|
||||||
|
overlap_tokens=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0] == short_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_chunk_long_section_with_overlap(tokenizer):
|
||||||
|
"""Test splitting long section with overlap."""
|
||||||
|
# Create text with ~500 tokens
|
||||||
|
long_text = " ".join([f"word{i}" for i in range(500)])
|
||||||
|
|
||||||
|
chunks = _chunk_section(
|
||||||
|
section_text=long_text,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_chunk_tokens=200,
|
||||||
|
overlap_tokens=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should create multiple chunks
|
||||||
|
assert len(chunks) >= 2
|
||||||
|
|
||||||
|
# Verify no chunk exceeds max tokens
|
||||||
|
for chunk in chunks:
|
||||||
|
tokens = tokenizer.encode(chunk, add_special_tokens=False)
|
||||||
|
assert len(tokens) <= 200
|
||||||
|
|
||||||
|
# Verify overlap exists between consecutive chunks
|
||||||
|
for i in range(len(chunks) - 1):
|
||||||
|
# Check that some words from end of chunk[i] appear in start of chunk[i+1]
|
||||||
|
words_chunk1 = chunks[i].split()[-10:] # Last 10 words
|
||||||
|
words_chunk2 = chunks[i + 1].split()[:10] # First 10 words
|
||||||
|
|
||||||
|
# At least some overlap should exist
|
||||||
|
overlap_found = any(word in words_chunk2 for word in words_chunk1)
|
||||||
|
assert overlap_found
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_chunk_empty_section(tokenizer):
|
||||||
|
"""Test chunking an empty section."""
|
||||||
|
empty_text = ""
|
||||||
|
|
||||||
|
chunks = _chunk_section(
|
||||||
|
section_text=empty_text,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_chunk_tokens=200,
|
||||||
|
overlap_tokens=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for _create_chunks_from_document()
|
||||||
|
|
||||||
|
def test_i_can_create_chunks_from_document_with_short_sections(tmp_path, tokenizer):
|
||||||
|
"""Test creating chunks from document with only short sections."""
|
||||||
|
vault_path = tmp_path / "vault"
|
||||||
|
vault_path.mkdir()
|
||||||
|
|
||||||
|
parsed_doc = ParsedDocument(
|
||||||
|
file_path=vault_path / "test.md",
|
||||||
|
title="test.md",
|
||||||
|
sections=[
|
||||||
|
MarkdownSection(1, "Section 1", "This is a short section with few words.", [], 1, 2),
|
||||||
|
MarkdownSection(2, "Section 2", "Another short section here.", ["Section 1"], 3, 4),
|
||||||
|
MarkdownSection(3, "Section 3", "Third short section.", ["Section 1", "Section 3"], 5, 6),
|
||||||
|
],
|
||||||
|
raw_content="" # not used in this test
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = _create_chunks_from_document(
|
||||||
|
parsed_doc=parsed_doc,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_chunk_tokens=200,
|
||||||
|
overlap_tokens=30,
|
||||||
|
vault_path=vault_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should create 3 chunks (one per section)
|
||||||
|
assert len(chunks) == 3
|
||||||
|
|
||||||
|
# Verify metadata
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
metadata = chunk.metadata
|
||||||
|
assert metadata.file_path == "test.md"
|
||||||
|
assert metadata.section_title == f"Section {i + 1}"
|
||||||
|
assert isinstance(metadata.line_start, int)
|
||||||
|
assert isinstance(metadata.line_end, int)
|
||||||
|
|
||||||
|
# Verify ID format
|
||||||
|
assert "test.md" in chunk.id
|
||||||
|
assert f"Section {i + 1}" in chunk.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_create_chunks_from_document_with_long_section(tmp_path, tokenizer):
|
||||||
|
"""Test creating chunks from document with a long section that needs splitting."""
|
||||||
|
vault_path = tmp_path / "vault"
|
||||||
|
vault_path.mkdir()
|
||||||
|
|
||||||
|
# Create long content (~500 tokens)
|
||||||
|
long_content = " ".join([f"word{i}" for i in range(500)])
|
||||||
|
|
||||||
|
parsed_doc = ParsedDocument(
|
||||||
|
file_path=vault_path / "test.md",
|
||||||
|
title="test.md",
|
||||||
|
sections=[
|
||||||
|
MarkdownSection(1, "Long Section", long_content, [], 1, 1)
|
||||||
|
],
|
||||||
|
raw_content=long_content,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = _create_chunks_from_document(
|
||||||
|
parsed_doc=parsed_doc,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_chunk_tokens=200,
|
||||||
|
overlap_tokens=30,
|
||||||
|
vault_path=vault_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should create multiple chunks
|
||||||
|
assert len(chunks) >= 2
|
||||||
|
|
||||||
|
# All chunks should have same section_title
|
||||||
|
for chunk in chunks:
|
||||||
|
assert chunk.metadata.section_title == "Long Section"
|
||||||
|
assert chunk.metadata.line_start == 1
|
||||||
|
assert chunk.metadata.line_end == 1
|
||||||
|
|
||||||
|
# IDs should include chunk numbers
|
||||||
|
assert "::chunk0" in chunks[0].id
|
||||||
|
assert "::chunk1" in chunks[1].id
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_create_chunks_with_correct_relative_paths(tmp_path, tokenizer):
|
||||||
|
"""Test that relative paths are correctly computed."""
|
||||||
|
vault_path = tmp_path / "vault"
|
||||||
|
vault_path.mkdir()
|
||||||
|
|
||||||
|
# Create subdirectory
|
||||||
|
subdir = vault_path / "subfolder"
|
||||||
|
subdir.mkdir()
|
||||||
|
|
||||||
|
parsed_doc = ParsedDocument(
|
||||||
|
file_path=subdir / "nested.md",
|
||||||
|
title=f"{subdir} nested.md",
|
||||||
|
sections=[
|
||||||
|
MarkdownSection(1, "Section", "Some content here.", [], 1, 2),
|
||||||
|
],
|
||||||
|
raw_content="",
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = _create_chunks_from_document(
|
||||||
|
parsed_doc=parsed_doc,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_chunk_tokens=200,
|
||||||
|
overlap_tokens=30,
|
||||||
|
vault_path=vault_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(chunks) == 1
|
||||||
|
assert chunks[0].metadata.file_path == "subfolder/nested.md"
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for _get_or_create_collection()
|
||||||
|
|
||||||
|
def test_i_can_create_new_collection(chroma_client):
|
||||||
|
"""Test creating a new collection that doesn't exist."""
|
||||||
|
collection_name = "test_collection"
|
||||||
|
|
||||||
|
collection = _get_or_create_collection(chroma_client, collection_name)
|
||||||
|
|
||||||
|
assert collection.name == collection_name
|
||||||
|
assert collection.count() == 0 # Should be empty
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_reset_existing_collection(chroma_client):
|
||||||
|
"""Test that an existing collection is deleted and recreated."""
|
||||||
|
collection_name = "test_collection"
|
||||||
|
|
||||||
|
# Create collection and add data
|
||||||
|
first_collection = chroma_client.create_collection(collection_name)
|
||||||
|
first_collection.add(
|
||||||
|
documents=["test document"],
|
||||||
|
ids=["test_id"],
|
||||||
|
)
|
||||||
|
assert first_collection.count() == 1
|
||||||
|
|
||||||
|
# Reset collection
|
||||||
|
new_collection = _get_or_create_collection(chroma_client, collection_name)
|
||||||
|
|
||||||
|
assert new_collection.name == collection_name
|
||||||
|
assert new_collection.count() == 0 # Should be empty after reset
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for index_vault()
|
||||||
|
|
||||||
|
def test_i_can_index_single_markdown_file(test_vault, tmp_path, embedding_model):
|
||||||
|
"""Test indexing a single markdown file."""
|
||||||
|
# Create test markdown file
|
||||||
|
test_file = test_vault / "test.md"
|
||||||
|
test_file.write_text(
|
||||||
|
"# Title\n\nThis is a test document with some content.\n\n## Section\n\nMore content here."
|
||||||
|
)
|
||||||
|
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
stats = index_vault(
|
||||||
|
vault_path=str(test_vault),
|
||||||
|
chroma_db_path=str(chroma_path),
|
||||||
|
collection_name="test_collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert stats["files_processed"] == 1
|
||||||
|
assert stats["chunks_created"] > 0
|
||||||
|
assert stats["errors"] == []
|
||||||
|
assert stats["collection_name"] == "test_collection"
|
||||||
|
|
||||||
|
# Verify collection contains data
|
||||||
|
client = chromadb.PersistentClient(
|
||||||
|
path=str(chroma_path),
|
||||||
|
settings=Settings(anonymized_telemetry=False)
|
||||||
|
)
|
||||||
|
collection = client.get_collection("test_collection")
|
||||||
|
assert collection.count() == stats["chunks_created"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_index_multiple_markdown_files(test_vault, tmp_path):
|
||||||
|
"""Test indexing multiple markdown files."""
|
||||||
|
# Create multiple test files
|
||||||
|
for i in range(3):
|
||||||
|
test_file = test_vault / f"test{i}.md"
|
||||||
|
test_file.write_text(f"# Document {i}\n\nContent for document {i}.")
|
||||||
|
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
stats = index_vault(
|
||||||
|
vault_path=str(test_vault),
|
||||||
|
chroma_db_path=str(chroma_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert stats["files_processed"] == 3
|
||||||
|
assert stats["chunks_created"] >= 3 # At least one chunk per file
|
||||||
|
assert stats["errors"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_continue_indexing_after_file_error(test_vault, tmp_path, monkeypatch):
|
||||||
|
"""Test that indexing continues after encountering an error."""
|
||||||
|
# Create valid files
|
||||||
|
(test_vault / "valid1.md").write_text("# Valid 1\n\nContent here.")
|
||||||
|
(test_vault / "valid2.md").write_text("# Valid 2\n\nMore content.")
|
||||||
|
(test_vault / "problematic.md").write_text("# Problem\n\nThis will fail.")
|
||||||
|
|
||||||
|
# Mock parse_markdown_file to fail for problematic.md
|
||||||
|
from obsidian_rag import markdown_parser
|
||||||
|
original_parse = markdown_parser.parse_markdown_file
|
||||||
|
|
||||||
|
def mock_parse(file_path):
|
||||||
|
if "problematic.md" in str(file_path):
|
||||||
|
raise ValueError("Simulated parsing error")
|
||||||
|
return original_parse(file_path)
|
||||||
|
|
||||||
|
monkeypatch.setattr("indexer.parse_markdown_file", mock_parse)
|
||||||
|
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
stats = index_vault(
|
||||||
|
vault_path=str(test_vault),
|
||||||
|
chroma_db_path=str(chroma_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should process 2 valid files
|
||||||
|
assert stats["files_processed"] == 2
|
||||||
|
assert len(stats["errors"]) == 1
|
||||||
|
assert "problematic.md" in stats["errors"][0]["file"]
|
||||||
|
assert "Simulated parsing error" in stats["errors"][0]["error"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_index_nonexistent_vault(tmp_path):
|
||||||
|
"""Test that indexing a nonexistent vault raises an error."""
|
||||||
|
nonexistent_path = tmp_path / "nonexistent_vault"
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Vault path does not exist"):
|
||||||
|
index_vault(
|
||||||
|
vault_path=str(nonexistent_path),
|
||||||
|
chroma_db_path=str(chroma_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_verify_embeddings_are_generated(test_vault, tmp_path):
|
||||||
|
"""Test that embeddings are properly generated and stored."""
|
||||||
|
# Create test file
|
||||||
|
test_file = test_vault / "test.md"
|
||||||
|
test_file.write_text("# Test\n\nThis is test content for embedding generation.")
|
||||||
|
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
|
||||||
|
stats = index_vault(
|
||||||
|
vault_path=str(test_vault),
|
||||||
|
chroma_db_path=str(chroma_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify embeddings in collection
|
||||||
|
client = chromadb.PersistentClient(
|
||||||
|
path=str(chroma_path),
|
||||||
|
settings=Settings(anonymized_telemetry=False)
|
||||||
|
)
|
||||||
|
collection = client.get_collection("obsidian_vault")
|
||||||
|
|
||||||
|
# Get all items
|
||||||
|
results = collection.get(include=["embeddings"])
|
||||||
|
|
||||||
|
assert len(results["ids"]) == stats["chunks_created"]
|
||||||
|
assert results["embeddings"] is not None
|
||||||
|
|
||||||
|
# Verify embeddings are non-zero vectors of correct dimension
|
||||||
|
for embedding in results["embeddings"]:
|
||||||
|
assert len(embedding) == 384 # all-MiniLM-L6-v2 dimension
|
||||||
|
assert any(val != 0 for val in embedding) # Not all zeros
|
||||||
238
tests/test_markdown_parser.py
Normal file
238
tests/test_markdown_parser.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""Unit tests for markdown_parser module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from markdown_parser import (
|
||||||
|
parse_markdown_file,
|
||||||
|
find_section_at_line,
|
||||||
|
MarkdownSection,
|
||||||
|
ParsedDocument
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_markdown_file(tmp_path):
|
||||||
|
"""Fixture to create temporary markdown files for testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tmp_path: pytest temporary directory fixture
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function that creates a markdown file with given content
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _create_file(content: str, filename: str = "test.md") -> Path:
|
||||||
|
file_path = tmp_path / filename
|
||||||
|
file_path.write_text(content, encoding="utf-8")
|
||||||
|
return file_path
|
||||||
|
|
||||||
|
return _create_file
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for parse_markdown_file()
|
||||||
|
|
||||||
|
def test_i_can_parse_file_with_single_section(tmp_markdown_file):
|
||||||
|
"""Test parsing a file with a single header section."""
|
||||||
|
content = """# Main Title
|
||||||
|
This is the content of the section.
|
||||||
|
It has multiple lines."""
|
||||||
|
|
||||||
|
file_path = tmp_markdown_file(content)
|
||||||
|
doc = parse_markdown_file(file_path)
|
||||||
|
|
||||||
|
assert len(doc.sections) == 1
|
||||||
|
assert doc.sections[0].level == 1
|
||||||
|
assert doc.sections[0].title == "Main Title"
|
||||||
|
assert "This is the content" in doc.sections[0].content
|
||||||
|
assert doc.sections[0].start_line == 1
|
||||||
|
assert doc.sections[0].end_line == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_parse_file_with_multiple_sections(tmp_markdown_file):
|
||||||
|
"""Test parsing a file with multiple sections at the same level."""
|
||||||
|
content = """# Section One
|
||||||
|
Content of section one.
|
||||||
|
|
||||||
|
# Section Two
|
||||||
|
Content of section two.
|
||||||
|
|
||||||
|
# Section Three
|
||||||
|
Content of section three."""
|
||||||
|
|
||||||
|
file_path = tmp_markdown_file(content)
|
||||||
|
doc = parse_markdown_file(file_path)
|
||||||
|
|
||||||
|
assert len(doc.sections) == 3
|
||||||
|
assert doc.sections[0].title == "Section One"
|
||||||
|
assert doc.sections[1].title == "Section Two"
|
||||||
|
assert doc.sections[2].title == "Section Three"
|
||||||
|
assert all(section.level == 1 for section in doc.sections)
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_parse_file_with_nested_sections(tmp_markdown_file):
|
||||||
|
"""Test parsing a file with nested headers (different levels)."""
|
||||||
|
content = """# Main Title
|
||||||
|
Introduction text.
|
||||||
|
|
||||||
|
## Subsection A
|
||||||
|
Content A.
|
||||||
|
|
||||||
|
## Subsection B
|
||||||
|
Content B.
|
||||||
|
|
||||||
|
### Sub-subsection
|
||||||
|
Nested content."""
|
||||||
|
|
||||||
|
file_path = tmp_markdown_file(content)
|
||||||
|
doc = parse_markdown_file(file_path)
|
||||||
|
|
||||||
|
assert len(doc.sections) == 4
|
||||||
|
assert doc.sections[0].level == 1
|
||||||
|
assert doc.sections[0].title == "Main Title"
|
||||||
|
assert doc.sections[1].level == 2
|
||||||
|
assert doc.sections[1].title == "Subsection A"
|
||||||
|
assert doc.sections[2].level == 2
|
||||||
|
assert doc.sections[2].title == "Subsection B"
|
||||||
|
assert doc.sections[3].level == 3
|
||||||
|
assert doc.sections[3].title == "Sub-subsection"
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_parse_file_without_headers(tmp_markdown_file):
|
||||||
|
"""Test parsing a file with no headers (plain text)."""
|
||||||
|
content = """This is a plain text file.
|
||||||
|
It has no headers at all.
|
||||||
|
Just regular content."""
|
||||||
|
|
||||||
|
file_path = tmp_markdown_file(content)
|
||||||
|
doc = parse_markdown_file(file_path)
|
||||||
|
|
||||||
|
assert len(doc.sections) == 1
|
||||||
|
assert doc.sections[0].level == 0
|
||||||
|
assert doc.sections[0].title == ""
|
||||||
|
assert doc.sections[0].content == content
|
||||||
|
assert doc.sections[0].start_line == 1
|
||||||
|
assert doc.sections[0].end_line == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_parse_empty_file(tmp_markdown_file):
|
||||||
|
"""Test parsing an empty file."""
|
||||||
|
content = ""
|
||||||
|
|
||||||
|
file_path = tmp_markdown_file(content)
|
||||||
|
doc = parse_markdown_file(file_path)
|
||||||
|
|
||||||
|
assert len(doc.sections) == 1
|
||||||
|
assert doc.sections[0].level == 0
|
||||||
|
assert doc.sections[0].title == ""
|
||||||
|
assert doc.sections[0].content == ""
|
||||||
|
assert doc.sections[0].start_line == 1
|
||||||
|
assert doc.sections[0].end_line == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_track_correct_line_numbers(tmp_markdown_file):
|
||||||
|
"""Test that line numbers are correctly tracked for each section."""
|
||||||
|
content = """# First Section
|
||||||
|
Line 2
|
||||||
|
Line 3
|
||||||
|
|
||||||
|
# Second Section
|
||||||
|
Line 6
|
||||||
|
Line 7
|
||||||
|
Line 8"""
|
||||||
|
|
||||||
|
file_path = tmp_markdown_file(content)
|
||||||
|
doc = parse_markdown_file(file_path)
|
||||||
|
|
||||||
|
assert doc.sections[0].start_line == 1
|
||||||
|
assert doc.sections[0].end_line == 4
|
||||||
|
assert doc.sections[1].start_line == 5
|
||||||
|
assert doc.sections[1].end_line == 8
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_parse_nonexistent_file():
|
||||||
|
"""Test that parsing a non-existent file raises FileNotFoundError."""
|
||||||
|
fake_path = Path("/nonexistent/path/to/file.md")
|
||||||
|
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
parse_markdown_file(fake_path)
|
||||||
|
|
||||||
|
|
||||||
|
# Tests for find_section_at_line()
|
||||||
|
|
||||||
|
def test_i_can_find_section_at_specific_line(tmp_markdown_file):
|
||||||
|
"""Test finding a section at a line in the middle of content."""
|
||||||
|
content = """# Section One
|
||||||
|
Line 2
|
||||||
|
Line 3
|
||||||
|
|
||||||
|
# Section Two
|
||||||
|
Line 6
|
||||||
|
Line 7"""
|
||||||
|
|
||||||
|
file_path = tmp_markdown_file(content)
|
||||||
|
doc = parse_markdown_file(file_path)
|
||||||
|
|
||||||
|
section = find_section_at_line(doc, 3)
|
||||||
|
|
||||||
|
assert section is not None
|
||||||
|
assert section.title == "Section One"
|
||||||
|
|
||||||
|
section = find_section_at_line(doc, 6)
|
||||||
|
|
||||||
|
assert section is not None
|
||||||
|
assert section.title == "Section Two"
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_find_section_at_first_line(tmp_markdown_file):
|
||||||
|
"""Test finding a section at the header line itself."""
|
||||||
|
content = """# Main Title
|
||||||
|
Content here."""
|
||||||
|
|
||||||
|
file_path = tmp_markdown_file(content)
|
||||||
|
doc = parse_markdown_file(file_path)
|
||||||
|
|
||||||
|
section = find_section_at_line(doc, 1)
|
||||||
|
|
||||||
|
assert section is not None
|
||||||
|
assert section.title == "Main Title"
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_find_section_at_last_line(tmp_markdown_file):
|
||||||
|
"""Test finding a section at its last line."""
|
||||||
|
content = """# Section One
|
||||||
|
Line 2
|
||||||
|
Line 3
|
||||||
|
|
||||||
|
# Section Two
|
||||||
|
Line 6"""
|
||||||
|
|
||||||
|
file_path = tmp_markdown_file(content)
|
||||||
|
doc = parse_markdown_file(file_path)
|
||||||
|
|
||||||
|
section = find_section_at_line(doc, 3)
|
||||||
|
|
||||||
|
assert section is not None
|
||||||
|
assert section.title == "Section One"
|
||||||
|
|
||||||
|
section = find_section_at_line(doc, 6)
|
||||||
|
|
||||||
|
assert section is not None
|
||||||
|
assert section.title == "Section Two"
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_find_section_for_invalid_line_number(tmp_markdown_file):
|
||||||
|
"""Test that invalid line numbers return None."""
|
||||||
|
content = """# Title
|
||||||
|
Content"""
|
||||||
|
|
||||||
|
file_path = tmp_markdown_file(content)
|
||||||
|
doc = parse_markdown_file(file_path)
|
||||||
|
|
||||||
|
# Negative line number
|
||||||
|
assert find_section_at_line(doc, -1) is None
|
||||||
|
|
||||||
|
# Zero line number
|
||||||
|
assert find_section_at_line(doc, 0) is None
|
||||||
|
|
||||||
|
# Line number beyond file length
|
||||||
|
assert find_section_at_line(doc, 1000) is None
|
||||||
337
tests/test_searcher.py
Normal file
337
tests/test_searcher.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the searcher module.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from indexer import index_vault
|
||||||
|
from searcher import search_vault, _parse_search_results, SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def temp_vault(tmp_path):
|
||||||
|
"""
|
||||||
|
Create a temporary vault with sample markdown files.
|
||||||
|
"""
|
||||||
|
vault_path = tmp_path / "test_vault"
|
||||||
|
vault_path.mkdir()
|
||||||
|
|
||||||
|
# Create sample files
|
||||||
|
file1 = vault_path / "python_basics.md"
|
||||||
|
file1.write_text("""# Python Programming
|
||||||
|
|
||||||
|
Python is a high-level programming language known for its simplicity and readability.
|
||||||
|
|
||||||
|
## Variables and Data Types
|
||||||
|
|
||||||
|
In Python, you can create variables without declaring their type explicitly.
|
||||||
|
Numbers, strings, and booleans are the basic data types.
|
||||||
|
|
||||||
|
## Functions
|
||||||
|
|
||||||
|
Functions in Python are defined using the def keyword.
|
||||||
|
They help organize code into reusable blocks.
|
||||||
|
""")
|
||||||
|
|
||||||
|
file2 = vault_path / "machine_learning.md"
|
||||||
|
file2.write_text("""# Machine Learning
|
||||||
|
|
||||||
|
Machine learning is a subset of artificial intelligence.
|
||||||
|
|
||||||
|
## Supervised Learning
|
||||||
|
|
||||||
|
Supervised learning uses labeled data to train models.
|
||||||
|
Common algorithms include linear regression and decision trees.
|
||||||
|
|
||||||
|
## Deep Learning
|
||||||
|
|
||||||
|
Deep learning uses neural networks with multiple layers.
|
||||||
|
It's particularly effective for image and speech recognition.
|
||||||
|
""")
|
||||||
|
|
||||||
|
file3 = vault_path / "cooking.md"
|
||||||
|
file3.write_text("""# Italian Cuisine
|
||||||
|
|
||||||
|
Italian cooking emphasizes fresh ingredients and simple preparation.
|
||||||
|
|
||||||
|
## Pasta Dishes
|
||||||
|
|
||||||
|
Pasta is a staple of Italian cuisine.
|
||||||
|
There are hundreds of pasta shapes and sauce combinations.
|
||||||
|
|
||||||
|
## Pizza Making
|
||||||
|
|
||||||
|
Traditional Italian pizza uses a thin crust and fresh mozzarella.
|
||||||
|
""")
|
||||||
|
|
||||||
|
return vault_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def indexed_vault(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Create and index a temporary vault.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_db"
|
||||||
|
chroma_path.mkdir()
|
||||||
|
|
||||||
|
# Index the vault
|
||||||
|
stats = index_vault(
|
||||||
|
vault_path=str(temp_vault),
|
||||||
|
chroma_db_path=str(chroma_path),
|
||||||
|
collection_name="test_collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"vault_path": temp_vault,
|
||||||
|
"chroma_path": chroma_path,
|
||||||
|
"collection_name": "test_collection",
|
||||||
|
"stats": stats,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Passing tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_search_vault_with_valid_query(indexed_vault):
|
||||||
|
"""
|
||||||
|
Test that a basic search returns valid results.
|
||||||
|
"""
|
||||||
|
results = search_vault(
|
||||||
|
query="Python programming language",
|
||||||
|
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||||
|
collection_name=indexed_vault["collection_name"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return results
|
||||||
|
assert len(results) > 0
|
||||||
|
|
||||||
|
# All results should be SearchResult instances
|
||||||
|
for result in results:
|
||||||
|
assert isinstance(result, SearchResult)
|
||||||
|
|
||||||
|
# Check that all fields are present
|
||||||
|
assert isinstance(result.file_path, str)
|
||||||
|
assert isinstance(result.section_title, str)
|
||||||
|
assert isinstance(result.line_start, int)
|
||||||
|
assert isinstance(result.line_end, int)
|
||||||
|
assert isinstance(result.score, float)
|
||||||
|
assert isinstance(result.text, str)
|
||||||
|
|
||||||
|
# Scores should be between 0 and 1
|
||||||
|
assert 0.0 <= result.score <= 1.0
|
||||||
|
|
||||||
|
# Results should be sorted by score (descending)
|
||||||
|
scores = [r.score for r in results]
|
||||||
|
assert scores == sorted(scores, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_search_vault_with_limit_parameter(indexed_vault):
|
||||||
|
"""
|
||||||
|
Test that the limit parameter is respected.
|
||||||
|
"""
|
||||||
|
limit = 3
|
||||||
|
results = search_vault(
|
||||||
|
query="learning",
|
||||||
|
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||||
|
collection_name=indexed_vault["collection_name"],
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return at most 'limit' results
|
||||||
|
assert len(results) <= limit
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_search_vault_with_min_score_filter(indexed_vault):
|
||||||
|
"""
|
||||||
|
Test that only results above min_score are returned.
|
||||||
|
"""
|
||||||
|
min_score = 0.5
|
||||||
|
results = search_vault(
|
||||||
|
query="Python",
|
||||||
|
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||||
|
collection_name=indexed_vault["collection_name"],
|
||||||
|
min_score=min_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
# All results should have score >= min_score
|
||||||
|
for result in results:
|
||||||
|
assert result.score >= min_score
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_get_correct_metadata_in_results(indexed_vault):
|
||||||
|
"""
|
||||||
|
Test that metadata in results is correct.
|
||||||
|
"""
|
||||||
|
results = search_vault(
|
||||||
|
query="Python programming",
|
||||||
|
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||||
|
collection_name=indexed_vault["collection_name"],
|
||||||
|
limit=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) > 0
|
||||||
|
top_result = results[0]
|
||||||
|
|
||||||
|
# Should find python_basics.md as most relevant
|
||||||
|
assert "python_basics.md" in top_result.file_path
|
||||||
|
|
||||||
|
# Should have a section title
|
||||||
|
assert len(top_result.section_title) > 0
|
||||||
|
|
||||||
|
# Line numbers should be positive
|
||||||
|
assert top_result.line_start > 0
|
||||||
|
assert top_result.line_end >= top_result.line_start
|
||||||
|
|
||||||
|
# Text should not be empty
|
||||||
|
assert len(top_result.text) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_search_with_different_collection_name(temp_vault, tmp_path):
|
||||||
|
"""
|
||||||
|
Test that we can search in a collection with a custom name.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "chroma_custom"
|
||||||
|
chroma_path.mkdir()
|
||||||
|
custom_collection = "my_custom_collection"
|
||||||
|
|
||||||
|
# Index with custom collection name
|
||||||
|
index_vault(
|
||||||
|
vault_path=str(temp_vault),
|
||||||
|
chroma_db_path=str(chroma_path),
|
||||||
|
collection_name=custom_collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Search with the same custom collection name
|
||||||
|
results = search_vault(
|
||||||
|
query="Python",
|
||||||
|
chroma_db_path=str(chroma_path),
|
||||||
|
collection_name=custom_collection,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_get_empty_results_when_no_match(indexed_vault):
|
||||||
|
"""
|
||||||
|
Test that a search with no matches returns an empty list.
|
||||||
|
"""
|
||||||
|
results = search_vault(
|
||||||
|
query="quantum physics relativity theory",
|
||||||
|
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||||
|
collection_name=indexed_vault["collection_name"],
|
||||||
|
min_score=0.9, # Very high threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return empty list, not raise exception
|
||||||
|
assert isinstance(results, list)
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# Error tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_search_with_empty_query(indexed_vault):
|
||||||
|
"""
|
||||||
|
Test that an empty query raises ValueError.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError, match="Query cannot be empty"):
|
||||||
|
search_vault(
|
||||||
|
query="",
|
||||||
|
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||||
|
collection_name=indexed_vault["collection_name"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_search_nonexistent_collection(tmp_path):
|
||||||
|
"""
|
||||||
|
Test that searching a nonexistent collection raises ValueError.
|
||||||
|
"""
|
||||||
|
chroma_path = tmp_path / "empty_chroma"
|
||||||
|
chroma_path.mkdir()
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
search_vault(
|
||||||
|
query="test query",
|
||||||
|
chroma_db_path=str(chroma_path),
|
||||||
|
collection_name="nonexistent_collection",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_cannot_search_with_whitespace_only_query(indexed_vault):
|
||||||
|
"""
|
||||||
|
Test that a query with only whitespace raises ValueError.
|
||||||
|
"""
|
||||||
|
with pytest.raises(ValueError, match="Query cannot be empty"):
|
||||||
|
search_vault(
|
||||||
|
query=" ",
|
||||||
|
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||||
|
collection_name=indexed_vault["collection_name"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Helper function tests
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_parse_search_results_correctly():
|
||||||
|
"""
|
||||||
|
Test that ChromaDB results are parsed correctly.
|
||||||
|
"""
|
||||||
|
# Mock ChromaDB query results
|
||||||
|
raw_results = {
|
||||||
|
"documents": [[
|
||||||
|
"Python is a programming language",
|
||||||
|
"Machine learning basics",
|
||||||
|
]],
|
||||||
|
"metadatas": [[
|
||||||
|
{
|
||||||
|
"file_path": "notes/python.md",
|
||||||
|
"section_title": "Introduction",
|
||||||
|
"line_start": 1,
|
||||||
|
"line_end": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"file_path": "notes/ml.md",
|
||||||
|
"section_title": "Overview",
|
||||||
|
"line_start": 10,
|
||||||
|
"line_end": 15,
|
||||||
|
},
|
||||||
|
]],
|
||||||
|
"distances": [[0.2, 0.4]], # ChromaDB distances (lower = more similar)
|
||||||
|
}
|
||||||
|
|
||||||
|
results = _parse_search_results(raw_results, min_score=0.0)
|
||||||
|
|
||||||
|
assert len(results) == 2
|
||||||
|
|
||||||
|
# Check first result
|
||||||
|
assert results[0].file_path == "notes/python.md"
|
||||||
|
assert results[0].section_title == "Introduction"
|
||||||
|
assert results[0].line_start == 1
|
||||||
|
assert results[0].line_end == 5
|
||||||
|
assert results[0].text == "Python is a programming language"
|
||||||
|
assert results[0].score == pytest.approx(0.8) # 1 - 0.2
|
||||||
|
|
||||||
|
# Check second result
|
||||||
|
assert results[1].score == pytest.approx(0.6) # 1 - 0.4
|
||||||
|
|
||||||
|
|
||||||
|
def test_i_can_filter_results_by_min_score():
|
||||||
|
"""
|
||||||
|
Test that results are filtered by min_score during parsing.
|
||||||
|
"""
|
||||||
|
raw_results = {
|
||||||
|
"documents": [["text1", "text2", "text3"]],
|
||||||
|
"metadatas": [[
|
||||||
|
{"file_path": "a.md", "section_title": "A", "line_start": 1, "line_end": 2},
|
||||||
|
{"file_path": "b.md", "section_title": "B", "line_start": 1, "line_end": 2},
|
||||||
|
{"file_path": "c.md", "section_title": "C", "line_start": 1, "line_end": 2},
|
||||||
|
]],
|
||||||
|
"distances": [[0.1, 0.5, 0.8]], # Scores will be: 0.9, 0.5, 0.2
|
||||||
|
}
|
||||||
|
|
||||||
|
results = _parse_search_results(raw_results, min_score=0.6)
|
||||||
|
|
||||||
|
# Only first result should pass (score 0.9 >= 0.6)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].file_path == "a.md"
|
||||||
|
assert results[0].score == pytest.approx(0.9)
|
||||||
Reference in New Issue
Block a user