Initial commit
This commit is contained in:
3
.gitignore
vendored
Normal file
3
.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
# Ignorer tous les fichiers .DS_Store quelle que soit leur profondeur
|
||||
**/.DS_Store
|
||||
prompts/spec/
|
||||
8
.idea/.gitignore
generated
vendored
Normal file
8
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
18
.idea/MyObsidianAI.iml
generated
Normal file
18
.idea/MyObsidianAI.iml
generated
Normal file
@@ -0,0 +1,18 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/obsidian_rag" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/tests" isTestSource="true" />
|
||||
</content>
|
||||
<orderEntry type="jdk" jdkName="Python 3.12.3 WSL (Ubuntu-24.04): (/home/kodjo/.virtualenvs/MyObsidianAI/bin/python)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
<component name="TestRunnerService">
|
||||
<option name="PROJECT_TEST_RUNNER" value="py.test" />
|
||||
</component>
|
||||
</module>
|
||||
6
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
6
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="PyInitNewSignatureInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
||||
</profile>
|
||||
</component>
|
||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.12 (MyObsidianAI)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12.3 WSL (Ubuntu-24.04): (/home/kodjo/.virtualenvs/MyObsidianAI/bin/python)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/MyObsidianAI.iml" filepath="$PROJECT_DIR$/.idea/MyObsidianAI.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
191
README.md
Normal file
191
README.md
Normal file
@@ -0,0 +1,191 @@
|
||||
# Obsidian RAG Backend
|
||||
|
||||
A local, semantic search backend for Obsidian markdown files.
|
||||
|
||||
## Project Overview
|
||||
|
||||
### Context
|
||||
|
||||
- **Target vault size**: ~1900 files, 480 MB
|
||||
- **Deployment**: 100% local (no external APIs)
|
||||
- **Usage**: Command-line interface (CLI)
|
||||
- **Language**: Python 3.12
|
||||
|
||||
### Phase 1 Scope (Current)
|
||||
|
||||
Semantic search system that:
|
||||
|
||||
- Indexes markdown files from an Obsidian vault
|
||||
- Performs semantic search using local embeddings
|
||||
- Returns relevant results with metadata
|
||||
|
||||
**Phase 2 (Future)**: Add LLM integration for answer generation using Phase 1 search results.
|
||||
|
||||
## Features
|
||||
|
||||
### Indexation
|
||||
|
||||
- Manual, on-demand indexing
|
||||
- Processes all `.md` files in vault
|
||||
- Extracts document structure (sections, line numbers)
|
||||
- Hybrid chunking strategy:
|
||||
- Short sections (≤200 tokens): indexed as-is
|
||||
- Long sections: split with sliding window (200 tokens, 30 tokens overlap)
|
||||
- Robust error handling: continues indexing even if individual files fail
|
||||
|
||||
### Search Results
|
||||
|
||||
Each search result includes:
|
||||
|
||||
- File path (relative to vault root)
|
||||
- Similarity score
|
||||
- Relevant text excerpt
|
||||
- Location in file (section and line number)
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
obsidian_rag/
|
||||
├── obsidian_rag/
|
||||
│ ├── __init__.py
|
||||
│ ├── markdown_parser.py # Parse .md files, extract structure
|
||||
│ ├── indexer.py # Generate embeddings and vector index
|
||||
│ ├── searcher.py # Perform semantic search
|
||||
│ └── cli.py # Typer CLI interface
|
||||
├── tests/
|
||||
│ ├── __init__.py
|
||||
│ ├── test_markdown_parser.py
|
||||
│ ├── test_indexer.py
|
||||
│ └── test_searcher.py
|
||||
├── pyproject.toml
|
||||
└── README.md
|
||||
```
|
||||
|
||||
## Technical Choices
|
||||
|
||||
### Technology Stack
|
||||
|
||||
| Component | Technology | Rationale |
|
||||
|---------------|--------------------------------------------|----------------------------------------------|
|
||||
| Embeddings | sentence-transformers (`all-MiniLM-L6-v2`) | Local, lightweight (~80MB), good performance |
|
||||
| Vector Store | ChromaDB | Simple, persistent, good Python integration |
|
||||
| CLI Framework | Typer | Modern, type-safe, excellent UX |
|
||||
| Testing | pytest | Standard, powerful, good ecosystem |
|
||||
|
||||
### Design Decisions
|
||||
|
||||
1. **Modular architecture**: Separate concerns (parsing, indexing, searching, CLI) for maintainability and testability
|
||||
2. **Local-only**: All processing happens on local machine, no data sent to external services
|
||||
3. **Manual indexing**: User triggers re-indexing when needed (incremental updates deferred to future phases)
|
||||
4. **Hybrid chunking**: Preserves small sections intact while handling large sections with sliding window
|
||||
5. **Token-based chunking**: Uses model's tokenizer for precise chunk sizing (max 200 tokens, 30 tokens overlap)
|
||||
6. **Robust error handling**: Indexing continues even if individual files fail, with detailed error reporting
|
||||
7. **Extensible design**: Architecture prepared for future LLM integration
|
||||
|
||||
### Chunking Strategy Details
|
||||
|
||||
The indexer uses a hybrid approach:
|
||||
|
||||
- **Short sections** (≤200 tokens): Indexed as a single chunk to preserve semantic coherence
|
||||
- **Long sections** (>200 tokens): Split using sliding window with:
|
||||
- Maximum chunk size: 200 tokens (safe margin under model's 256 token limit)
|
||||
- Overlap: 30 tokens (~15% overlap to preserve context at boundaries)
|
||||
- Token counting: Uses sentence-transformers' native tokenizer for accuracy
|
||||
|
||||
### Metadata Structure
|
||||
|
||||
Each chunk stored in ChromaDB includes:
|
||||
|
||||
```python
|
||||
{
|
||||
"file_path": str, # Relative path from vault root
|
||||
"section_title": str, # Markdown section heading
|
||||
"line_start": int, # Starting line number in file
|
||||
"line_end": int # Ending line number in file
|
||||
}
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
### Required
|
||||
|
||||
```bash
|
||||
sentence-transformers # Local embeddings model (includes tokenizer)
|
||||
chromadb # Vector database
|
||||
typer # CLI framework
|
||||
rich # Terminal formatting (Typer dependency)
|
||||
```
|
||||
|
||||
### Development
|
||||
|
||||
```bash
|
||||
pytest # Testing framework
|
||||
pytest-cov # Test coverage
|
||||
```
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip install sentence-transformers chromadb typer[all] pytest pytest-cov
|
||||
```
|
||||
|
||||
## Usage (Planned)
|
||||
|
||||
```bash
|
||||
# Index vault
|
||||
obsidian-rag index /path/to/vault
|
||||
|
||||
# Search
|
||||
obsidian-rag search "your query here"
|
||||
|
||||
# Search with options
|
||||
obsidian-rag search "query" --limit 10 --min-score 0.5
|
||||
```
|
||||
|
||||
## Development Standards
|
||||
|
||||
### Code Style
|
||||
|
||||
- Follow PEP 8 conventions
|
||||
- Use snake_case for variables and functions
|
||||
- Docstrings in Google or NumPy format
|
||||
- All code, comments, and documentation in English
|
||||
|
||||
### Testing Strategy
|
||||
|
||||
- Unit tests with pytest
|
||||
- Test function naming: `test_i_can_xxx` (passing tests) or `test_i_cannot_xxx` (error cases)
|
||||
- Functions over classes unless inheritance required
|
||||
- Test plan validation before implementation
|
||||
|
||||
### File Management
|
||||
|
||||
- All file modifications documented with full file path
|
||||
- Clear separation of concerns across modules
|
||||
|
||||
## Project Status
|
||||
|
||||
- [x] Requirements gathering
|
||||
- [x] Architecture design
|
||||
- [x] Chunking strategy validation
|
||||
- [ ] Implementation
|
||||
- [x] `markdown_parser.py`
|
||||
- [x] `indexer.py`
|
||||
- [x] `searcher.py`
|
||||
- [x] `cli.py`
|
||||
- [ ] Unit tests
|
||||
- [x] `test_markdown_parser.py`
|
||||
- [x] `test_indexer.py` (tests written, debugging in progress)
|
||||
- [x] `test_searcher.py`
|
||||
- [ ] `test_cli.py`
|
||||
- [ ] Integration testing
|
||||
- [ ] Documentation
|
||||
- [ ] Phase 2: LLM integration
|
||||
|
||||
## License
|
||||
|
||||
[To be determined]
|
||||
|
||||
## Contributing
|
||||
|
||||
[To be determined]
|
||||
16
main.py
Normal file
16
main.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# This is a sample Python script.
|
||||
|
||||
# Press Maj+F10 to execute it or replace it with your code.
|
||||
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.
|
||||
|
||||
|
||||
def print_hi(name):
|
||||
# Use a breakpoint in the code line below to debug your script.
|
||||
print(f'Hi, {name}') # Press Ctrl+F8 to toggle the breakpoint.
|
||||
|
||||
|
||||
# Press the green button in the gutter to run the script.
|
||||
if __name__ == '__main__':
|
||||
print_hi('PyCharm')
|
||||
|
||||
# See PyCharm help at https://www.jetbrains.com/help/pycharm/
|
||||
0
obsidian_rag/__init__.py
Normal file
0
obsidian_rag/__init__.py
Normal file
403
obsidian_rag/cli.py
Normal file
403
obsidian_rag/cli.py
Normal file
@@ -0,0 +1,403 @@
|
||||
"""
|
||||
CLI module for Obsidian RAG Backend.
|
||||
|
||||
Provides command-line interface for indexing and searching the Obsidian vault.
|
||||
"""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import typer
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TaskProgressColumn
|
||||
|
||||
from indexer import index_vault
|
||||
from rag_chain import RAGChain
|
||||
from searcher import search_vault, SearchResult
|
||||
|
||||
app = typer.Typer(
|
||||
name="obsidian-rag",
|
||||
help="Local semantic search backend for Obsidian markdown files",
|
||||
add_completion=False,
|
||||
)
|
||||
console = Console()
|
||||
|
||||
# Default ChromaDB path
|
||||
DEFAULT_CHROMA_PATH = Path.home() / ".obsidian_rag" / "chroma_db"
|
||||
|
||||
|
||||
def _truncate_path(path: str, max_len: int = 60) -> str:
|
||||
"""Return a truncated version of the file path if too long."""
|
||||
if len(path) <= max_len:
|
||||
return path
|
||||
return "..." + path[-(max_len - 3):]
|
||||
|
||||
|
||||
@app.command()
|
||||
def index(
|
||||
vault_path: str = typer.Argument(
|
||||
...,
|
||||
help="Path to the Obsidian vault directory",
|
||||
),
|
||||
chroma_path: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--chroma-path",
|
||||
"-c",
|
||||
help=f"Path to ChromaDB storage (default: {DEFAULT_CHROMA_PATH})",
|
||||
),
|
||||
collection_name: str = typer.Option(
|
||||
"obsidian_vault",
|
||||
"--collection",
|
||||
help="Name of the ChromaDB collection",
|
||||
),
|
||||
max_chunk_tokens: int = typer.Option(
|
||||
200,
|
||||
"--max-tokens",
|
||||
help="Maximum tokens per chunk",
|
||||
),
|
||||
overlap_tokens: int = typer.Option(
|
||||
30,
|
||||
"--overlap",
|
||||
help="Number of overlapping tokens between chunks",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Index all markdown files from the Obsidian vault into ChromaDB.
|
||||
"""
|
||||
vault_path_obj = Path(vault_path)
|
||||
chroma_path_obj = Path(chroma_path) if chroma_path else DEFAULT_CHROMA_PATH
|
||||
|
||||
if not vault_path_obj.exists():
|
||||
console.print(f"[red]✗ Error:[/red] Vault path does not exist: {vault_path}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
if not vault_path_obj.is_dir():
|
||||
console.print(f"[red]✗ Error:[/red] Vault path is not a directory: {vault_path}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
chroma_path_obj.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
md_files = list(vault_path_obj.rglob("*.md"))
|
||||
total_files = len(md_files)
|
||||
|
||||
if total_files == 0:
|
||||
console.print(f"[yellow]⚠ Warning:[/yellow] No markdown files found in {vault_path}")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
console.print(f"\n[cyan]Found {total_files} markdown files to index[/cyan]\n")
|
||||
|
||||
# One single stable progress bar
|
||||
with Progress(
|
||||
SpinnerColumn(),
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
BarColumn(),
|
||||
TaskProgressColumn(),
|
||||
console=console,
|
||||
) as progress:
|
||||
|
||||
main_task = progress.add_task("[cyan]Indexing vault...", total=total_files)
|
||||
|
||||
# Create a separate status line below the progress bar
|
||||
status_line = console.status("[dim]Preparing first file...")
|
||||
|
||||
def progress_callback(current_file: str, files_processed: int, total: int):
|
||||
"""Update progress bar and status message."""
|
||||
progress.update(main_task, completed=files_processed)
|
||||
|
||||
short_file = _truncate_path(current_file)
|
||||
status_line.update(f"[dim]Processing: {short_file}")
|
||||
|
||||
try:
|
||||
with status_line:
|
||||
stats = index_vault(
|
||||
vault_path=str(vault_path_obj),
|
||||
chroma_db_path=str(chroma_path_obj),
|
||||
collection_name=collection_name,
|
||||
max_chunk_tokens=max_chunk_tokens,
|
||||
overlap_tokens=overlap_tokens,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
progress.update(main_task, completed=total_files)
|
||||
status_line.update("[green]✓ Completed")
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]✗ Error during indexing:[/red] {str(e)}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
console.print()
|
||||
_display_index_results(stats)
|
||||
|
||||
|
||||
@app.command()
|
||||
def search(
|
||||
query: str = typer.Argument(
|
||||
...,
|
||||
help="Search query",
|
||||
),
|
||||
chroma_path: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--chroma-path",
|
||||
"-c",
|
||||
help=f"Path to ChromaDB storage (default: {DEFAULT_CHROMA_PATH})",
|
||||
),
|
||||
collection_name: str = typer.Option(
|
||||
"obsidian_vault",
|
||||
"--collection",
|
||||
help="Name of the ChromaDB collection",
|
||||
),
|
||||
limit: int = typer.Option(
|
||||
5,
|
||||
"--limit",
|
||||
"-l",
|
||||
help="Maximum number of results to return",
|
||||
),
|
||||
min_score: float = typer.Option(
|
||||
0.0,
|
||||
"--min-score",
|
||||
"-s",
|
||||
help="Minimum similarity score (0.0 to 1.0)",
|
||||
),
|
||||
format: str = typer.Option(
|
||||
"compact",
|
||||
"--format",
|
||||
"-f",
|
||||
help="Output format: compact (default), panel, table",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Search the indexed vault for semantically similar content.
|
||||
|
||||
Returns relevant sections from your Obsidian notes based on
|
||||
semantic similarity to the query.
|
||||
"""
|
||||
# Resolve paths
|
||||
chroma_path_obj = Path(chroma_path) if chroma_path else DEFAULT_CHROMA_PATH
|
||||
|
||||
# Validate chroma path exists
|
||||
if not chroma_path_obj.exists():
|
||||
console.print(
|
||||
f"[red]✗ Error:[/red] ChromaDB not found at {chroma_path_obj}\n"
|
||||
f"Please run 'obsidian-rag index <vault_path>' first to create the index."
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Validate format
|
||||
valid_formats = ["compact", "panel", "table"]
|
||||
if format not in valid_formats:
|
||||
console.print(f"[red]✗ Error:[/red] Invalid format '{format}'. Valid options: {', '.join(valid_formats)}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Perform search
|
||||
try:
|
||||
with console.status("[cyan]Searching...", spinner="dots"):
|
||||
results = search_vault(
|
||||
query=query,
|
||||
chroma_db_path=str(chroma_path_obj),
|
||||
collection_name=collection_name,
|
||||
limit=limit,
|
||||
min_score=min_score,
|
||||
)
|
||||
except ValueError as e:
|
||||
console.print(f"[red]✗ Error:[/red] {str(e)}")
|
||||
raise typer.Exit(code=1)
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Unexpected error:[/red] {str(e)}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Display results
|
||||
if not results:
|
||||
console.print(f"\n[yellow]No results found for query:[/yellow] '{query}'")
|
||||
if min_score > 0:
|
||||
console.print(f"[dim]Try lowering --min-score (currently {min_score})[/dim]")
|
||||
raise typer.Exit(code=0)
|
||||
|
||||
console.print(f"\n[cyan]Found {len(results)} result(s) for:[/cyan] '{query}'\n")
|
||||
|
||||
# Display with selected format
|
||||
if format == "compact":
|
||||
_display_results_compact(results)
|
||||
elif format == "panel":
|
||||
_display_results_panel(results)
|
||||
elif format == "table":
|
||||
_display_results_table(results)
|
||||
|
||||
|
||||
@app.command()
|
||||
def ask(
|
||||
query: str = typer.Argument(
|
||||
...,
|
||||
help="Question to ask the LLM based on your Obsidian notes."
|
||||
),
|
||||
chroma_path: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--chroma-path",
|
||||
"-c",
|
||||
help=f"Path to ChromaDB storage (default: {DEFAULT_CHROMA_PATH})",
|
||||
),
|
||||
collection_name: str = typer.Option(
|
||||
"obsidian_vault",
|
||||
"--collection",
|
||||
help="Name of the ChromaDB collection",
|
||||
),
|
||||
top_k: int = typer.Option(
|
||||
5,
|
||||
"--top-k",
|
||||
"-k",
|
||||
help="Number of top chunks to use for context",
|
||||
),
|
||||
min_score: float = typer.Option(
|
||||
0.0,
|
||||
"--min-score",
|
||||
"-s",
|
||||
help="Minimum similarity score for chunks",
|
||||
),
|
||||
api_key: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--api-key",
|
||||
help="Clovis API key (or set CLOVIS_API_KEY environment variable)",
|
||||
),
|
||||
base_url: Optional[str] = typer.Option(
|
||||
None,
|
||||
"--base-url",
|
||||
help="Clovis base URL (or set CLOVIS_BASE_URL environment variable)",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Ask a question to the LLM using RAG over your Obsidian vault.
|
||||
"""
|
||||
|
||||
# Resolve ChromaDB path
|
||||
chroma_path_obj = Path(chroma_path) if chroma_path else DEFAULT_CHROMA_PATH
|
||||
if not chroma_path_obj.exists():
|
||||
console.print(
|
||||
f"[red]✗ Error:[/red] ChromaDB not found at {chroma_path_obj}\n"
|
||||
f"Please run 'obsidian-rag index <vault_path>' first to create the index."
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Resolve API key and base URL
|
||||
api_key = api_key or os.getenv("CLOVIS_API_KEY")
|
||||
base_url = base_url or os.getenv("CLOVIS_BASE_URL")
|
||||
if not api_key or not base_url:
|
||||
console.print(
|
||||
"[red]✗ Error:[/red] API key or base URL not provided.\n"
|
||||
"Set them via --api-key / --base-url or environment variables CLOVIS_API_KEY and CLOVIS_BASE_URL."
|
||||
)
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Instantiate RAGChain
|
||||
rag = RAGChain(
|
||||
chroma_db_path=str(chroma_path_obj),
|
||||
collection_name=collection_name,
|
||||
top_k=top_k,
|
||||
min_score=min_score,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
|
||||
# Get answer from RAG
|
||||
try:
|
||||
with console.status("[cyan]Querying LLM...", spinner="dots"):
|
||||
answer, used_chunks = rag.answer_query(query)
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Error:[/red] {str(e)}")
|
||||
raise typer.Exit(code=1)
|
||||
|
||||
# Display answer
|
||||
console.print("\n[bold green]Answer:[/bold green]\n")
|
||||
console.print(answer + "\n")
|
||||
|
||||
# Display sources used
|
||||
if used_chunks:
|
||||
sources = ", ".join(f"{c.file_path}#L{c.line_start}-L{c.line_end}" for c in used_chunks)
|
||||
console.print(f"[bold cyan]Sources:[/bold cyan] {sources}\n")
|
||||
else:
|
||||
console.print("[bold cyan]Sources:[/bold cyan] None\n")
|
||||
|
||||
|
||||
def _display_index_results(stats: dict):
|
||||
"""
|
||||
Display indexing results with rich formatting.
|
||||
|
||||
Args:
|
||||
stats: Statistics dictionary from index_vault
|
||||
"""
|
||||
files_processed = stats["files_processed"]
|
||||
chunks_created = stats["chunks_created"]
|
||||
errors = stats["errors"]
|
||||
|
||||
# Success summary
|
||||
console.print(Panel(
|
||||
f"[green]✓[/green] Indexing completed\n\n"
|
||||
f"Files processed: [cyan]{files_processed}[/cyan]\n"
|
||||
f"Chunks created: [cyan]{chunks_created}[/cyan]\n"
|
||||
f"Collection: [cyan]{stats['collection_name']}[/cyan]",
|
||||
title="[bold]Indexing Results[/bold]",
|
||||
border_style="green",
|
||||
))
|
||||
|
||||
# Display errors if any
|
||||
if errors:
|
||||
console.print(f"\n[yellow]⚠ {len(errors)} file(s) skipped due to errors:[/yellow]\n")
|
||||
for error in errors:
|
||||
console.print(f" [red]•[/red] {error['file']}: [dim]{error['error']}[/dim]")
|
||||
|
||||
|
||||
def _display_results_compact(results: list[SearchResult]):
|
||||
"""
|
||||
Display search results in compact format.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
"""
|
||||
for idx, result in enumerate(results, 1):
|
||||
# Format score as stars (0-5 scale)
|
||||
stars = "⭐" * int(result.score * 5)
|
||||
|
||||
console.print(f"[bold cyan]{idx}.[/bold cyan] {result.file_path} [dim](score: {result.score:.2f} {stars})[/dim]")
|
||||
console.print(
|
||||
f" Section: [yellow]{result.section_title}[/yellow] | Lines: [dim]{result.line_start}-{result.line_end}[/dim]")
|
||||
|
||||
# Truncate text if too long
|
||||
text = result.text
|
||||
if len(text) > 200:
|
||||
text = text[:200] + "..."
|
||||
|
||||
console.print(f" {text}\n")
|
||||
|
||||
|
||||
def _display_results_panel(results: list[SearchResult]):
|
||||
"""
|
||||
Display search results in panel format (rich boxes).
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
"""
|
||||
# TODO: Implement panel format in future
|
||||
console.print("[yellow]Panel format not yet implemented. Using compact format.[/yellow]\n")
|
||||
_display_results_compact(results)
|
||||
|
||||
|
||||
def _display_results_table(results: list[SearchResult]):
|
||||
"""
|
||||
Display search results in table format.
|
||||
|
||||
Args:
|
||||
results: List of SearchResult objects
|
||||
"""
|
||||
# TODO: Implement table format in future
|
||||
console.print("[yellow]Table format not yet implemented. Using compact format.[/yellow]\n")
|
||||
_display_results_compact(results)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Entry point for the CLI application.
|
||||
"""
|
||||
app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
284
obsidian_rag/indexer.py
Normal file
284
obsidian_rag/indexer.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Indexer module for Obsidian RAG Backend.
|
||||
|
||||
This module handles the indexing of markdown files into a ChromaDB vector store
|
||||
using local embeddings from sentence-transformers.
|
||||
"""
|
||||
from dataclasses import dataclass, asdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Callable
|
||||
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from markdown_parser import ParsedDocument
|
||||
from markdown_parser import parse_markdown_file
|
||||
|
||||
# EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
||||
EMBEDDING_MODEL = "all-MiniLM-L12-v2"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMetadata:
|
||||
file_path: str
|
||||
section_title: str
|
||||
line_start: int
|
||||
line_end: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chunk:
|
||||
id: str
|
||||
text: str
|
||||
metadata: ChunkMetadata
|
||||
|
||||
|
||||
def index_vault(
|
||||
vault_path: str,
|
||||
chroma_db_path: str,
|
||||
collection_name: str = "obsidian_vault",
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
max_chunk_tokens: int = 200,
|
||||
overlap_tokens: int = 30,
|
||||
progress_callback: Optional[Callable[[str, int, int], None]] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Index all markdown files from vault into ChromaDB.
|
||||
|
||||
Args:
|
||||
vault_path: Path to the Obsidian vault directory
|
||||
chroma_db_path: Path where ChromaDB will store its data
|
||||
collection_name: Name of the ChromaDB collection
|
||||
embedding_model: Name of the sentence-transformers model to use
|
||||
max_chunk_tokens: Maximum tokens per chunk
|
||||
overlap_tokens: Number of overlapping tokens between chunks
|
||||
progress_callback: Optional callback function called for each file processed.
|
||||
Signature: callback(current_file: str, files_processed: int, total_files: int)
|
||||
|
||||
Returns:
|
||||
Dictionary with indexing statistics:
|
||||
- files_processed: Number of files successfully processed
|
||||
- chunks_created: Total number of chunks created
|
||||
- errors: List of errors encountered (file path and error message)
|
||||
- collection_name: Name of the collection used
|
||||
"""
|
||||
|
||||
vault_path_obj = Path(vault_path)
|
||||
if not vault_path_obj.exists():
|
||||
raise ValueError(f"Vault path does not exist: {vault_path}")
|
||||
|
||||
# Initialize embedding model and tokenizer
|
||||
model = SentenceTransformer(embedding_model)
|
||||
tokenizer = model.tokenizer
|
||||
|
||||
# Initialize ChromaDB client and collection
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=chroma_db_path,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
collection = _get_or_create_collection(chroma_client, collection_name)
|
||||
|
||||
# Find all markdown files
|
||||
md_files = list(vault_path_obj.rglob("*.md"))
|
||||
total_files = len(md_files)
|
||||
|
||||
# Statistics tracking
|
||||
stats = {
|
||||
"files_processed": 0,
|
||||
"chunks_created": 0,
|
||||
"errors": [],
|
||||
"collection_name": collection_name,
|
||||
}
|
||||
|
||||
# Process each file
|
||||
for md_file in md_files:
|
||||
# Get relative path for display
|
||||
relative_path = md_file.relative_to(vault_path_obj)
|
||||
|
||||
# Notify callback that we're starting this file
|
||||
if progress_callback:
|
||||
progress_callback(str(relative_path), stats["files_processed"], total_files)
|
||||
|
||||
try:
|
||||
# Parse markdown file
|
||||
parsed_doc = parse_markdown_file(md_file)
|
||||
|
||||
# Create chunks from document
|
||||
chunks = _create_chunks_from_document(
|
||||
parsed_doc,
|
||||
tokenizer,
|
||||
max_chunk_tokens,
|
||||
overlap_tokens,
|
||||
vault_path_obj,
|
||||
)
|
||||
|
||||
if chunks:
|
||||
# Extract data for ChromaDB
|
||||
documents = [chunk.text for chunk in chunks]
|
||||
metadatas = [asdict(chunk.metadata) for chunk in chunks]
|
||||
ids = [chunk.id for chunk in chunks]
|
||||
|
||||
# Generate embeddings and add to collection
|
||||
embeddings = model.encode(documents, show_progress_bar=False)
|
||||
collection.add(
|
||||
documents=documents,
|
||||
metadatas=metadatas,
|
||||
ids=ids,
|
||||
embeddings=embeddings.tolist(),
|
||||
)
|
||||
|
||||
stats["chunks_created"] += len(chunks)
|
||||
|
||||
stats["files_processed"] += 1
|
||||
|
||||
except Exception as e:
|
||||
# Log error but continue processing
|
||||
stats["errors"].append({
|
||||
"file": str(relative_path),
|
||||
"error": str(e),
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def _get_or_create_collection(
|
||||
chroma_client: chromadb.PersistentClient,
|
||||
collection_name: str,
|
||||
) -> chromadb.Collection:
|
||||
"""
|
||||
Get or create a ChromaDB collection, resetting it if it already exists.
|
||||
|
||||
Args:
|
||||
chroma_client: ChromaDB client instance
|
||||
collection_name: Name of the collection
|
||||
|
||||
Returns:
|
||||
ChromaDB collection instance
|
||||
"""
|
||||
try:
|
||||
# Try to delete existing collection
|
||||
chroma_client.delete_collection(name=collection_name)
|
||||
except Exception:
|
||||
# Collection doesn't exist, that's fine
|
||||
pass
|
||||
|
||||
# Create fresh collection
|
||||
collection = chroma_client.create_collection(
|
||||
name=collection_name,
|
||||
metadata={"hnsw:space": "cosine"} # Use cosine similarity
|
||||
)
|
||||
|
||||
return collection
|
||||
|
||||
|
||||
def _create_chunks_from_document(
|
||||
parsed_doc: ParsedDocument,
|
||||
tokenizer,
|
||||
max_chunk_tokens: int,
|
||||
overlap_tokens: int,
|
||||
vault_path: Path,
|
||||
) -> List[Chunk]:
|
||||
"""
|
||||
Transform a parsed document into chunks with metadata.
|
||||
|
||||
Implements hybrid chunking strategy:
|
||||
- Short sections (≤max_chunk_tokens): one chunk per section
|
||||
- Long sections (>max_chunk_tokens): split with sliding window
|
||||
|
||||
Args:
|
||||
parsed_doc: Parsed document from markdown_parser
|
||||
tokenizer: Tokenizer from sentence-transformers model
|
||||
max_chunk_tokens: Maximum tokens per chunk
|
||||
overlap_tokens: Number of overlapping tokens between chunks
|
||||
vault_path: Path to vault root (for relative path calculation)
|
||||
|
||||
Returns:
|
||||
List of chunk dictionaries with 'text', 'metadata', and 'id' keys
|
||||
"""
|
||||
chunks = []
|
||||
file_path = parsed_doc.file_path
|
||||
relative_path = file_path.relative_to(vault_path)
|
||||
|
||||
for section in parsed_doc.sections:
|
||||
section_text = f"{parsed_doc.title} {section.title} {section.content}"
|
||||
section_title = section.title
|
||||
line_start = section.start_line
|
||||
line_end = section.end_line
|
||||
|
||||
# Tokenize section to check length
|
||||
tokens = tokenizer.encode(section_text, add_special_tokens=False)
|
||||
|
||||
if len(tokens) <= max_chunk_tokens:
|
||||
# Short section: create single chunk
|
||||
chunk_id = f"{relative_path}::{section_title}::{line_start}-{line_end}"
|
||||
chunks.append(Chunk(chunk_id, section_text, ChunkMetadata(str(relative_path),
|
||||
section_title,
|
||||
line_start,
|
||||
line_start
|
||||
)))
|
||||
else:
|
||||
# Long section: split with sliding window
|
||||
sub_chunks = _chunk_section(
|
||||
section_text,
|
||||
tokenizer,
|
||||
max_chunk_tokens,
|
||||
overlap_tokens,
|
||||
)
|
||||
|
||||
# Create chunk for each sub-chunk
|
||||
for idx, sub_chunk_text in enumerate(sub_chunks):
|
||||
chunk_id = f"{relative_path}::{section_title}::{line_start}-{line_end}::chunk{idx}"
|
||||
chunks.append(Chunk(chunk_id, sub_chunk_text, ChunkMetadata(str(relative_path),
|
||||
section_title,
|
||||
line_start,
|
||||
line_start
|
||||
)))
|
||||
return chunks
|
||||
|
||||
|
||||
def _chunk_section(
|
||||
section_text: str,
|
||||
tokenizer,
|
||||
max_chunk_tokens: int,
|
||||
overlap_tokens: int,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split a section into overlapping chunks using sliding window.
|
||||
|
||||
Args:
|
||||
section_text: Text content to chunk
|
||||
tokenizer: Tokenizer from sentence-transformers model
|
||||
max_chunk_tokens: Maximum tokens per chunk
|
||||
overlap_tokens: Number of overlapping tokens between chunks
|
||||
|
||||
Returns:
|
||||
List of text chunks
|
||||
"""
|
||||
# Apply safety margin to prevent decode/encode inconsistencies
|
||||
# from exceeding the max token limit
|
||||
max_chunk_tokens_to_use = int(max_chunk_tokens * 0.98)
|
||||
|
||||
# Tokenize the full text
|
||||
tokens = tokenizer.encode(section_text, add_special_tokens=False)
|
||||
|
||||
chunks = []
|
||||
start_idx = 0
|
||||
|
||||
while start_idx < len(tokens):
|
||||
# Extract chunk tokens
|
||||
end_idx = start_idx + max_chunk_tokens_to_use
|
||||
chunk_tokens = tokens[start_idx:end_idx]
|
||||
|
||||
# Decode back to text
|
||||
chunk_text = tokenizer.decode(chunk_tokens, skip_special_tokens=True)
|
||||
chunks.append(chunk_text)
|
||||
|
||||
# Move window forward (with overlap)
|
||||
start_idx += max_chunk_tokens_to_use - overlap_tokens
|
||||
|
||||
# Avoid infinite loop if overlap >= max_chunk_tokens
|
||||
if start_idx <= start_idx - (max_chunk_tokens_to_use - overlap_tokens):
|
||||
break
|
||||
|
||||
return chunks
|
||||
74
obsidian_rag/llm_client.py
Normal file
74
obsidian_rag/llm_client.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Dict
|
||||
|
||||
import openai
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""
|
||||
Minimalist client for interacting with Clovis LLM via OpenAI SDK.
|
||||
|
||||
Attributes:
|
||||
api_key (str): API key for Clovis.
|
||||
base_url (str): Base URL for Clovis LLM gateway.
|
||||
model (str): Model name to use. Defaults to 'ClovisLLM'.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, model: str = "ClovisLLM") -> None:
|
||||
if not api_key:
|
||||
raise ValueError("API key is required for LLMClient.")
|
||||
if not base_url:
|
||||
raise ValueError("Base URL is required for LLMClient.")
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self.client = openai.OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
def generate(self, system_prompt: str, user_prompt: str, context: str) -> Dict[str, object]:
|
||||
"""
|
||||
Generate a response from the LLM given a system prompt, user prompt, and context.
|
||||
|
||||
Args:
|
||||
system_prompt (str): Instructions for the assistant.
|
||||
user_prompt (str): The user's query.
|
||||
context (str): Concatenated chunks from RAG search.
|
||||
|
||||
Returns:
|
||||
Dict[str, object]: Contains:
|
||||
- "answer" (str): Text generated by the LLM.
|
||||
- "usage" (int): Total tokens used in the completion.
|
||||
"""
|
||||
# Construct user message with explicit CONTEXT / QUESTION separation
|
||||
user_message_content = f"CONTEXT:\n{context}\n\nQUESTION:\n{user_prompt}"
|
||||
|
||||
try:
|
||||
response = self.client.chat.completions.create(model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message_content}
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=2000,
|
||||
top_p=1.0,
|
||||
n=1,
|
||||
# stream=False,
|
||||
# presence_penalty=0.0,
|
||||
# frequency_penalty=0.0,
|
||||
# stop=None,
|
||||
# logit_bias={},
|
||||
user="obsidian_rag",
|
||||
)
|
||||
except Exception as e:
|
||||
# For now, propagate exceptions (C1 minimal)
|
||||
raise e
|
||||
|
||||
# Extract text and usage
|
||||
try:
|
||||
answer_text = response.choices[0].message.content
|
||||
total_tokens = response.usage.total_tokens
|
||||
except AttributeError:
|
||||
# Fallback if response structure is unexpected
|
||||
answer_text = ""
|
||||
total_tokens = 0
|
||||
|
||||
return {"answer": answer_text, "usage": total_tokens}
|
||||
213
obsidian_rag/markdown_parser.py
Normal file
213
obsidian_rag/markdown_parser.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Markdown parser for Obsidian vault files.
|
||||
|
||||
This module provides functionality to parse markdown files and extract
|
||||
their structure (sections, line numbers) for semantic search indexing.
|
||||
"""
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarkdownSection:
|
||||
"""Represents a section in a markdown document.
|
||||
|
||||
Attributes:
|
||||
level: Header level (0 for no header, 1 for #, 2 for ##, etc.)
|
||||
title: Section title (empty string if level=0)
|
||||
content: Text content without the header line
|
||||
start_line: Line number where section starts (1-indexed)
|
||||
end_line: Line number where section ends (1-indexed, inclusive)
|
||||
"""
|
||||
level: int
|
||||
title: str
|
||||
content: str
|
||||
parents: list[str]
|
||||
start_line: int
|
||||
end_line: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedDocument:
|
||||
"""Represents a parsed markdown document.
|
||||
|
||||
Attributes:
|
||||
file_path: Path to the markdown file
|
||||
sections: List of sections extracted from the document
|
||||
raw_content: Full file content as string
|
||||
"""
|
||||
file_path: Path
|
||||
title: str
|
||||
sections: List[MarkdownSection]
|
||||
raw_content: str
|
||||
|
||||
|
||||
def _compute_parents(current_parents, previous_level, previous_title, current_level):
|
||||
"""Computes the parents of `current_parents`."""
|
||||
return current_parents
|
||||
|
||||
|
||||
def parse_markdown_file(file_path: Path, vault_path=None) -> ParsedDocument:
|
||||
"""Parse a markdown file and extract its structure.
|
||||
|
||||
This function reads a markdown file, identifies all header sections,
|
||||
and extracts their content with precise line number tracking.
|
||||
Files without headers are treated as a single section with level 0.
|
||||
|
||||
Args:
|
||||
file_path: Path to the markdown file to parse
|
||||
vault_path: Path to the vault file.
|
||||
|
||||
Returns:
|
||||
ParsedDocument containing the file structure and content
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist
|
||||
|
||||
Example:
|
||||
>>> doc = parse_markdown_file(Path("notes/example.md"))
|
||||
>>> print(f"Found {len(doc.sections)} sections")
|
||||
>>> print(doc.sections[0].title)
|
||||
"""
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
if vault_path:
|
||||
title = str(file_path.relative_to(vault_path)).replace(".md", "")
|
||||
title = title.replace("\\", " ").replace("/", " ")
|
||||
else:
|
||||
title = file_path.stem
|
||||
|
||||
raw_content = file_path.read_text(encoding="utf-8")
|
||||
lines = raw_content.splitlines()
|
||||
|
||||
sections: List[MarkdownSection] = []
|
||||
current_section_start = 1
|
||||
current_level = 0
|
||||
current_title = ""
|
||||
current_parents = []
|
||||
current_content_lines: List[str] = []
|
||||
|
||||
header_pattern = re.compile(r"^(#{1,6})\s+(.+)$")
|
||||
|
||||
for line_num, line in enumerate(lines, start=1):
|
||||
match = header_pattern.match(line)
|
||||
|
||||
if match:
|
||||
# Save the previous section only if it actually has content.
|
||||
if current_content_lines:
|
||||
content = "\n".join(current_content_lines)
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
level=current_level,
|
||||
title=current_title,
|
||||
content=content,
|
||||
parents=current_parents,
|
||||
start_line=current_section_start,
|
||||
end_line=line_num - 1,
|
||||
)
|
||||
)
|
||||
|
||||
# Start a new section with the detected header.
|
||||
previous_level = current_level
|
||||
previous_title = current_title
|
||||
current_level = len(match.group(1))
|
||||
current_title = match.group(2).strip()
|
||||
current_section_start = line_num
|
||||
current_parents = _compute_parents(current_parents, previous_level, previous_title, current_level)
|
||||
current_content_lines = []
|
||||
else:
|
||||
current_content_lines.append(line)
|
||||
|
||||
# Handle the final section (or whole file if no headers were found).
|
||||
if lines:
|
||||
content = "\n".join(current_content_lines)
|
||||
end_line = len(lines)
|
||||
|
||||
# Case 1 – no header was ever found.
|
||||
if not sections and current_level == 0:
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
level=0,
|
||||
title="",
|
||||
content=content,
|
||||
parents=current_parents,
|
||||
start_line=1,
|
||||
end_line=end_line,
|
||||
)
|
||||
)
|
||||
# Case 2 – a single header was found (sections empty but we have a title).
|
||||
elif not sections:
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
level=current_level,
|
||||
title=current_title,
|
||||
content=content,
|
||||
parents=current_parents,
|
||||
start_line=current_section_start,
|
||||
end_line=end_line,
|
||||
)
|
||||
)
|
||||
# Case 3 – multiple headers were found (sections already contains earlier ones).
|
||||
else:
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
level=current_level,
|
||||
title=current_title,
|
||||
content=content,
|
||||
parents=current_parents,
|
||||
start_line=current_section_start,
|
||||
end_line=end_line,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Empty file: create a single empty level‑0 section.
|
||||
sections.append(
|
||||
MarkdownSection(
|
||||
level=0,
|
||||
title="",
|
||||
content="",
|
||||
parents=[],
|
||||
start_line=1,
|
||||
end_line=1,
|
||||
)
|
||||
)
|
||||
|
||||
return ParsedDocument(
|
||||
file_path=file_path,
|
||||
title=title,
|
||||
sections=sections,
|
||||
raw_content=raw_content,
|
||||
)
|
||||
|
||||
|
||||
def find_section_at_line(
|
||||
document: ParsedDocument,
|
||||
line_number: int,
|
||||
) -> Optional[MarkdownSection]:
|
||||
"""Find which section contains a given line number.
|
||||
|
||||
This function searches through the document's sections to find
|
||||
which section contains the specified line number.
|
||||
|
||||
Args:
|
||||
document: Parsed markdown document
|
||||
line_number: Line number to search for (1-indexed)
|
||||
|
||||
Returns:
|
||||
MarkdownSection containing the line, or None if line number
|
||||
is invalid or out of range
|
||||
|
||||
Example:
|
||||
>>> section = find_section_at_line(doc, 42)
|
||||
>>> if section:
|
||||
... print(f"Line 42 is in section: {section.title}")
|
||||
"""
|
||||
if line_number < 1:
|
||||
return None
|
||||
|
||||
for section in document.sections:
|
||||
if section.start_line <= line_number <= section.end_line:
|
||||
return section
|
||||
96
obsidian_rag/rag_chain.py
Normal file
96
obsidian_rag/rag_chain.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# File: obsidian_rag/rag_chain.py
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from indexer import EMBEDDING_MODEL
|
||||
from llm_client import LLMClient
|
||||
from searcher import search_vault, SearchResult
|
||||
|
||||
|
||||
class RAGChain:
|
||||
"""
|
||||
Retrieval-Augmented Generation (RAG) chain for answering queries
|
||||
using semantic search over an Obsidian vault and LLM.
|
||||
|
||||
Attributes:
|
||||
chroma_db_path (Path): Path to ChromaDB.
|
||||
collection_name (str): Chroma collection name.
|
||||
embedding_model (str): Embedding model name.
|
||||
top_k (int): Number of chunks to send to the LLM.
|
||||
min_score (float): Minimum similarity score for chunks.
|
||||
system_prompt (str): System prompt to instruct the LLM.
|
||||
llm_client (LLMClient): Internal LLM client instance.
|
||||
"""
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"You are an assistant specialized in analyzing Obsidian notes.\n\n"
|
||||
"INSTRUCTIONS:\n"
|
||||
"- Answer based ONLY on the provided context\n"
|
||||
"- Cite the sources (files) you use\n"
|
||||
"- If the information is not in the context, say \"I did not find this information in your notes\"\n"
|
||||
"- Be concise but thorough\n"
|
||||
"- Structure your answer with sections if necessary"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chroma_db_path: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
collection_name: str = "obsidian_vault",
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
top_k: int = 5,
|
||||
min_score: float = 0.0,
|
||||
system_prompt: str = None,
|
||||
) -> None:
|
||||
self.chroma_db_path = Path(chroma_db_path)
|
||||
self.collection_name = collection_name
|
||||
self.embedding_model = embedding_model
|
||||
self.top_k = top_k
|
||||
self.min_score = min_score
|
||||
self.system_prompt = system_prompt or self.DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
# Instantiate internal LLM client
|
||||
self.llm_client = LLMClient(api_key=api_key, base_url=base_url)
|
||||
|
||||
def answer_query(self, query: str) -> Tuple[str, List[SearchResult]]:
|
||||
"""
|
||||
Answer a user query using RAG: search vault, build context, call LLM.
|
||||
|
||||
Args:
|
||||
query (str): User query.
|
||||
|
||||
Returns:
|
||||
Tuple[str, List[SearchResult]]:
|
||||
- LLM answer (str)
|
||||
- List of used SearchResult chunks
|
||||
"""
|
||||
# 1. Perform semantic search
|
||||
chunks: List[SearchResult] = search_vault(
|
||||
query=query,
|
||||
chroma_db_path=str(self.chroma_db_path),
|
||||
collection_name=self.collection_name,
|
||||
embedding_model=self.embedding_model,
|
||||
limit=self.top_k,
|
||||
min_score=self.min_score,
|
||||
)
|
||||
|
||||
# 2. Build context string with citations
|
||||
context_parts: List[str] = []
|
||||
for chunk in chunks:
|
||||
chunk_text = chunk.text.strip()
|
||||
citation = f"[{chunk.file_path}#L{chunk.line_start}-L{chunk.line_end}]"
|
||||
context_parts.append(f"{chunk_text}\n{citation}")
|
||||
|
||||
context_str = "\n\n".join(context_parts) if context_parts else ""
|
||||
|
||||
# 3. Call LLM with context + question
|
||||
llm_response = self.llm_client.generate(
|
||||
system_prompt=self.system_prompt,
|
||||
user_prompt=query,
|
||||
context=context_str,
|
||||
)
|
||||
|
||||
answer_text = llm_response.get("answer", "")
|
||||
return answer_text, chunks
|
||||
131
obsidian_rag/searcher.py
Normal file
131
obsidian_rag/searcher.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Searcher module for Obsidian RAG Backend.
|
||||
|
||||
This module handles semantic search operations on the indexed ChromaDB collection.
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
import chromadb
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from indexer import EMBEDDING_MODEL
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""
|
||||
Represents a single search result with metadata and relevance score.
|
||||
"""
|
||||
file_path: str
|
||||
section_title: str
|
||||
line_start: int
|
||||
line_end: int
|
||||
score: float
|
||||
text: str
|
||||
|
||||
|
||||
def search_vault(
|
||||
query: str,
|
||||
chroma_db_path: str,
|
||||
collection_name: str = "obsidian_vault",
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
limit: int = 5,
|
||||
min_score: float = 0.0,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search the indexed vault for semantically similar content.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
chroma_db_path: Path to ChromaDB data directory
|
||||
collection_name: Name of the ChromaDB collection to search
|
||||
embedding_model: Model used for embeddings (must match indexing model)
|
||||
limit: Maximum number of results to return
|
||||
min_score: Minimum similarity score threshold (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects, sorted by relevance (highest score first)
|
||||
|
||||
Raises:
|
||||
ValueError: If the collection does not exist or query is empty
|
||||
"""
|
||||
if not query or not query.strip():
|
||||
raise ValueError("Query cannot be empty")
|
||||
|
||||
# Initialize ChromaDB client
|
||||
chroma_client = chromadb.PersistentClient(
|
||||
path=chroma_db_path,
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
|
||||
# Get collection (will raise if it doesn't exist)
|
||||
try:
|
||||
collection = chroma_client.get_collection(name=collection_name)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Collection '{collection_name}' not found. "
|
||||
f"Please index your vault first using the index command."
|
||||
) from e
|
||||
|
||||
# Initialize embedding model (same as used during indexing)
|
||||
model = SentenceTransformer(embedding_model)
|
||||
|
||||
# Generate query embedding
|
||||
query_embedding = model.encode(query, show_progress_bar=False)
|
||||
|
||||
# Perform search
|
||||
results = collection.query(
|
||||
query_embeddings=[query_embedding.tolist()],
|
||||
n_results=limit,
|
||||
)
|
||||
|
||||
# Parse and format results
|
||||
search_results = _parse_search_results(results, min_score)
|
||||
|
||||
return search_results
|
||||
|
||||
|
||||
def _parse_search_results(
|
||||
raw_results: dict,
|
||||
min_score: float,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Parse ChromaDB query results into SearchResult objects.
|
||||
|
||||
ChromaDB returns distances (lower = more similar). We convert to
|
||||
similarity scores (higher = more similar) using: score = 1 - distance
|
||||
|
||||
Args:
|
||||
raw_results: Raw results dictionary from ChromaDB query
|
||||
min_score: Minimum similarity score to include
|
||||
|
||||
Returns:
|
||||
List of SearchResult objects filtered by min_score
|
||||
"""
|
||||
search_results = []
|
||||
|
||||
# ChromaDB returns results as lists of lists (one list per query)
|
||||
# We only have one query, so we take the first element
|
||||
documents = raw_results.get("documents", [[]])[0]
|
||||
metadatas = raw_results.get("metadatas", [[]])[0]
|
||||
distances = raw_results.get("distances", [[]])[0]
|
||||
|
||||
for doc, metadata, distance in zip(documents, metadatas, distances):
|
||||
# Convert distance to similarity score (cosine distance -> cosine similarity)
|
||||
score = 1.0 - distance
|
||||
|
||||
# Filter by minimum score
|
||||
if score < min_score:
|
||||
continue
|
||||
|
||||
search_results.append(SearchResult(
|
||||
file_path=metadata["file_path"],
|
||||
section_title=metadata["section_title"],
|
||||
line_start=metadata["line_start"],
|
||||
line_end=metadata["line_end"],
|
||||
score=score,
|
||||
text=doc,
|
||||
))
|
||||
|
||||
return search_results
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
537
tests/test_cli.py
Normal file
537
tests/test_cli.py
Normal file
@@ -0,0 +1,537 @@
|
||||
"""
|
||||
Unit tests for the CLI module.
|
||||
"""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from typer.testing import CliRunner
|
||||
from obsidian_rag.cli import app, _display_index_results, _display_results_compact
|
||||
from obsidian_rag.indexer import index_vault
|
||||
from obsidian_rag.searcher import SearchResult
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_vault(tmp_path):
|
||||
"""
|
||||
Create a temporary vault with sample markdown files.
|
||||
"""
|
||||
vault_path = tmp_path / "test_vault"
|
||||
vault_path.mkdir()
|
||||
|
||||
# Create sample files
|
||||
file1 = vault_path / "python.md"
|
||||
file1.write_text("""# Python Programming
|
||||
|
||||
Python is a high-level programming language.
|
||||
|
||||
## Features
|
||||
|
||||
Python has dynamic typing and automatic memory management.
|
||||
""")
|
||||
|
||||
file2 = vault_path / "javascript.md"
|
||||
file2.write_text("""# JavaScript
|
||||
|
||||
JavaScript is a scripting language for web development.
|
||||
|
||||
## Usage
|
||||
|
||||
JavaScript runs in web browsers and Node.js environments.
|
||||
""")
|
||||
|
||||
file3 = vault_path / "cooking.md"
|
||||
file3.write_text("""# Cooking Tips
|
||||
|
||||
Learn how to cook delicious meals.
|
||||
|
||||
## Basics
|
||||
|
||||
Start with simple recipes and basic techniques.
|
||||
""")
|
||||
|
||||
return vault_path
|
||||
|
||||
|
||||
# Tests for 'index' command - Passing tests
|
||||
|
||||
|
||||
def test_i_can_index_vault_successfully(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that we can index a vault successfully.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
result = runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Found 3 markdown files to index" in result.stdout
|
||||
assert "Indexing completed" in result.stdout
|
||||
assert "Files processed:" in result.stdout
|
||||
assert "Chunks created:" in result.stdout
|
||||
|
||||
|
||||
def test_i_can_index_with_custom_chroma_path(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that we can specify a custom ChromaDB path.
|
||||
"""
|
||||
custom_chroma = tmp_path / "my_custom_db"
|
||||
|
||||
result = runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(custom_chroma),
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert custom_chroma.exists()
|
||||
assert (custom_chroma / "chroma.sqlite3").exists()
|
||||
|
||||
|
||||
def test_i_can_index_with_custom_collection_name(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that we can use a custom collection name.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
collection_name = "my_custom_collection"
|
||||
|
||||
result = runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
"--collection", collection_name,
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert f"Collection: {collection_name}" in result.stdout
|
||||
|
||||
|
||||
def test_i_can_see_errors_in_index_results(tmp_path):
|
||||
"""
|
||||
Test that errors during indexing are displayed.
|
||||
"""
|
||||
vault_path = tmp_path / "vault_with_errors"
|
||||
vault_path.mkdir()
|
||||
|
||||
# Create a valid file
|
||||
valid_file = vault_path / "valid.md"
|
||||
valid_file.write_text("# Valid File\n\nThis is valid content.")
|
||||
|
||||
# Create an invalid file (will cause parsing error)
|
||||
invalid_file = vault_path / "invalid.md"
|
||||
invalid_file.write_bytes(b"\xff\xfe\x00\x00") # Invalid UTF-8
|
||||
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
result = runner.invoke(app, [
|
||||
"index",
|
||||
str(vault_path),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
# Should still complete (exit code 0) but show errors
|
||||
assert result.exit_code == 0
|
||||
assert "Indexing completed" in result.stdout
|
||||
# Note: Error display might vary, just check it completed
|
||||
|
||||
|
||||
# Tests for 'index' command - Error tests
|
||||
|
||||
|
||||
def test_i_cannot_index_nonexistent_vault(tmp_path):
|
||||
"""
|
||||
Test that indexing a nonexistent vault fails with clear error.
|
||||
"""
|
||||
nonexistent_path = tmp_path / "does_not_exist"
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
result = runner.invoke(app, [
|
||||
"index",
|
||||
str(nonexistent_path),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "does not exist" in result.stdout
|
||||
|
||||
|
||||
def test_i_cannot_index_file_instead_of_directory(tmp_path):
|
||||
"""
|
||||
Test that indexing a file (not directory) fails.
|
||||
"""
|
||||
file_path = tmp_path / "somefile.txt"
|
||||
file_path.write_text("I am a file")
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
result = runner.invoke(app, [
|
||||
"index",
|
||||
str(file_path),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "not a directory" in result.stdout
|
||||
|
||||
|
||||
def test_i_can_handle_empty_vault_gracefully(tmp_path):
|
||||
"""
|
||||
Test that an empty vault (no .md files) is handled gracefully.
|
||||
"""
|
||||
empty_vault = tmp_path / "empty_vault"
|
||||
empty_vault.mkdir()
|
||||
|
||||
# Create a non-markdown file
|
||||
(empty_vault / "readme.txt").write_text("Not a markdown file")
|
||||
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
result = runner.invoke(app, [
|
||||
"index",
|
||||
str(empty_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No markdown files found" in result.stdout
|
||||
|
||||
|
||||
# Tests for 'search' command - Passing tests
|
||||
|
||||
|
||||
def test_i_can_search_indexed_vault(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that we can search an indexed vault.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
# First, index the vault
|
||||
index_result = runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
assert index_result.exit_code == 0
|
||||
|
||||
# Then search
|
||||
search_result = runner.invoke(app, [
|
||||
"search",
|
||||
"Python programming",
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
assert search_result.exit_code == 0
|
||||
assert "Found" in search_result.stdout
|
||||
assert "result(s) for:" in search_result.stdout
|
||||
assert "python.md" in search_result.stdout
|
||||
|
||||
|
||||
def test_i_can_search_with_limit_option(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that the --limit option works.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
# Index
|
||||
runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
# Search with limit
|
||||
result = runner.invoke(app, [
|
||||
"search",
|
||||
"programming",
|
||||
"--chroma-path", str(chroma_path),
|
||||
"--limit", "2",
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Count result numbers (1., 2., etc.)
|
||||
result_count = result.stdout.count("[bold cyan]")
|
||||
assert result_count <= 2
|
||||
|
||||
|
||||
def test_i_can_search_with_min_score_option(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that the --min-score option works.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
# Index
|
||||
runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
# Search with high min-score
|
||||
result = runner.invoke(app, [
|
||||
"search",
|
||||
"Python",
|
||||
"--chroma-path", str(chroma_path),
|
||||
"--min-score", "0.5",
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Should have results (Python file should match well)
|
||||
assert "Found" in result.stdout
|
||||
|
||||
|
||||
def test_i_can_search_with_custom_collection(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that we can search in a custom collection.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
collection_name = "test_collection"
|
||||
|
||||
# Index with custom collection
|
||||
runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
"--collection", collection_name,
|
||||
])
|
||||
|
||||
# Search in same collection
|
||||
result = runner.invoke(app, [
|
||||
"search",
|
||||
"Python",
|
||||
"--chroma-path", str(chroma_path),
|
||||
"--collection", collection_name,
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Found" in result.stdout
|
||||
|
||||
|
||||
def test_i_can_handle_no_results_gracefully(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that no results scenario is handled gracefully.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
# Index
|
||||
runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
# Search for something unlikely with high threshold
|
||||
result = runner.invoke(app, [
|
||||
"search",
|
||||
"quantum physics relativity",
|
||||
"--chroma-path", str(chroma_path),
|
||||
"--min-score", "0.95",
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No results found" in result.stdout
|
||||
|
||||
|
||||
def test_i_can_use_compact_format(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that compact format displays correctly.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
# Index
|
||||
runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
# Search with explicit compact format
|
||||
result = runner.invoke(app, [
|
||||
"search",
|
||||
"Python",
|
||||
"--chroma-path", str(chroma_path),
|
||||
"--format", "compact",
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Check for compact format elements
|
||||
assert "Section:" in result.stdout
|
||||
assert "Lines:" in result.stdout
|
||||
assert "score:" in result.stdout
|
||||
|
||||
|
||||
# Tests for 'search' command - Error tests
|
||||
|
||||
|
||||
def test_i_cannot_search_without_index(tmp_path):
|
||||
"""
|
||||
Test that searching without indexing fails with clear message.
|
||||
"""
|
||||
chroma_path = tmp_path / "nonexistent_chroma"
|
||||
|
||||
result = runner.invoke(app, [
|
||||
"search",
|
||||
"test query",
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "not found" in result.stdout
|
||||
assert "index" in result.stdout.lower()
|
||||
|
||||
|
||||
def test_i_cannot_search_nonexistent_collection(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that searching in a nonexistent collection fails.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
# Index with default collection
|
||||
runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
# Search in different collection
|
||||
result = runner.invoke(app, [
|
||||
"search",
|
||||
"Python",
|
||||
"--chroma-path", str(chroma_path),
|
||||
"--collection", "nonexistent_collection",
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "not found" in result.stdout
|
||||
|
||||
|
||||
def test_i_cannot_use_invalid_format(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that an invalid format is rejected.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
# Index
|
||||
runner.invoke(app, [
|
||||
"index",
|
||||
str(temp_vault),
|
||||
"--chroma-path", str(chroma_path),
|
||||
])
|
||||
|
||||
# Search with invalid format
|
||||
result = runner.invoke(app, [
|
||||
"search",
|
||||
"Python",
|
||||
"--chroma-path", str(chroma_path),
|
||||
"--format", "invalid_format",
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Invalid format" in result.stdout
|
||||
assert "compact" in result.stdout
|
||||
|
||||
|
||||
# Tests for helper functions
|
||||
|
||||
|
||||
def test_i_can_display_index_results(capsys):
|
||||
"""
|
||||
Test that index results are displayed correctly.
|
||||
"""
|
||||
stats = {
|
||||
"files_processed": 10,
|
||||
"chunks_created": 50,
|
||||
"collection_name": "test_collection",
|
||||
"errors": [],
|
||||
}
|
||||
|
||||
_display_index_results(stats)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Indexing completed" in captured.out
|
||||
assert "10" in captured.out
|
||||
assert "50" in captured.out
|
||||
assert "test_collection" in captured.out
|
||||
|
||||
|
||||
def test_i_can_display_index_results_with_errors(capsys):
|
||||
"""
|
||||
Test that index results with errors are displayed correctly.
|
||||
"""
|
||||
stats = {
|
||||
"files_processed": 8,
|
||||
"chunks_created": 40,
|
||||
"collection_name": "test_collection",
|
||||
"errors": [
|
||||
{"file": "broken.md", "error": "Invalid encoding"},
|
||||
{"file": "corrupt.md", "error": "Parse error"},
|
||||
],
|
||||
}
|
||||
|
||||
_display_index_results(stats)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "Indexing completed" in captured.out
|
||||
assert "2 file(s) skipped" in captured.out
|
||||
assert "broken.md" in captured.out
|
||||
assert "Invalid encoding" in captured.out
|
||||
|
||||
|
||||
def test_i_can_display_results_compact(capsys):
|
||||
"""
|
||||
Test that compact results display correctly.
|
||||
"""
|
||||
results = [
|
||||
SearchResult(
|
||||
file_path="notes/python.md",
|
||||
section_title="Introduction",
|
||||
line_start=1,
|
||||
line_end=5,
|
||||
score=0.87,
|
||||
text="Python is a high-level programming language.",
|
||||
),
|
||||
SearchResult(
|
||||
file_path="notes/javascript.md",
|
||||
section_title="Overview",
|
||||
line_start=10,
|
||||
line_end=15,
|
||||
score=0.65,
|
||||
text="JavaScript is used for web development.",
|
||||
),
|
||||
]
|
||||
|
||||
_display_results_compact(results)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "python.md" in captured.out
|
||||
assert "javascript.md" in captured.out
|
||||
assert "0.87" in captured.out
|
||||
assert "0.65" in captured.out
|
||||
assert "Introduction" in captured.out
|
||||
assert "Overview" in captured.out
|
||||
|
||||
|
||||
def test_i_can_display_results_compact_with_long_text(capsys):
|
||||
"""
|
||||
Test that long text is truncated in compact display.
|
||||
"""
|
||||
long_text = "A" * 300 # Text longer than 200 characters
|
||||
|
||||
results = [
|
||||
SearchResult(
|
||||
file_path="notes/long.md",
|
||||
section_title="Long Section",
|
||||
line_start=1,
|
||||
line_end=10,
|
||||
score=0.75,
|
||||
text=long_text,
|
||||
),
|
||||
]
|
||||
|
||||
_display_results_compact(results)
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "..." in captured.out # Should be truncated
|
||||
assert len([line for line in captured.out.split('\n') if 'A' * 200 in line]) == 0 # Full text not shown
|
||||
381
tests/test_indexer.py
Normal file
381
tests/test_indexer.py
Normal file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Unit tests for the indexer module.
|
||||
"""
|
||||
|
||||
import chromadb
|
||||
import pytest
|
||||
from chromadb.config import Settings
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from indexer import (
|
||||
index_vault,
|
||||
_chunk_section,
|
||||
_create_chunks_from_document,
|
||||
_get_or_create_collection, EMBEDDING_MODEL,
|
||||
)
|
||||
from obsidian_rag.markdown_parser import ParsedDocument, MarkdownSection
|
||||
|
||||
|
||||
# Fixtures
|
||||
|
||||
@pytest.fixture
|
||||
def tokenizer():
|
||||
"""Provide sentence-transformers tokenizer."""
|
||||
model = SentenceTransformer(EMBEDDING_MODEL)
|
||||
return model.tokenizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_model():
|
||||
"""Provide sentence-transformers model."""
|
||||
return SentenceTransformer(EMBEDDING_MODEL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_client(tmp_path):
|
||||
"""Provide ChromaDB client with temporary storage."""
|
||||
client = chromadb.PersistentClient(
|
||||
path=str(tmp_path / "chroma_test"),
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_vault(tmp_path):
|
||||
"""Create a temporary vault with test markdown files."""
|
||||
vault_path = tmp_path / "test_vault"
|
||||
vault_path.mkdir()
|
||||
return vault_path
|
||||
|
||||
|
||||
# Tests for _chunk_section()
|
||||
|
||||
def test_i_can_chunk_short_section_into_single_chunk(tokenizer):
|
||||
"""Test that a short section is not split."""
|
||||
# Create text with ~100 tokens
|
||||
short_text = " ".join(["word"] * 100)
|
||||
|
||||
chunks = _chunk_section(
|
||||
section_text=short_text,
|
||||
tokenizer=tokenizer,
|
||||
max_chunk_tokens=200,
|
||||
overlap_tokens=30,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0] == short_text
|
||||
|
||||
|
||||
def test_i_can_chunk_long_section_with_overlap(tokenizer):
|
||||
"""Test splitting long section with overlap."""
|
||||
# Create text with ~500 tokens
|
||||
long_text = " ".join([f"word{i}" for i in range(500)])
|
||||
|
||||
chunks = _chunk_section(
|
||||
section_text=long_text,
|
||||
tokenizer=tokenizer,
|
||||
max_chunk_tokens=200,
|
||||
overlap_tokens=30,
|
||||
)
|
||||
|
||||
# Should create multiple chunks
|
||||
assert len(chunks) >= 2
|
||||
|
||||
# Verify no chunk exceeds max tokens
|
||||
for chunk in chunks:
|
||||
tokens = tokenizer.encode(chunk, add_special_tokens=False)
|
||||
assert len(tokens) <= 200
|
||||
|
||||
# Verify overlap exists between consecutive chunks
|
||||
for i in range(len(chunks) - 1):
|
||||
# Check that some words from end of chunk[i] appear in start of chunk[i+1]
|
||||
words_chunk1 = chunks[i].split()[-10:] # Last 10 words
|
||||
words_chunk2 = chunks[i + 1].split()[:10] # First 10 words
|
||||
|
||||
# At least some overlap should exist
|
||||
overlap_found = any(word in words_chunk2 for word in words_chunk1)
|
||||
assert overlap_found
|
||||
|
||||
|
||||
def test_i_can_chunk_empty_section(tokenizer):
|
||||
"""Test chunking an empty section."""
|
||||
empty_text = ""
|
||||
|
||||
chunks = _chunk_section(
|
||||
section_text=empty_text,
|
||||
tokenizer=tokenizer,
|
||||
max_chunk_tokens=200,
|
||||
overlap_tokens=30,
|
||||
)
|
||||
|
||||
assert len(chunks) == 0
|
||||
|
||||
|
||||
# Tests for _create_chunks_from_document()
|
||||
|
||||
def test_i_can_create_chunks_from_document_with_short_sections(tmp_path, tokenizer):
|
||||
"""Test creating chunks from document with only short sections."""
|
||||
vault_path = tmp_path / "vault"
|
||||
vault_path.mkdir()
|
||||
|
||||
parsed_doc = ParsedDocument(
|
||||
file_path=vault_path / "test.md",
|
||||
title="test.md",
|
||||
sections=[
|
||||
MarkdownSection(1, "Section 1", "This is a short section with few words.", [], 1, 2),
|
||||
MarkdownSection(2, "Section 2", "Another short section here.", ["Section 1"], 3, 4),
|
||||
MarkdownSection(3, "Section 3", "Third short section.", ["Section 1", "Section 3"], 5, 6),
|
||||
],
|
||||
raw_content="" # not used in this test
|
||||
)
|
||||
|
||||
chunks = _create_chunks_from_document(
|
||||
parsed_doc=parsed_doc,
|
||||
tokenizer=tokenizer,
|
||||
max_chunk_tokens=200,
|
||||
overlap_tokens=30,
|
||||
vault_path=vault_path,
|
||||
)
|
||||
|
||||
# Should create 3 chunks (one per section)
|
||||
assert len(chunks) == 3
|
||||
|
||||
# Verify metadata
|
||||
for i, chunk in enumerate(chunks):
|
||||
metadata = chunk.metadata
|
||||
assert metadata.file_path == "test.md"
|
||||
assert metadata.section_title == f"Section {i + 1}"
|
||||
assert isinstance(metadata.line_start, int)
|
||||
assert isinstance(metadata.line_end, int)
|
||||
|
||||
# Verify ID format
|
||||
assert "test.md" in chunk.id
|
||||
assert f"Section {i + 1}" in chunk.id
|
||||
|
||||
|
||||
def test_i_can_create_chunks_from_document_with_long_section(tmp_path, tokenizer):
|
||||
"""Test creating chunks from document with a long section that needs splitting."""
|
||||
vault_path = tmp_path / "vault"
|
||||
vault_path.mkdir()
|
||||
|
||||
# Create long content (~500 tokens)
|
||||
long_content = " ".join([f"word{i}" for i in range(500)])
|
||||
|
||||
parsed_doc = ParsedDocument(
|
||||
file_path=vault_path / "test.md",
|
||||
title="test.md",
|
||||
sections=[
|
||||
MarkdownSection(1, "Long Section", long_content, [], 1, 1)
|
||||
],
|
||||
raw_content=long_content,
|
||||
)
|
||||
|
||||
chunks = _create_chunks_from_document(
|
||||
parsed_doc=parsed_doc,
|
||||
tokenizer=tokenizer,
|
||||
max_chunk_tokens=200,
|
||||
overlap_tokens=30,
|
||||
vault_path=vault_path,
|
||||
)
|
||||
|
||||
# Should create multiple chunks
|
||||
assert len(chunks) >= 2
|
||||
|
||||
# All chunks should have same section_title
|
||||
for chunk in chunks:
|
||||
assert chunk.metadata.section_title == "Long Section"
|
||||
assert chunk.metadata.line_start == 1
|
||||
assert chunk.metadata.line_end == 1
|
||||
|
||||
# IDs should include chunk numbers
|
||||
assert "::chunk0" in chunks[0].id
|
||||
assert "::chunk1" in chunks[1].id
|
||||
|
||||
|
||||
def test_i_can_create_chunks_with_correct_relative_paths(tmp_path, tokenizer):
|
||||
"""Test that relative paths are correctly computed."""
|
||||
vault_path = tmp_path / "vault"
|
||||
vault_path.mkdir()
|
||||
|
||||
# Create subdirectory
|
||||
subdir = vault_path / "subfolder"
|
||||
subdir.mkdir()
|
||||
|
||||
parsed_doc = ParsedDocument(
|
||||
file_path=subdir / "nested.md",
|
||||
title=f"{subdir} nested.md",
|
||||
sections=[
|
||||
MarkdownSection(1, "Section", "Some content here.", [], 1, 2),
|
||||
],
|
||||
raw_content="",
|
||||
)
|
||||
|
||||
chunks = _create_chunks_from_document(
|
||||
parsed_doc=parsed_doc,
|
||||
tokenizer=tokenizer,
|
||||
max_chunk_tokens=200,
|
||||
overlap_tokens=30,
|
||||
vault_path=vault_path,
|
||||
)
|
||||
|
||||
assert len(chunks) == 1
|
||||
assert chunks[0].metadata.file_path == "subfolder/nested.md"
|
||||
|
||||
|
||||
# Tests for _get_or_create_collection()
|
||||
|
||||
def test_i_can_create_new_collection(chroma_client):
|
||||
"""Test creating a new collection that doesn't exist."""
|
||||
collection_name = "test_collection"
|
||||
|
||||
collection = _get_or_create_collection(chroma_client, collection_name)
|
||||
|
||||
assert collection.name == collection_name
|
||||
assert collection.count() == 0 # Should be empty
|
||||
|
||||
|
||||
def test_i_can_reset_existing_collection(chroma_client):
|
||||
"""Test that an existing collection is deleted and recreated."""
|
||||
collection_name = "test_collection"
|
||||
|
||||
# Create collection and add data
|
||||
first_collection = chroma_client.create_collection(collection_name)
|
||||
first_collection.add(
|
||||
documents=["test document"],
|
||||
ids=["test_id"],
|
||||
)
|
||||
assert first_collection.count() == 1
|
||||
|
||||
# Reset collection
|
||||
new_collection = _get_or_create_collection(chroma_client, collection_name)
|
||||
|
||||
assert new_collection.name == collection_name
|
||||
assert new_collection.count() == 0 # Should be empty after reset
|
||||
|
||||
|
||||
# Tests for index_vault()
|
||||
|
||||
def test_i_can_index_single_markdown_file(test_vault, tmp_path, embedding_model):
|
||||
"""Test indexing a single markdown file."""
|
||||
# Create test markdown file
|
||||
test_file = test_vault / "test.md"
|
||||
test_file.write_text(
|
||||
"# Title\n\nThis is a test document with some content.\n\n## Section\n\nMore content here."
|
||||
)
|
||||
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
stats = index_vault(
|
||||
vault_path=str(test_vault),
|
||||
chroma_db_path=str(chroma_path),
|
||||
collection_name="test_collection",
|
||||
)
|
||||
|
||||
assert stats["files_processed"] == 1
|
||||
assert stats["chunks_created"] > 0
|
||||
assert stats["errors"] == []
|
||||
assert stats["collection_name"] == "test_collection"
|
||||
|
||||
# Verify collection contains data
|
||||
client = chromadb.PersistentClient(
|
||||
path=str(chroma_path),
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
collection = client.get_collection("test_collection")
|
||||
assert collection.count() == stats["chunks_created"]
|
||||
|
||||
|
||||
def test_i_can_index_multiple_markdown_files(test_vault, tmp_path):
|
||||
"""Test indexing multiple markdown files."""
|
||||
# Create multiple test files
|
||||
for i in range(3):
|
||||
test_file = test_vault / f"test{i}.md"
|
||||
test_file.write_text(f"# Document {i}\n\nContent for document {i}.")
|
||||
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
stats = index_vault(
|
||||
vault_path=str(test_vault),
|
||||
chroma_db_path=str(chroma_path),
|
||||
)
|
||||
|
||||
assert stats["files_processed"] == 3
|
||||
assert stats["chunks_created"] >= 3 # At least one chunk per file
|
||||
assert stats["errors"] == []
|
||||
|
||||
|
||||
def test_i_can_continue_indexing_after_file_error(test_vault, tmp_path, monkeypatch):
|
||||
"""Test that indexing continues after encountering an error."""
|
||||
# Create valid files
|
||||
(test_vault / "valid1.md").write_text("# Valid 1\n\nContent here.")
|
||||
(test_vault / "valid2.md").write_text("# Valid 2\n\nMore content.")
|
||||
(test_vault / "problematic.md").write_text("# Problem\n\nThis will fail.")
|
||||
|
||||
# Mock parse_markdown_file to fail for problematic.md
|
||||
from obsidian_rag import markdown_parser
|
||||
original_parse = markdown_parser.parse_markdown_file
|
||||
|
||||
def mock_parse(file_path):
|
||||
if "problematic.md" in str(file_path):
|
||||
raise ValueError("Simulated parsing error")
|
||||
return original_parse(file_path)
|
||||
|
||||
monkeypatch.setattr("indexer.parse_markdown_file", mock_parse)
|
||||
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
stats = index_vault(
|
||||
vault_path=str(test_vault),
|
||||
chroma_db_path=str(chroma_path),
|
||||
)
|
||||
|
||||
# Should process 2 valid files
|
||||
assert stats["files_processed"] == 2
|
||||
assert len(stats["errors"]) == 1
|
||||
assert "problematic.md" in stats["errors"][0]["file"]
|
||||
assert "Simulated parsing error" in stats["errors"][0]["error"]
|
||||
|
||||
|
||||
def test_i_cannot_index_nonexistent_vault(tmp_path):
|
||||
"""Test that indexing a nonexistent vault raises an error."""
|
||||
nonexistent_path = tmp_path / "nonexistent_vault"
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
with pytest.raises(ValueError, match="Vault path does not exist"):
|
||||
index_vault(
|
||||
vault_path=str(nonexistent_path),
|
||||
chroma_db_path=str(chroma_path),
|
||||
)
|
||||
|
||||
|
||||
def test_i_can_verify_embeddings_are_generated(test_vault, tmp_path):
|
||||
"""Test that embeddings are properly generated and stored."""
|
||||
# Create test file
|
||||
test_file = test_vault / "test.md"
|
||||
test_file.write_text("# Test\n\nThis is test content for embedding generation.")
|
||||
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
|
||||
stats = index_vault(
|
||||
vault_path=str(test_vault),
|
||||
chroma_db_path=str(chroma_path),
|
||||
)
|
||||
|
||||
# Verify embeddings in collection
|
||||
client = chromadb.PersistentClient(
|
||||
path=str(chroma_path),
|
||||
settings=Settings(anonymized_telemetry=False)
|
||||
)
|
||||
collection = client.get_collection("obsidian_vault")
|
||||
|
||||
# Get all items
|
||||
results = collection.get(include=["embeddings"])
|
||||
|
||||
assert len(results["ids"]) == stats["chunks_created"]
|
||||
assert results["embeddings"] is not None
|
||||
|
||||
# Verify embeddings are non-zero vectors of correct dimension
|
||||
for embedding in results["embeddings"]:
|
||||
assert len(embedding) == 384 # all-MiniLM-L6-v2 dimension
|
||||
assert any(val != 0 for val in embedding) # Not all zeros
|
||||
238
tests/test_markdown_parser.py
Normal file
238
tests/test_markdown_parser.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Unit tests for markdown_parser module."""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from markdown_parser import (
|
||||
parse_markdown_file,
|
||||
find_section_at_line,
|
||||
MarkdownSection,
|
||||
ParsedDocument
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_markdown_file(tmp_path):
|
||||
"""Fixture to create temporary markdown files for testing.
|
||||
|
||||
Args:
|
||||
tmp_path: pytest temporary directory fixture
|
||||
|
||||
Returns:
|
||||
Function that creates a markdown file with given content
|
||||
"""
|
||||
|
||||
def _create_file(content: str, filename: str = "test.md") -> Path:
|
||||
file_path = tmp_path / filename
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
return file_path
|
||||
|
||||
return _create_file
|
||||
|
||||
|
||||
# Tests for parse_markdown_file()
|
||||
|
||||
def test_i_can_parse_file_with_single_section(tmp_markdown_file):
|
||||
"""Test parsing a file with a single header section."""
|
||||
content = """# Main Title
|
||||
This is the content of the section.
|
||||
It has multiple lines."""
|
||||
|
||||
file_path = tmp_markdown_file(content)
|
||||
doc = parse_markdown_file(file_path)
|
||||
|
||||
assert len(doc.sections) == 1
|
||||
assert doc.sections[0].level == 1
|
||||
assert doc.sections[0].title == "Main Title"
|
||||
assert "This is the content" in doc.sections[0].content
|
||||
assert doc.sections[0].start_line == 1
|
||||
assert doc.sections[0].end_line == 3
|
||||
|
||||
|
||||
def test_i_can_parse_file_with_multiple_sections(tmp_markdown_file):
|
||||
"""Test parsing a file with multiple sections at the same level."""
|
||||
content = """# Section One
|
||||
Content of section one.
|
||||
|
||||
# Section Two
|
||||
Content of section two.
|
||||
|
||||
# Section Three
|
||||
Content of section three."""
|
||||
|
||||
file_path = tmp_markdown_file(content)
|
||||
doc = parse_markdown_file(file_path)
|
||||
|
||||
assert len(doc.sections) == 3
|
||||
assert doc.sections[0].title == "Section One"
|
||||
assert doc.sections[1].title == "Section Two"
|
||||
assert doc.sections[2].title == "Section Three"
|
||||
assert all(section.level == 1 for section in doc.sections)
|
||||
|
||||
|
||||
def test_i_can_parse_file_with_nested_sections(tmp_markdown_file):
|
||||
"""Test parsing a file with nested headers (different levels)."""
|
||||
content = """# Main Title
|
||||
Introduction text.
|
||||
|
||||
## Subsection A
|
||||
Content A.
|
||||
|
||||
## Subsection B
|
||||
Content B.
|
||||
|
||||
### Sub-subsection
|
||||
Nested content."""
|
||||
|
||||
file_path = tmp_markdown_file(content)
|
||||
doc = parse_markdown_file(file_path)
|
||||
|
||||
assert len(doc.sections) == 4
|
||||
assert doc.sections[0].level == 1
|
||||
assert doc.sections[0].title == "Main Title"
|
||||
assert doc.sections[1].level == 2
|
||||
assert doc.sections[1].title == "Subsection A"
|
||||
assert doc.sections[2].level == 2
|
||||
assert doc.sections[2].title == "Subsection B"
|
||||
assert doc.sections[3].level == 3
|
||||
assert doc.sections[3].title == "Sub-subsection"
|
||||
|
||||
|
||||
def test_i_can_parse_file_without_headers(tmp_markdown_file):
|
||||
"""Test parsing a file with no headers (plain text)."""
|
||||
content = """This is a plain text file.
|
||||
It has no headers at all.
|
||||
Just regular content."""
|
||||
|
||||
file_path = tmp_markdown_file(content)
|
||||
doc = parse_markdown_file(file_path)
|
||||
|
||||
assert len(doc.sections) == 1
|
||||
assert doc.sections[0].level == 0
|
||||
assert doc.sections[0].title == ""
|
||||
assert doc.sections[0].content == content
|
||||
assert doc.sections[0].start_line == 1
|
||||
assert doc.sections[0].end_line == 3
|
||||
|
||||
|
||||
def test_i_can_parse_empty_file(tmp_markdown_file):
|
||||
"""Test parsing an empty file."""
|
||||
content = ""
|
||||
|
||||
file_path = tmp_markdown_file(content)
|
||||
doc = parse_markdown_file(file_path)
|
||||
|
||||
assert len(doc.sections) == 1
|
||||
assert doc.sections[0].level == 0
|
||||
assert doc.sections[0].title == ""
|
||||
assert doc.sections[0].content == ""
|
||||
assert doc.sections[0].start_line == 1
|
||||
assert doc.sections[0].end_line == 1
|
||||
|
||||
|
||||
def test_i_can_track_correct_line_numbers(tmp_markdown_file):
|
||||
"""Test that line numbers are correctly tracked for each section."""
|
||||
content = """# First Section
|
||||
Line 2
|
||||
Line 3
|
||||
|
||||
# Second Section
|
||||
Line 6
|
||||
Line 7
|
||||
Line 8"""
|
||||
|
||||
file_path = tmp_markdown_file(content)
|
||||
doc = parse_markdown_file(file_path)
|
||||
|
||||
assert doc.sections[0].start_line == 1
|
||||
assert doc.sections[0].end_line == 4
|
||||
assert doc.sections[1].start_line == 5
|
||||
assert doc.sections[1].end_line == 8
|
||||
|
||||
|
||||
def test_i_cannot_parse_nonexistent_file():
|
||||
"""Test that parsing a non-existent file raises FileNotFoundError."""
|
||||
fake_path = Path("/nonexistent/path/to/file.md")
|
||||
|
||||
with pytest.raises(FileNotFoundError):
|
||||
parse_markdown_file(fake_path)
|
||||
|
||||
|
||||
# Tests for find_section_at_line()
|
||||
|
||||
def test_i_can_find_section_at_specific_line(tmp_markdown_file):
|
||||
"""Test finding a section at a line in the middle of content."""
|
||||
content = """# Section One
|
||||
Line 2
|
||||
Line 3
|
||||
|
||||
# Section Two
|
||||
Line 6
|
||||
Line 7"""
|
||||
|
||||
file_path = tmp_markdown_file(content)
|
||||
doc = parse_markdown_file(file_path)
|
||||
|
||||
section = find_section_at_line(doc, 3)
|
||||
|
||||
assert section is not None
|
||||
assert section.title == "Section One"
|
||||
|
||||
section = find_section_at_line(doc, 6)
|
||||
|
||||
assert section is not None
|
||||
assert section.title == "Section Two"
|
||||
|
||||
|
||||
def test_i_can_find_section_at_first_line(tmp_markdown_file):
|
||||
"""Test finding a section at the header line itself."""
|
||||
content = """# Main Title
|
||||
Content here."""
|
||||
|
||||
file_path = tmp_markdown_file(content)
|
||||
doc = parse_markdown_file(file_path)
|
||||
|
||||
section = find_section_at_line(doc, 1)
|
||||
|
||||
assert section is not None
|
||||
assert section.title == "Main Title"
|
||||
|
||||
|
||||
def test_i_can_find_section_at_last_line(tmp_markdown_file):
|
||||
"""Test finding a section at its last line."""
|
||||
content = """# Section One
|
||||
Line 2
|
||||
Line 3
|
||||
|
||||
# Section Two
|
||||
Line 6"""
|
||||
|
||||
file_path = tmp_markdown_file(content)
|
||||
doc = parse_markdown_file(file_path)
|
||||
|
||||
section = find_section_at_line(doc, 3)
|
||||
|
||||
assert section is not None
|
||||
assert section.title == "Section One"
|
||||
|
||||
section = find_section_at_line(doc, 6)
|
||||
|
||||
assert section is not None
|
||||
assert section.title == "Section Two"
|
||||
|
||||
|
||||
def test_i_cannot_find_section_for_invalid_line_number(tmp_markdown_file):
|
||||
"""Test that invalid line numbers return None."""
|
||||
content = """# Title
|
||||
Content"""
|
||||
|
||||
file_path = tmp_markdown_file(content)
|
||||
doc = parse_markdown_file(file_path)
|
||||
|
||||
# Negative line number
|
||||
assert find_section_at_line(doc, -1) is None
|
||||
|
||||
# Zero line number
|
||||
assert find_section_at_line(doc, 0) is None
|
||||
|
||||
# Line number beyond file length
|
||||
assert find_section_at_line(doc, 1000) is None
|
||||
337
tests/test_searcher.py
Normal file
337
tests/test_searcher.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""
|
||||
Unit tests for the searcher module.
|
||||
"""
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from indexer import index_vault
|
||||
from searcher import search_vault, _parse_search_results, SearchResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_vault(tmp_path):
|
||||
"""
|
||||
Create a temporary vault with sample markdown files.
|
||||
"""
|
||||
vault_path = tmp_path / "test_vault"
|
||||
vault_path.mkdir()
|
||||
|
||||
# Create sample files
|
||||
file1 = vault_path / "python_basics.md"
|
||||
file1.write_text("""# Python Programming
|
||||
|
||||
Python is a high-level programming language known for its simplicity and readability.
|
||||
|
||||
## Variables and Data Types
|
||||
|
||||
In Python, you can create variables without declaring their type explicitly.
|
||||
Numbers, strings, and booleans are the basic data types.
|
||||
|
||||
## Functions
|
||||
|
||||
Functions in Python are defined using the def keyword.
|
||||
They help organize code into reusable blocks.
|
||||
""")
|
||||
|
||||
file2 = vault_path / "machine_learning.md"
|
||||
file2.write_text("""# Machine Learning
|
||||
|
||||
Machine learning is a subset of artificial intelligence.
|
||||
|
||||
## Supervised Learning
|
||||
|
||||
Supervised learning uses labeled data to train models.
|
||||
Common algorithms include linear regression and decision trees.
|
||||
|
||||
## Deep Learning
|
||||
|
||||
Deep learning uses neural networks with multiple layers.
|
||||
It's particularly effective for image and speech recognition.
|
||||
""")
|
||||
|
||||
file3 = vault_path / "cooking.md"
|
||||
file3.write_text("""# Italian Cuisine
|
||||
|
||||
Italian cooking emphasizes fresh ingredients and simple preparation.
|
||||
|
||||
## Pasta Dishes
|
||||
|
||||
Pasta is a staple of Italian cuisine.
|
||||
There are hundreds of pasta shapes and sauce combinations.
|
||||
|
||||
## Pizza Making
|
||||
|
||||
Traditional Italian pizza uses a thin crust and fresh mozzarella.
|
||||
""")
|
||||
|
||||
return vault_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def indexed_vault(temp_vault, tmp_path):
|
||||
"""
|
||||
Create and index a temporary vault.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_db"
|
||||
chroma_path.mkdir()
|
||||
|
||||
# Index the vault
|
||||
stats = index_vault(
|
||||
vault_path=str(temp_vault),
|
||||
chroma_db_path=str(chroma_path),
|
||||
collection_name="test_collection",
|
||||
)
|
||||
|
||||
return {
|
||||
"vault_path": temp_vault,
|
||||
"chroma_path": chroma_path,
|
||||
"collection_name": "test_collection",
|
||||
"stats": stats,
|
||||
}
|
||||
|
||||
|
||||
# Passing tests
|
||||
|
||||
|
||||
def test_i_can_search_vault_with_valid_query(indexed_vault):
|
||||
"""
|
||||
Test that a basic search returns valid results.
|
||||
"""
|
||||
results = search_vault(
|
||||
query="Python programming language",
|
||||
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||
collection_name=indexed_vault["collection_name"],
|
||||
)
|
||||
|
||||
# Should return results
|
||||
assert len(results) > 0
|
||||
|
||||
# All results should be SearchResult instances
|
||||
for result in results:
|
||||
assert isinstance(result, SearchResult)
|
||||
|
||||
# Check that all fields are present
|
||||
assert isinstance(result.file_path, str)
|
||||
assert isinstance(result.section_title, str)
|
||||
assert isinstance(result.line_start, int)
|
||||
assert isinstance(result.line_end, int)
|
||||
assert isinstance(result.score, float)
|
||||
assert isinstance(result.text, str)
|
||||
|
||||
# Scores should be between 0 and 1
|
||||
assert 0.0 <= result.score <= 1.0
|
||||
|
||||
# Results should be sorted by score (descending)
|
||||
scores = [r.score for r in results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
|
||||
def test_i_can_search_vault_with_limit_parameter(indexed_vault):
|
||||
"""
|
||||
Test that the limit parameter is respected.
|
||||
"""
|
||||
limit = 3
|
||||
results = search_vault(
|
||||
query="learning",
|
||||
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||
collection_name=indexed_vault["collection_name"],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# Should return at most 'limit' results
|
||||
assert len(results) <= limit
|
||||
|
||||
|
||||
def test_i_can_search_vault_with_min_score_filter(indexed_vault):
|
||||
"""
|
||||
Test that only results above min_score are returned.
|
||||
"""
|
||||
min_score = 0.5
|
||||
results = search_vault(
|
||||
query="Python",
|
||||
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||
collection_name=indexed_vault["collection_name"],
|
||||
min_score=min_score,
|
||||
)
|
||||
|
||||
# All results should have score >= min_score
|
||||
for result in results:
|
||||
assert result.score >= min_score
|
||||
|
||||
|
||||
def test_i_can_get_correct_metadata_in_results(indexed_vault):
|
||||
"""
|
||||
Test that metadata in results is correct.
|
||||
"""
|
||||
results = search_vault(
|
||||
query="Python programming",
|
||||
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||
collection_name=indexed_vault["collection_name"],
|
||||
limit=1,
|
||||
)
|
||||
|
||||
assert len(results) > 0
|
||||
top_result = results[0]
|
||||
|
||||
# Should find python_basics.md as most relevant
|
||||
assert "python_basics.md" in top_result.file_path
|
||||
|
||||
# Should have a section title
|
||||
assert len(top_result.section_title) > 0
|
||||
|
||||
# Line numbers should be positive
|
||||
assert top_result.line_start > 0
|
||||
assert top_result.line_end >= top_result.line_start
|
||||
|
||||
# Text should not be empty
|
||||
assert len(top_result.text) > 0
|
||||
|
||||
|
||||
def test_i_can_search_with_different_collection_name(temp_vault, tmp_path):
|
||||
"""
|
||||
Test that we can search in a collection with a custom name.
|
||||
"""
|
||||
chroma_path = tmp_path / "chroma_custom"
|
||||
chroma_path.mkdir()
|
||||
custom_collection = "my_custom_collection"
|
||||
|
||||
# Index with custom collection name
|
||||
index_vault(
|
||||
vault_path=str(temp_vault),
|
||||
chroma_db_path=str(chroma_path),
|
||||
collection_name=custom_collection,
|
||||
)
|
||||
|
||||
# Search with the same custom collection name
|
||||
results = search_vault(
|
||||
query="Python",
|
||||
chroma_db_path=str(chroma_path),
|
||||
collection_name=custom_collection,
|
||||
)
|
||||
|
||||
assert len(results) > 0
|
||||
|
||||
|
||||
def test_i_can_get_empty_results_when_no_match(indexed_vault):
|
||||
"""
|
||||
Test that a search with no matches returns an empty list.
|
||||
"""
|
||||
results = search_vault(
|
||||
query="quantum physics relativity theory",
|
||||
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||
collection_name=indexed_vault["collection_name"],
|
||||
min_score=0.9, # Very high threshold
|
||||
)
|
||||
|
||||
# Should return empty list, not raise exception
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
# Error tests
|
||||
|
||||
|
||||
def test_i_cannot_search_with_empty_query(indexed_vault):
|
||||
"""
|
||||
Test that an empty query raises ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="Query cannot be empty"):
|
||||
search_vault(
|
||||
query="",
|
||||
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||
collection_name=indexed_vault["collection_name"],
|
||||
)
|
||||
|
||||
|
||||
def test_i_cannot_search_nonexistent_collection(tmp_path):
|
||||
"""
|
||||
Test that searching a nonexistent collection raises ValueError.
|
||||
"""
|
||||
chroma_path = tmp_path / "empty_chroma"
|
||||
chroma_path.mkdir()
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
search_vault(
|
||||
query="test query",
|
||||
chroma_db_path=str(chroma_path),
|
||||
collection_name="nonexistent_collection",
|
||||
)
|
||||
|
||||
|
||||
def test_i_cannot_search_with_whitespace_only_query(indexed_vault):
|
||||
"""
|
||||
Test that a query with only whitespace raises ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="Query cannot be empty"):
|
||||
search_vault(
|
||||
query=" ",
|
||||
chroma_db_path=str(indexed_vault["chroma_path"]),
|
||||
collection_name=indexed_vault["collection_name"],
|
||||
)
|
||||
|
||||
|
||||
# Helper function tests
|
||||
|
||||
|
||||
def test_i_can_parse_search_results_correctly():
|
||||
"""
|
||||
Test that ChromaDB results are parsed correctly.
|
||||
"""
|
||||
# Mock ChromaDB query results
|
||||
raw_results = {
|
||||
"documents": [[
|
||||
"Python is a programming language",
|
||||
"Machine learning basics",
|
||||
]],
|
||||
"metadatas": [[
|
||||
{
|
||||
"file_path": "notes/python.md",
|
||||
"section_title": "Introduction",
|
||||
"line_start": 1,
|
||||
"line_end": 5,
|
||||
},
|
||||
{
|
||||
"file_path": "notes/ml.md",
|
||||
"section_title": "Overview",
|
||||
"line_start": 10,
|
||||
"line_end": 15,
|
||||
},
|
||||
]],
|
||||
"distances": [[0.2, 0.4]], # ChromaDB distances (lower = more similar)
|
||||
}
|
||||
|
||||
results = _parse_search_results(raw_results, min_score=0.0)
|
||||
|
||||
assert len(results) == 2
|
||||
|
||||
# Check first result
|
||||
assert results[0].file_path == "notes/python.md"
|
||||
assert results[0].section_title == "Introduction"
|
||||
assert results[0].line_start == 1
|
||||
assert results[0].line_end == 5
|
||||
assert results[0].text == "Python is a programming language"
|
||||
assert results[0].score == pytest.approx(0.8) # 1 - 0.2
|
||||
|
||||
# Check second result
|
||||
assert results[1].score == pytest.approx(0.6) # 1 - 0.4
|
||||
|
||||
|
||||
def test_i_can_filter_results_by_min_score():
|
||||
"""
|
||||
Test that results are filtered by min_score during parsing.
|
||||
"""
|
||||
raw_results = {
|
||||
"documents": [["text1", "text2", "text3"]],
|
||||
"metadatas": [[
|
||||
{"file_path": "a.md", "section_title": "A", "line_start": 1, "line_end": 2},
|
||||
{"file_path": "b.md", "section_title": "B", "line_start": 1, "line_end": 2},
|
||||
{"file_path": "c.md", "section_title": "C", "line_start": 1, "line_end": 2},
|
||||
]],
|
||||
"distances": [[0.1, 0.5, 0.8]], # Scores will be: 0.9, 0.5, 0.2
|
||||
}
|
||||
|
||||
results = _parse_search_results(raw_results, min_score=0.6)
|
||||
|
||||
# Only first result should pass (score 0.9 >= 0.6)
|
||||
assert len(results) == 1
|
||||
assert results[0].file_path == "a.md"
|
||||
assert results[0].score == pytest.approx(0.9)
|
||||
Reference in New Issue
Block a user