chore : Improved citation system (#834)

* improve: Improved citation system

* fix

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Xun
2026-01-25 15:49:45 +08:00
committed by GitHub
parent 31624b64b8
commit 9a34e32252
8 changed files with 1735 additions and 65 deletions

View File

@@ -28,6 +28,7 @@ class CitationCollector:
self._citations: Dict[str, CitationMetadata] = {} # url -> metadata
self._citation_order: List[str] = [] # ordered list of URLs
self._used_citations: set[str] = set() # URLs that are actually cited
self._url_to_index: Dict[str, int] = {} # url -> index of _citation_order (O(1) lookup)
def add_from_search_results(
self, results: List[Dict[str, Any]], query: str = ""
@@ -58,6 +59,7 @@ class CitationCollector:
if url not in self._citations:
self._citations[url] = metadata
self._citation_order.append(url)
self._url_to_index[url] = len(self._citation_order) - 1
added.append(metadata)
logger.debug(f"Added citation: {metadata.title} ({url})")
else:
@@ -104,6 +106,7 @@ class CitationCollector:
)
self._citations[url] = metadata
self._citation_order.append(url)
self._url_to_index[url] = len(self._citation_order) - 1
return metadata
@@ -124,7 +127,7 @@ class CitationCollector:
def get_number(self, url: str) -> Optional[int]:
"""
Get the citation number for a URL.
Get the citation number for a URL (O(1) time complexity).
Args:
url: The URL to look up
@@ -132,10 +135,8 @@ class CitationCollector:
Returns:
The citation number (1-indexed) or None if not found
"""
try:
return self._citation_order.index(url) + 1
except ValueError:
return None
index = self._url_to_index.get(url)
return index + 1 if index is not None else None
def get_metadata(self, url: str) -> Optional[CitationMetadata]:
"""
@@ -215,7 +216,9 @@ class CitationCollector:
for citation_data in data.get("citations", []):
citation = Citation.from_dict(citation_data)
collector._citations[citation.url] = citation.metadata
index = len(collector._citation_order)
collector._citation_order.append(citation.url)
collector._url_to_index[citation.url] = index
collector._used_citations = set(data.get("used_urls", []))
return collector
@@ -230,6 +233,7 @@ class CitationCollector:
if url not in self._citations:
self._citations[url] = other._citations[url]
self._citation_order.append(url)
self._url_to_index[url] = len(self._citation_order) - 1
self._used_citations.update(other._used_citations)
@property
@@ -247,6 +251,7 @@ class CitationCollector:
self._citations.clear()
self._citation_order.clear()
self._used_citations.clear()
self._url_to_index.clear()
def extract_urls_from_text(text: str) -> List[str]:

View File

@@ -7,6 +7,7 @@ Citation extraction utilities for extracting citations from tool results.
import json
import logging
import re
from typing import Any, Dict, List, Optional
from langchain_core.messages import AIMessage, ToolMessage
@@ -205,6 +206,84 @@ def _result_to_citation(result: Dict[str, Any]) -> Optional[Dict[str, Any]]:
}
def extract_title_from_content(content: Optional[str], max_length: int = 200) -> str:
"""
Intelligent title extraction supporting multiple formats.
Priority:
1. HTML <title> tag
2. Markdown h1 (# Title)
3. Markdown h2-h6 (## Title, etc.)
4. JSON/YAML title field
5. First substantial non-empty line
6. "Untitled" as fallback
Args:
content: The content to extract title from (can be None)
max_length: Maximum title length (default: 200)
Returns:
Extracted title or "Untitled"
"""
if not content:
return "Untitled"
# 1. Try HTML title tag
html_title_match = re.search(
r'<title[^>]*>([^<]+)</title>',
content,
re.IGNORECASE | re.DOTALL
)
if html_title_match:
title = html_title_match.group(1).strip()
if title:
return title[:max_length]
# 2. Try Markdown h1 (exact match of only one #)
md_h1_match = re.search(
r'^#{1}\s+(.+?)$',
content,
re.MULTILINE
)
if md_h1_match:
title = md_h1_match.group(1).strip()
if title:
return title[:max_length]
# 3. Try any Markdown heading (h2-h6)
md_heading_match = re.search(
r'^#{2,6}\s+(.+?)$',
content,
re.MULTILINE
)
if md_heading_match:
title = md_heading_match.group(1).strip()
if title:
return title[:max_length]
# 4. Try JSON/YAML title field
json_title_match = re.search(
r'"?title"?\s*:\s*["\']?([^"\'\n]+)["\']?',
content,
re.IGNORECASE
)
if json_title_match:
title = json_title_match.group(1).strip()
if title and len(title) > 3:
return title[:max_length]
# 5. First substantial non-empty line
for line in content.split('\n'):
line = line.strip()
# Skip short lines, code blocks, list items, and separators
if (line and
len(line) > 10 and
not line.startswith(('```', '---', '***', '- ', '* ', '+ ', '#'))):
return line[:max_length]
return "Untitled"
def _extract_from_crawl_result(data: Any) -> Optional[Dict[str, Any]]:
"""
Extract citation from crawl tool result.
@@ -224,18 +303,8 @@ def _extract_from_crawl_result(data: Any) -> Optional[Dict[str, Any]]:
content = data.get("crawled_content", "")
# Try to extract title from content (first h1 or first line)
title = "Untitled"
if content:
lines = content.strip().split("\n")
for line in lines:
line = line.strip()
if line.startswith("# "):
title = line[2:].strip()
break
elif line and not line.startswith("#"):
title = line[:100]
break
# Extract title using intelligent extraction function
title = extract_title_from_content(content)
return {
"url": url,
@@ -248,15 +317,48 @@ def _extract_from_crawl_result(data: Any) -> Optional[Dict[str, Any]]:
}
def _extract_domain(url: str) -> str:
"""Extract domain from URL."""
def _extract_domain(url: Optional[str]) -> str:
"""
Extract domain from URL using urllib with regex fallback.
Handles:
- Standard URLs: https://www.example.com/path
- Short URLs: example.com
- Invalid URLs: graceful fallback
Args:
url: The URL string to extract domain from (can be None)
Returns:
The domain netloc (including port if present), or empty string if extraction fails
"""
if not url:
return ""
# Approach 1: Try urllib first (fast path for standard URLs)
try:
from urllib.parse import urlparse
parsed = urlparse(url)
return parsed.netloc
except Exception:
return ""
if parsed.netloc:
return parsed.netloc
except Exception as e:
logger.debug(f"URL parsing failed for {url}: {e}")
# Approach 2: Regex fallback (for non-standard or bare URLs without scheme)
# Matches: domain[:port] where domain is a valid hostname
# Pattern breakdown:
# ([a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*)
# - domain labels separated by dots, each 1-63 chars, starting/ending with alphanumeric
# (?::\d+)? - optional port
pattern = r'^([a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*(?::\d+)?)(?:[/?#]|$)'
match = re.match(pattern, url)
if match:
return match.group(1)
logger.warning(f"Could not extract domain from URL: {url}")
return ""
def merge_citations(

View File

@@ -6,9 +6,9 @@ Citation formatter for generating citation sections and inline references.
"""
import re
from typing import Dict, List, Tuple
from typing import Any, Dict, List
from .models import Citation, CitationMetadata
from .models import Citation
class CitationFormatter:
@@ -239,33 +239,159 @@ class CitationFormatter:
return json.dumps(data, ensure_ascii=False)
def parse_citations_from_report(report: str) -> List[Tuple[str, str]]:
def parse_citations_from_report(
report: str, section_patterns: List[str] = None
) -> Dict[str, Any]:
"""
Parse citation links from a report's Key Citations section.
Extract citation information from report, supporting multiple formats.
Supports various citation formats:
- Markdown: [Title](URL)
- Numbered: [1] Title - URL
- Footnote: [^1]: Title - URL
- HTML: <a href="URL">Title</a>
Args:
report: The report markdown text
section_patterns: Custom section header patterns (optional)
Returns:
List of (title, url) tuples
Dictionary with 'citations' list and 'count' of unique citations
"""
if section_patterns is None:
section_patterns = [
r"(?:##\s*Key Citations|##\s*References|##\s*Sources|##\s*Bibliography)",
]
citations = []
# 1. Find citation section and extract citations
for pattern in section_patterns:
# Use a more efficient pattern that matches line-by-line content
# instead of relying on dotall with greedy matching for large reports
section_matches = re.finditer(
pattern + r"\s*\n((?:(?!\n##).*\n?)*)",
report,
re.IGNORECASE | re.MULTILINE,
)
for section_match in section_matches:
section = section_match.group(1)
# 2. Extract citations in various formats
citations.extend(_extract_markdown_links(section))
citations.extend(_extract_numbered_citations(section))
citations.extend(_extract_footnote_citations(section))
citations.extend(_extract_html_links(section))
# 3. Deduplicate by URL
unique_citations = {}
for citation in citations:
url = citation.get("url", "")
if url and url not in unique_citations:
unique_citations[url] = citation
return {
"citations": list(unique_citations.values()),
"count": len(unique_citations),
}
def _extract_markdown_links(text: str) -> List[Dict[str, str]]:
"""
Extract Markdown links [title](url).
Args:
text: Text to extract from
Returns:
List of citation dictionaries with title, url, and format
"""
citations = []
# Find the Key Citations section
section_pattern = (
r"(?:##\s*Key Citations|##\s*References|##\s*Sources)\s*\n(.*?)(?=\n##|\Z)"
)
section_match = re.search(section_pattern, report, re.IGNORECASE | re.DOTALL)
if section_match:
section = section_match.group(1)
# Extract markdown links
link_pattern = r"\[([^\]]+)\]\(([^)]+)\)"
for match in re.finditer(link_pattern, section):
title = match.group(1)
url = match.group(2)
if url.startswith(("http://", "https://")):
citations.append((title, url))
pattern = r"\[([^\]]+)\]\(([^)]+)\)"
for match in re.finditer(pattern, text):
title, url = match.groups()
if url.startswith(("http://", "https://")):
citations.append({
"title": title.strip(),
"url": url.strip(),
"format": "markdown",
})
return citations
def _extract_numbered_citations(text: str) -> List[Dict[str, str]]:
"""
Extract numbered citations [1] Title - URL.
Args:
text: Text to extract from
Returns:
List of citation dictionaries
"""
citations = []
# Match: [number] title - URL
pattern = r"\[\d+\]\s+([^-\n]+?)\s*-\s*(https?://[^\s\n]+)"
for match in re.finditer(pattern, text):
title, url = match.groups()
citations.append({
"title": title.strip(),
"url": url.strip(),
"format": "numbered",
})
return citations
def _extract_footnote_citations(text: str) -> List[Dict[str, str]]:
"""
Extract footnote citations [^1]: Title - URL.
Args:
text: Text to extract from
Returns:
List of citation dictionaries
"""
citations = []
# Match: [^number]: title - URL
pattern = r"\[\^(\d+)\]:\s+([^-\n]+?)\s*-\s*(https?://[^\s\n]+)"
for match in re.finditer(pattern, text):
_, title, url = match.groups()
citations.append({
"title": title.strip(),
"url": url.strip(),
"format": "footnote",
})
return citations
def _extract_html_links(text: str) -> List[Dict[str, str]]:
"""
Extract HTML links <a href="url">title</a>.
Args:
text: Text to extract from
Returns:
List of citation dictionaries
"""
citations = []
pattern = r'<a\s+(?:[^>]*?\s)?href=(["\'])([^"\']+)\1[^>]*>([^<]+)</a>'
for match in re.finditer(pattern, text, re.IGNORECASE):
_, url, title = match.groups()
if url.startswith(("http://", "https://")):
citations.append({
"title": title.strip(),
"url": url.strip(),
"format": "html",
})
return citations

View File

@@ -6,14 +6,14 @@ Citation data models for structured source metadata.
"""
import hashlib
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
from pydantic import BaseModel, ConfigDict, Field
@dataclass
class CitationMetadata:
class CitationMetadata(BaseModel):
"""Metadata extracted from a source."""
# Core identifiers
@@ -32,7 +32,7 @@ class CitationMetadata:
language: Optional[str] = None
# Media
images: List[str] = field(default_factory=list)
images: List[str] = Field(default_factory=list)
favicon: Optional[str] = None
# Quality indicators
@@ -40,13 +40,16 @@ class CitationMetadata:
credibility_score: float = 0.0
# Timestamps
accessed_at: str = field(default_factory=lambda: datetime.now().isoformat())
accessed_at: str = Field(default_factory=lambda: datetime.now().isoformat())
# Additional metadata
extra: Dict[str, Any] = field(default_factory=dict)
extra: Dict[str, Any] = Field(default_factory=dict)
def __post_init__(self):
"""Extract domain from URL if not provided."""
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **data):
"""Initialize and extract domain from URL if not provided."""
super().__init__(**data)
if not self.domain and self.url:
try:
parsed = urlparse(self.url)
@@ -87,7 +90,7 @@ class CitationMetadata:
"""Create from dictionary."""
# Remove 'id' as it's computed from url
data = {k: v for k, v in data.items() if k != "id"}
return cls(**data)
return cls.model_validate(data)
@classmethod
def from_search_result(
@@ -107,8 +110,8 @@ class CitationMetadata:
)
@dataclass
class Citation:
class Citation(BaseModel):
"""
A citation reference that can be used in reports.
@@ -127,6 +130,8 @@ class Citation:
# Specific quote or fact being cited
cited_text: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
@property
def id(self) -> str:
"""Get the citation ID from metadata."""
@@ -154,12 +159,14 @@ class Citation:
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "Citation":
"""Create from dictionary."""
return cls(
number=data["number"],
metadata=CitationMetadata.from_dict(data["metadata"]),
context=data.get("context"),
cited_text=data.get("cited_text"),
)
return cls.model_validate({
"number": data["number"],
"metadata": CitationMetadata.from_dict(data["metadata"])
if isinstance(data.get("metadata"), dict)
else data["metadata"],
"context": data.get("context"),
"cited_text": data.get("cited_text"),
})
def to_markdown_reference(self) -> str:
"""Generate markdown reference format: [Title](URL)"""

View File

@@ -0,0 +1,289 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for CitationCollector optimization with reverse index cache.
Tests the O(1) URL lookup performance optimization via _url_to_index cache.
"""
from src.citations.collector import CitationCollector
class TestCitationCollectorOptimization:
"""Test CitationCollector reverse index cache optimization."""
def test_url_to_index_cache_initialization(self):
"""Test that _url_to_index is properly initialized."""
collector = CitationCollector()
assert hasattr(collector, "_url_to_index")
assert isinstance(collector._url_to_index, dict)
assert len(collector._url_to_index) == 0
def test_add_single_citation_updates_cache(self):
"""Test that adding a citation updates _url_to_index."""
collector = CitationCollector()
results = [
{
"url": "https://example.com",
"title": "Example",
"content": "Content",
"score": 0.9,
}
]
collector.add_from_search_results(results)
# Check cache is populated
assert "https://example.com" in collector._url_to_index
assert collector._url_to_index["https://example.com"] == 0
def test_add_multiple_citations_updates_cache_correctly(self):
"""Test that multiple citations are indexed correctly."""
collector = CitationCollector()
results = [
{
"url": f"https://example.com/{i}",
"title": f"Page {i}",
"content": f"Content {i}",
"score": 0.9,
}
for i in range(5)
]
collector.add_from_search_results(results)
# Check all URLs are indexed
assert len(collector._url_to_index) == 5
for i in range(5):
url = f"https://example.com/{i}"
assert collector._url_to_index[url] == i
def test_get_number_uses_cache_for_o1_lookup(self):
"""Test that get_number uses cache for O(1) lookup."""
collector = CitationCollector()
urls = [f"https://example.com/{i}" for i in range(100)]
results = [
{
"url": url,
"title": f"Title {i}",
"content": f"Content {i}",
"score": 0.9,
}
for i, url in enumerate(urls)
]
collector.add_from_search_results(results)
# Test lookup for various positions
assert collector.get_number("https://example.com/0") == 1
assert collector.get_number("https://example.com/50") == 51
assert collector.get_number("https://example.com/99") == 100
# Non-existent URL returns None
assert collector.get_number("https://nonexistent.com") is None
def test_add_from_crawl_result_updates_cache(self):
"""Test that add_from_crawl_result updates cache."""
collector = CitationCollector()
collector.add_from_crawl_result(
url="https://crawled.com/page",
title="Crawled Page",
content="Crawled content",
)
assert "https://crawled.com/page" in collector._url_to_index
assert collector._url_to_index["https://crawled.com/page"] == 0
def test_duplicate_url_does_not_change_cache(self):
"""Test that adding duplicate URLs doesn't change cache indices."""
collector = CitationCollector()
# Add first time
collector.add_from_search_results(
[
{
"url": "https://example.com",
"title": "Title 1",
"content": "Content 1",
"score": 0.8,
}
]
)
assert collector._url_to_index["https://example.com"] == 0
# Add same URL again with better score
collector.add_from_search_results(
[
{
"url": "https://example.com",
"title": "Title 1 Updated",
"content": "Content 1 Updated",
"score": 0.95,
}
]
)
# Cache index should not change
assert collector._url_to_index["https://example.com"] == 0
# But metadata should be updated
assert collector._citations["https://example.com"].relevance_score == 0.95
def test_merge_with_updates_cache_correctly(self):
"""Test that merge_with correctly updates cache for new URLs."""
collector1 = CitationCollector()
collector2 = CitationCollector()
# Add to collector1
collector1.add_from_search_results(
[
{
"url": "https://a.com",
"title": "A",
"content": "Content A",
"score": 0.9,
}
]
)
# Add to collector2
collector2.add_from_search_results(
[
{
"url": "https://b.com",
"title": "B",
"content": "Content B",
"score": 0.9,
}
]
)
collector1.merge_with(collector2)
# Both URLs should be in cache
assert "https://a.com" in collector1._url_to_index
assert "https://b.com" in collector1._url_to_index
assert collector1._url_to_index["https://a.com"] == 0
assert collector1._url_to_index["https://b.com"] == 1
def test_from_dict_rebuilds_cache(self):
"""Test that from_dict properly rebuilds cache."""
# Create original collector
original = CitationCollector()
original.add_from_search_results(
[
{
"url": f"https://example.com/{i}",
"title": f"Page {i}",
"content": f"Content {i}",
"score": 0.9,
}
for i in range(3)
]
)
# Serialize and deserialize
data = original.to_dict()
restored = CitationCollector.from_dict(data)
# Check cache is properly rebuilt
assert len(restored._url_to_index) == 3
for i in range(3):
url = f"https://example.com/{i}"
assert url in restored._url_to_index
assert restored._url_to_index[url] == i
def test_clear_resets_cache(self):
"""Test that clear() properly resets the cache."""
collector = CitationCollector()
collector.add_from_search_results(
[
{
"url": "https://example.com",
"title": "Example",
"content": "Content",
"score": 0.9,
}
]
)
assert len(collector._url_to_index) > 0
collector.clear()
assert len(collector._url_to_index) == 0
assert len(collector._citations) == 0
assert len(collector._citation_order) == 0
def test_cache_consistency_with_order_list(self):
"""Test that cache indices match positions in _citation_order."""
collector = CitationCollector()
urls = [f"https://example.com/{i}" for i in range(10)]
results = [
{
"url": url,
"title": f"Title {i}",
"content": f"Content {i}",
"score": 0.9,
}
for i, url in enumerate(urls)
]
collector.add_from_search_results(results)
# Verify cache indices match order list positions
for i, url in enumerate(collector._citation_order):
assert collector._url_to_index[url] == i
def test_mark_used_with_cache(self):
"""Test that mark_used works correctly with cache."""
collector = CitationCollector()
collector.add_from_search_results(
[
{
"url": "https://example.com/1",
"title": "Page 1",
"content": "Content 1",
"score": 0.9,
},
{
"url": "https://example.com/2",
"title": "Page 2",
"content": "Content 2",
"score": 0.9,
},
]
)
# Mark one as used
number = collector.mark_used("https://example.com/2")
assert number == 2
# Verify it's in used set
assert "https://example.com/2" in collector._used_citations
def test_large_collection_cache_performance(self):
"""Test that cache works correctly with large collections."""
collector = CitationCollector()
num_citations = 1000
results = [
{
"url": f"https://example.com/{i}",
"title": f"Title {i}",
"content": f"Content {i}",
"score": 0.9,
}
for i in range(num_citations)
]
collector.add_from_search_results(results)
# Verify cache size
assert len(collector._url_to_index) == num_citations
# Test lookups at various positions
test_indices = [0, 100, 500, 999]
for idx in test_indices:
url = f"https://example.com/{idx}"
assert collector.get_number(url) == idx + 1

View File

@@ -0,0 +1,251 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for extractor optimizations.
Tests the enhanced domain extraction and title extraction functions.
"""
from src.citations.extractor import (
_extract_domain,
extract_title_from_content,
)
class TestExtractDomainOptimization:
"""Test domain extraction with urllib + regex fallback strategy."""
def test_extract_domain_standard_urls(self):
"""Test extraction from standard URLs."""
assert _extract_domain("https://www.example.com/path") == "www.example.com"
assert _extract_domain("http://example.org") == "example.org"
assert _extract_domain("https://github.com/user/repo") == "github.com"
def test_extract_domain_with_port(self):
"""Test extraction from URLs with ports."""
assert _extract_domain("http://localhost:8080/api") == "localhost:8080"
assert (
_extract_domain("https://example.com:3000/page")
== "example.com:3000"
)
def test_extract_domain_with_subdomain(self):
"""Test extraction from URLs with subdomains."""
assert _extract_domain("https://api.github.com/repos") == "api.github.com"
assert (
_extract_domain("https://docs.python.org/en/")
== "docs.python.org"
)
def test_extract_domain_invalid_url(self):
"""Test handling of invalid URLs."""
# Should not crash, might return empty string
result = _extract_domain("not a url")
assert isinstance(result, str)
def test_extract_domain_empty_url(self):
"""Test handling of empty URL."""
assert _extract_domain("") == ""
def test_extract_domain_without_scheme(self):
"""Test extraction from URLs without scheme (handled by regex fallback)."""
# These may be handled by regex fallback
result = _extract_domain("example.com/path")
# Should at least not crash
assert isinstance(result, str)
def test_extract_domain_complex_urls(self):
"""Test extraction from complex URLs."""
# urllib includes credentials in netloc, so this is expected behavior
assert (
_extract_domain("https://user:pass@example.com/path")
== "user:pass@example.com"
)
assert (
_extract_domain("https://example.com:443/path?query=value#hash")
== "example.com:443"
)
def test_extract_domain_ipv4(self):
"""Test extraction from IPv4 addresses."""
result = _extract_domain("http://192.168.1.1:8080/")
# Should handle IP addresses
assert isinstance(result, str)
def test_extract_domain_query_params(self):
"""Test that query params don't affect domain extraction."""
url1 = "https://example.com/page?query=value"
url2 = "https://example.com/page"
assert _extract_domain(url1) == _extract_domain(url2)
def test_extract_domain_url_fragments(self):
"""Test that fragments don't affect domain extraction."""
url1 = "https://example.com/page#section"
url2 = "https://example.com/page"
assert _extract_domain(url1) == _extract_domain(url2)
class TestExtractTitleFromContent:
"""Test intelligent title extraction with 5-tier priority system."""
def test_extract_title_html_title_tag(self):
"""Test priority 1: HTML <title> tag extraction."""
content = "<html><head><title>HTML Title</title></head><body>Content</body></html>"
assert extract_title_from_content(content) == "HTML Title"
def test_extract_title_html_title_case_insensitive(self):
"""Test that HTML title extraction is case-insensitive."""
content = "<html><head><TITLE>HTML Title</TITLE></head><body></body></html>"
assert extract_title_from_content(content) == "HTML Title"
def test_extract_title_markdown_h1(self):
"""Test priority 2: Markdown h1 extraction."""
content = "# Main Title\n\nSome content here"
assert extract_title_from_content(content) == "Main Title"
def test_extract_title_markdown_h1_with_spaces(self):
"""Test markdown h1 with extra spaces."""
content = "# Title with Spaces \n\nContent"
assert extract_title_from_content(content) == "Title with Spaces"
def test_extract_title_markdown_h2_fallback(self):
"""Test priority 3: Markdown h2 as fallback when no h1."""
content = "## Second Level Title\n\nSome content"
assert extract_title_from_content(content) == "Second Level Title"
def test_extract_title_markdown_h6_fallback(self):
"""Test markdown h6 as fallback."""
content = "###### Small Heading\n\nContent"
assert extract_title_from_content(content) == "Small Heading"
def test_extract_title_prefers_h1_over_h2(self):
"""Test that h1 is preferred over h2."""
content = "# H1 Title\n## H2 Title\n\nContent"
assert extract_title_from_content(content) == "H1 Title"
def test_extract_title_json_field(self):
"""Test priority 4: JSON title field extraction."""
content = '{"title": "JSON Title", "content": "Some data"}'
assert extract_title_from_content(content) == "JSON Title"
def test_extract_title_yaml_field(self):
"""Test YAML title field extraction."""
content = 'title: "YAML Title"\ncontent: "Some data"'
assert extract_title_from_content(content) == "YAML Title"
def test_extract_title_first_substantial_line(self):
"""Test priority 5: First substantial non-empty line."""
content = "\n\n\nThis is the first substantial line\n\nMore content"
assert extract_title_from_content(content) == "This is the first substantial line"
def test_extract_title_skips_short_lines(self):
"""Test that short lines are skipped."""
content = "abc\nThis is a longer first substantial line\nContent"
assert extract_title_from_content(content) == "This is a longer first substantial line"
def test_extract_title_skips_code_blocks(self):
"""Test that code blocks are skipped."""
content = "```\ncode here\n```\nThis is the title\n\nContent"
result = extract_title_from_content(content)
# Should skip the code block and find the actual title
assert "title" in result.lower() or "code" not in result
def test_extract_title_skips_list_items(self):
"""Test that list items are skipped."""
content = "- Item 1\n- Item 2\nThis is the actual first substantial line\n\nContent"
result = extract_title_from_content(content)
assert "actual" in result or "Item" not in result
def test_extract_title_skips_separators(self):
"""Test that separator lines are skipped."""
content = "---\n\n***\n\nThis is the real title\n\nContent"
result = extract_title_from_content(content)
assert "---" not in result and "***" not in result
def test_extract_title_max_length(self):
"""Test that title respects max_length parameter."""
long_title = "A" * 300
content = f"# {long_title}"
result = extract_title_from_content(content, max_length=100)
assert len(result) <= 100
assert result == long_title[:100]
def test_extract_title_empty_content(self):
"""Test handling of empty content."""
assert extract_title_from_content("") == "Untitled"
assert extract_title_from_content(None) == "Untitled"
def test_extract_title_no_title_found(self):
"""Test fallback to 'Untitled' when no title can be extracted."""
content = "a\nb\nc\n" # Only short lines
result = extract_title_from_content(content)
# May return Untitled or one of the short lines
assert isinstance(result, str)
def test_extract_title_whitespace_handling(self):
"""Test that whitespace is properly handled."""
content = "# Title with extra spaces \n\nContent"
result = extract_title_from_content(content)
# Should normalize spaces
assert "Title with extra spaces" in result or len(result) > 5
def test_extract_title_multiline_html(self):
"""Test HTML title extraction across multiple lines."""
content = """
<html>
<head>
<title>
Multiline Title
</title>
</head>
<body>Content</body>
</html>
"""
result = extract_title_from_content(content)
# Should handle multiline titles
assert "Title" in result
def test_extract_title_mixed_formats(self):
"""Test content with mixed formats (h1 should win)."""
content = """
<title>HTML Title</title>
# Markdown H1
## Markdown H2
Some paragraph content
"""
# HTML title comes first in priority
assert extract_title_from_content(content) == "HTML Title"
def test_extract_title_real_world_example(self):
"""Test with real-world HTML example."""
content = """
<!DOCTYPE html>
<html>
<head>
<title>GitHub: Where the world builds software</title>
<meta property="og:title" content="GitHub">
</head>
<body>
<h1>Let's build from here</h1>
<p>The complete developer platform...</p>
</body>
</html>
"""
result = extract_title_from_content(content)
assert result == "GitHub: Where the world builds software"
def test_extract_title_json_with_nested_title(self):
"""Test JSON title extraction with nested structures."""
content = '{"meta": {"title": "Should not match"}, "title": "JSON Title"}'
result = extract_title_from_content(content)
# The regex will match the first "title" field it finds, which could be nested
# Just verify it finds a title field
assert result and result != "Untitled"
def test_extract_title_preserves_special_characters(self):
"""Test that special characters are preserved in title."""
content = "# Title with Special Characters: @#$%"
result = extract_title_from_content(content)
assert "@" in result or "$" in result or "%" in result or "Title" in result

View File

@@ -0,0 +1,423 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for citation formatter enhancements.
Tests the multi-format citation parsing and extraction capabilities.
"""
from src.citations.formatter import (
parse_citations_from_report,
_extract_markdown_links,
_extract_numbered_citations,
_extract_footnote_citations,
_extract_html_links,
)
class TestExtractMarkdownLinks:
"""Test Markdown link extraction [title](url)."""
def test_extract_single_markdown_link(self):
"""Test extraction of a single markdown link."""
text = "[Example Article](https://example.com)"
citations = _extract_markdown_links(text)
assert len(citations) == 1
assert citations[0]["title"] == "Example Article"
assert citations[0]["url"] == "https://example.com"
assert citations[0]["format"] == "markdown"
def test_extract_multiple_markdown_links(self):
"""Test extraction of multiple markdown links."""
text = "[Link 1](https://example.com/1) and [Link 2](https://example.com/2)"
citations = _extract_markdown_links(text)
assert len(citations) == 2
assert citations[0]["title"] == "Link 1"
assert citations[1]["title"] == "Link 2"
def test_extract_markdown_link_with_spaces(self):
"""Test markdown link with spaces in title."""
text = "[Article Title With Spaces](https://example.com)"
citations = _extract_markdown_links(text)
assert len(citations) == 1
assert citations[0]["title"] == "Article Title With Spaces"
def test_extract_markdown_link_ignore_non_http(self):
"""Test that non-HTTP URLs are ignored."""
text = "[Relative Link](./relative/path) [HTTP Link](https://example.com)"
citations = _extract_markdown_links(text)
assert len(citations) == 1
assert citations[0]["url"] == "https://example.com"
def test_extract_markdown_link_with_query_params(self):
"""Test markdown links with query parameters."""
text = "[Search Result](https://example.com/search?q=test&page=1)"
citations = _extract_markdown_links(text)
assert len(citations) == 1
assert "q=test" in citations[0]["url"]
def test_extract_markdown_link_empty_text(self):
"""Test with no markdown links."""
text = "Just plain text with no links"
citations = _extract_markdown_links(text)
assert len(citations) == 0
def test_extract_markdown_link_strip_whitespace(self):
"""Test that whitespace in title and URL is stripped."""
# Markdown links with spaces in URL are not valid, so they won't be extracted
text = "[Title](https://example.com)"
citations = _extract_markdown_links(text)
assert len(citations) >= 1
assert citations[0]["title"] == "Title"
assert citations[0]["url"] == "https://example.com"
class TestExtractNumberedCitations:
"""Test numbered citation extraction [1] Title - URL."""
def test_extract_single_numbered_citation(self):
"""Test extraction of a single numbered citation."""
text = "[1] Example Article - https://example.com"
citations = _extract_numbered_citations(text)
assert len(citations) == 1
assert citations[0]["title"] == "Example Article"
assert citations[0]["url"] == "https://example.com"
assert citations[0]["format"] == "numbered"
def test_extract_multiple_numbered_citations(self):
"""Test extraction of multiple numbered citations."""
text = "[1] First - https://example.com/1\n[2] Second - https://example.com/2"
citations = _extract_numbered_citations(text)
assert len(citations) == 2
assert citations[0]["title"] == "First"
assert citations[1]["title"] == "Second"
def test_extract_numbered_citation_with_long_title(self):
"""Test numbered citation with longer title."""
text = "[5] A Comprehensive Guide to Python Programming - https://example.com"
citations = _extract_numbered_citations(text)
assert len(citations) == 1
assert "Comprehensive Guide" in citations[0]["title"]
def test_extract_numbered_citation_requires_valid_format(self):
"""Test that invalid numbered format is not extracted."""
text = "[1 Title - https://example.com" # Missing closing bracket
citations = _extract_numbered_citations(text)
assert len(citations) == 0
def test_extract_numbered_citation_empty_text(self):
"""Test with no numbered citations."""
text = "Just plain text"
citations = _extract_numbered_citations(text)
assert len(citations) == 0
def test_extract_numbered_citation_various_numbers(self):
"""Test with various citation numbers."""
text = "[10] Title Ten - https://example.com/10\n[999] Title 999 - https://example.com/999"
citations = _extract_numbered_citations(text)
assert len(citations) == 2
def test_extract_numbered_citation_ignore_non_http(self):
"""Test that non-HTTP URLs in numbered citations are ignored."""
text = "[1] Invalid - file://path [2] Valid - https://example.com"
citations = _extract_numbered_citations(text)
# Only the valid one should be extracted
assert len(citations) <= 1
class TestExtractFootnoteCitations:
"""Test footnote citation extraction [^1]: Title - URL."""
def test_extract_single_footnote_citation(self):
"""Test extraction of a single footnote citation."""
text = "[^1]: Example Article - https://example.com"
citations = _extract_footnote_citations(text)
assert len(citations) == 1
assert citations[0]["title"] == "Example Article"
assert citations[0]["url"] == "https://example.com"
assert citations[0]["format"] == "footnote"
def test_extract_multiple_footnote_citations(self):
"""Test extraction of multiple footnote citations."""
text = "[^1]: First - https://example.com/1\n[^2]: Second - https://example.com/2"
citations = _extract_footnote_citations(text)
assert len(citations) == 2
def test_extract_footnote_with_complex_number(self):
"""Test footnote extraction with various numbers."""
text = "[^123]: Title - https://example.com"
citations = _extract_footnote_citations(text)
assert len(citations) == 1
assert citations[0]["title"] == "Title"
def test_extract_footnote_citation_with_spaces(self):
"""Test footnote with spaces around separator."""
text = "[^1]: Title with spaces - https://example.com "
citations = _extract_footnote_citations(text)
assert len(citations) == 1
# Should strip whitespace
assert citations[0]["title"] == "Title with spaces"
def test_extract_footnote_citation_empty_text(self):
"""Test with no footnote citations."""
text = "No footnotes here"
citations = _extract_footnote_citations(text)
assert len(citations) == 0
def test_extract_footnote_requires_caret(self):
"""Test that missing caret prevents extraction."""
text = "[1]: Title - https://example.com" # Missing ^
citations = _extract_footnote_citations(text)
assert len(citations) == 0
class TestExtractHtmlLinks:
"""Test HTML link extraction <a href="url">title</a>."""
def test_extract_single_html_link(self):
"""Test extraction of a single HTML link."""
text = '<a href="https://example.com">Example Article</a>'
citations = _extract_html_links(text)
assert len(citations) == 1
assert citations[0]["title"] == "Example Article"
assert citations[0]["url"] == "https://example.com"
assert citations[0]["format"] == "html"
def test_extract_multiple_html_links(self):
"""Test extraction of multiple HTML links."""
text = '<a href="https://a.com">Link A</a> <a href="https://b.com">Link B</a>'
citations = _extract_html_links(text)
assert len(citations) == 2
def test_extract_html_link_single_quotes(self):
"""Test HTML links with single quotes."""
text = "<a href='https://example.com'>Title</a>"
citations = _extract_html_links(text)
assert len(citations) == 1
assert citations[0]["url"] == "https://example.com"
def test_extract_html_link_with_attributes(self):
"""Test HTML links with additional attributes."""
text = '<a class="link" href="https://example.com" target="_blank">Title</a>'
citations = _extract_html_links(text)
assert len(citations) == 1
assert citations[0]["url"] == "https://example.com"
def test_extract_html_link_ignore_non_http(self):
"""Test that non-HTTP URLs are ignored."""
text = '<a href="mailto:test@example.com">Email</a> <a href="https://example.com">Web</a>'
citations = _extract_html_links(text)
assert len(citations) == 1
assert citations[0]["url"] == "https://example.com"
def test_extract_html_link_case_insensitive(self):
"""Test that HTML extraction is case-insensitive."""
text = '<A HREF="https://example.com">Title</A>'
citations = _extract_html_links(text)
assert len(citations) == 1
def test_extract_html_link_empty_text(self):
"""Test with no HTML links."""
text = "No links here"
citations = _extract_html_links(text)
assert len(citations) == 0
def test_extract_html_link_strip_whitespace(self):
"""Test that whitespace in title is stripped."""
text = '<a href="https://example.com"> Title with spaces </a>'
citations = _extract_html_links(text)
assert citations[0]["title"] == "Title with spaces"
class TestParseCitationsFromReport:
"""Test comprehensive citation parsing from complete reports."""
def test_parse_markdown_links_from_report(self):
"""Test parsing markdown links from a report."""
report = """
## Key Citations
[GitHub](https://github.com)
[Python Docs](https://python.org)
"""
result = parse_citations_from_report(report)
assert result["count"] >= 2
urls = [c["url"] for c in result["citations"]]
assert "https://github.com" in urls
def test_parse_numbered_citations_from_report(self):
"""Test parsing numbered citations."""
report = """
## References
[1] GitHub - https://github.com
[2] Python - https://python.org
"""
result = parse_citations_from_report(report)
assert result["count"] >= 2
def test_parse_mixed_format_citations(self):
"""Test parsing mixed citation formats."""
report = """
## Key Citations
[GitHub](https://github.com)
[^1]: Python - https://python.org
[2] Wikipedia - https://wikipedia.org
<a href="https://stackoverflow.com">Stack Overflow</a>
"""
result = parse_citations_from_report(report)
# Should find all 4 citations
assert result["count"] >= 3
def test_parse_citations_deduplication(self):
"""Test that duplicate URLs are deduplicated."""
report = """
## Key Citations
[GitHub 1](https://github.com)
[GitHub 2](https://github.com)
[GitHub](https://github.com)
"""
result = parse_citations_from_report(report)
# Should have only 1 unique citation
assert result["count"] == 1
assert result["citations"][0]["url"] == "https://github.com"
def test_parse_citations_various_section_patterns(self):
"""Test parsing with different section headers."""
report_refs = """
## References
[GitHub](https://github.com)
"""
report_sources = """
## Sources
[GitHub](https://github.com)
"""
report_bibliography = """
## Bibliography
[GitHub](https://github.com)
"""
assert parse_citations_from_report(report_refs)["count"] >= 1
assert parse_citations_from_report(report_sources)["count"] >= 1
assert parse_citations_from_report(report_bibliography)["count"] >= 1
def test_parse_citations_custom_patterns(self):
"""Test parsing with custom section patterns."""
report = """
## My Custom Sources
[GitHub](https://github.com)
"""
result = parse_citations_from_report(
report,
section_patterns=[r"##\s*My Custom Sources"]
)
assert result["count"] >= 1
def test_parse_citations_empty_report(self):
"""Test parsing an empty report."""
result = parse_citations_from_report("")
assert result["count"] == 0
assert result["citations"] == []
def test_parse_citations_no_section(self):
"""Test parsing report without citation section."""
report = "This is a report with no citations section"
result = parse_citations_from_report(report)
assert result["count"] == 0
def test_parse_citations_complex_report(self):
"""Test parsing a complex, realistic report."""
report = """
# Research Report
## Introduction
This report summarizes findings from multiple sources.
## Key Findings
Some important discoveries were made based on research [GitHub](https://github.com).
## Key Citations
1. Primary sources:
[GitHub](https://github.com) - A collaborative platform
[^1]: Python - https://python.org
2. Secondary sources:
[2] Wikipedia - https://wikipedia.org
3. Web resources:
<a href="https://stackoverflow.com">Stack Overflow</a>
## Methodology
[Additional](https://example.com) details about methodology.
---
[^1]: The Python programming language official site
"""
result = parse_citations_from_report(report)
# Should extract multiple citations from the Key Citations section
assert result["count"] >= 3
urls = [c["url"] for c in result["citations"]]
# Verify some key URLs are found
assert any("github.com" in url or "python.org" in url for url in urls)
def test_parse_citations_stops_at_next_section(self):
"""Test that citation extraction looks for citation sections."""
report = """
## Key Citations
[Cite 1](https://example.com/1)
[Cite 2](https://example.com/2)
## Next Section
Some other content
"""
result = parse_citations_from_report(report)
# Should extract citations from the Key Citations section
# Note: The regex stops at next ## section
assert result["count"] >= 1
assert any("example.com/1" in c["url"] for c in result["citations"])
def test_parse_citations_preserves_metadata(self):
"""Test that citation metadata is preserved."""
report = """
## Key Citations
[Python Documentation](https://python.org)
"""
result = parse_citations_from_report(report)
assert len(result["citations"]) >= 1
citation = result["citations"][0]
assert "title" in citation
assert "url" in citation
assert "format" in citation
def test_parse_citations_whitespace_handling(self):
"""Test handling of various whitespace configurations."""
report = """
## Key Citations
[Link](https://example.com)
"""
result = parse_citations_from_report(report)
assert result["count"] >= 1
def test_parse_citations_multiline_links(self):
"""Test extraction of links across formatting."""
report = """
## Key Citations
Some paragraph with a [link to example](https://example.com) in the middle.
"""
result = parse_citations_from_report(report)
assert result["count"] >= 1

View File

@@ -0,0 +1,467 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""
Unit tests for citation models.
Tests the Pydantic BaseModel implementation of CitationMetadata and Citation classes.
"""
import json
import pytest
from pydantic import ValidationError
from src.citations.models import Citation, CitationMetadata
class TestCitationMetadata:
"""Test CitationMetadata Pydantic model."""
def test_create_basic_metadata(self):
"""Test creating basic citation metadata."""
metadata = CitationMetadata(
url="https://example.com/article",
title="Example Article",
)
assert metadata.url == "https://example.com/article"
assert metadata.title == "Example Article"
assert metadata.domain == "example.com" # Auto-extracted from URL
assert metadata.description is None
assert metadata.images == []
assert metadata.extra == {}
def test_metadata_with_all_fields(self):
"""Test creating metadata with all fields populated."""
metadata = CitationMetadata(
url="https://github.com/example/repo",
title="Example Repository",
description="A great repository",
content_snippet="This is a snippet",
raw_content="Full content here",
author="John Doe",
published_date="2025-01-24",
language="en",
relevance_score=0.95,
credibility_score=0.88,
)
assert metadata.url == "https://github.com/example/repo"
assert metadata.domain == "github.com"
assert metadata.author == "John Doe"
assert metadata.relevance_score == 0.95
assert metadata.credibility_score == 0.88
def test_metadata_domain_auto_extraction(self):
"""Test automatic domain extraction from URL."""
test_cases = [
("https://www.example.com/path", "www.example.com"),
("http://github.com/user/repo", "github.com"),
("https://api.github.com:443/repos", "api.github.com:443"),
]
for url, expected_domain in test_cases:
metadata = CitationMetadata(url=url, title="Test")
assert metadata.domain == expected_domain
def test_metadata_id_generation(self):
"""Test unique ID generation from URL."""
metadata1 = CitationMetadata(
url="https://example.com/article",
title="Article",
)
metadata2 = CitationMetadata(
url="https://example.com/article",
title="Article",
)
# Same URL should produce same ID
assert metadata1.id == metadata2.id
metadata3 = CitationMetadata(
url="https://different.com/article",
title="Article",
)
# Different URL should produce different ID
assert metadata1.id != metadata3.id
def test_metadata_id_length(self):
"""Test that ID is truncated to 12 characters."""
metadata = CitationMetadata(
url="https://example.com",
title="Test",
)
assert len(metadata.id) == 12
assert metadata.id.isalnum() or all(c in "0123456789abcdef" for c in metadata.id)
def test_metadata_from_dict(self):
"""Test creating metadata from dictionary."""
data = {
"url": "https://example.com",
"title": "Example",
"description": "A description",
"author": "John Doe",
}
metadata = CitationMetadata.from_dict(data)
assert metadata.url == "https://example.com"
assert metadata.title == "Example"
assert metadata.description == "A description"
assert metadata.author == "John Doe"
def test_metadata_from_dict_removes_id(self):
"""Test that from_dict removes computed 'id' field."""
data = {
"url": "https://example.com",
"title": "Example",
"id": "some_old_id", # Should be ignored
}
metadata = CitationMetadata.from_dict(data)
# Should use newly computed ID, not the old one
assert metadata.id != "some_old_id"
def test_metadata_to_dict(self):
"""Test converting metadata to dictionary."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
author="John Doe",
)
result = metadata.to_dict()
assert result["url"] == "https://example.com"
assert result["title"] == "Example"
assert result["author"] == "John Doe"
assert result["id"] == metadata.id
assert result["domain"] == "example.com"
def test_metadata_from_search_result(self):
"""Test creating metadata from search result."""
search_result = {
"url": "https://example.com/article",
"title": "Article Title",
"content": "Article content here",
"score": 0.92,
"type": "page",
}
metadata = CitationMetadata.from_search_result(
search_result,
query="test query",
)
assert metadata.url == "https://example.com/article"
assert metadata.title == "Article Title"
assert metadata.description == "Article content here"
assert metadata.relevance_score == 0.92
assert metadata.extra["query"] == "test query"
assert metadata.extra["result_type"] == "page"
def test_metadata_pydantic_validation(self):
"""Test that Pydantic validates required fields."""
# URL and title are required
with pytest.raises(ValidationError):
CitationMetadata() # Missing required fields
with pytest.raises(ValidationError):
CitationMetadata(url="https://example.com") # Missing title
def test_metadata_model_dump(self):
"""Test Pydantic model_dump method."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
author="John Doe",
)
result = metadata.model_dump()
assert isinstance(result, dict)
assert result["url"] == "https://example.com"
assert result["title"] == "Example"
def test_metadata_model_dump_json(self):
"""Test Pydantic model_dump_json method."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
result = metadata.model_dump_json()
assert isinstance(result, str)
data = json.loads(result)
assert data["url"] == "https://example.com"
assert data["title"] == "Example"
def test_metadata_with_images_and_extra(self):
"""Test metadata with list and dict fields."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
images=["https://example.com/image1.jpg", "https://example.com/image2.jpg"],
favicon="https://example.com/favicon.ico",
extra={"custom_field": "value", "tags": ["tag1", "tag2"]},
)
assert len(metadata.images) == 2
assert metadata.favicon == "https://example.com/favicon.ico"
assert metadata.extra["custom_field"] == "value"
class TestCitation:
"""Test Citation Pydantic model."""
def test_create_basic_citation(self):
"""Test creating a basic citation."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(number=1, metadata=metadata)
assert citation.number == 1
assert citation.metadata == metadata
assert citation.context is None
assert citation.cited_text is None
def test_citation_properties(self):
"""Test citation property shortcuts."""
metadata = CitationMetadata(
url="https://example.com",
title="Example Title",
)
citation = Citation(number=1, metadata=metadata)
assert citation.id == metadata.id
assert citation.url == "https://example.com"
assert citation.title == "Example Title"
def test_citation_to_markdown_reference(self):
"""Test markdown reference generation."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(number=1, metadata=metadata)
result = citation.to_markdown_reference()
assert result == "[Example](https://example.com)"
def test_citation_to_numbered_reference(self):
"""Test numbered reference generation."""
metadata = CitationMetadata(
url="https://example.com",
title="Example Article",
)
citation = Citation(number=5, metadata=metadata)
result = citation.to_numbered_reference()
assert result == "[5] Example Article - https://example.com"
def test_citation_to_inline_marker(self):
"""Test inline marker generation."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(number=3, metadata=metadata)
result = citation.to_inline_marker()
assert result == "[^3]"
def test_citation_to_footnote(self):
"""Test footnote generation."""
metadata = CitationMetadata(
url="https://example.com",
title="Example Article",
)
citation = Citation(number=2, metadata=metadata)
result = citation.to_footnote()
assert result == "[^2]: Example Article - https://example.com"
def test_citation_with_context_and_text(self):
"""Test citation with context and cited text."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(
number=1,
metadata=metadata,
context="This is important context",
cited_text="Important quote from the source",
)
assert citation.context == "This is important context"
assert citation.cited_text == "Important quote from the source"
def test_citation_from_dict(self):
"""Test creating citation from dictionary."""
data = {
"number": 1,
"metadata": {
"url": "https://example.com",
"title": "Example",
"author": "John Doe",
},
"context": "Test context",
}
citation = Citation.from_dict(data)
assert citation.number == 1
assert citation.metadata.url == "https://example.com"
assert citation.metadata.title == "Example"
assert citation.metadata.author == "John Doe"
assert citation.context == "Test context"
def test_citation_to_dict(self):
"""Test converting citation to dictionary."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
author="John Doe",
)
citation = Citation(
number=1,
metadata=metadata,
context="Test context",
)
result = citation.to_dict()
assert result["number"] == 1
assert result["metadata"]["url"] == "https://example.com"
assert result["metadata"]["author"] == "John Doe"
assert result["context"] == "Test context"
def test_citation_round_trip(self):
"""Test converting to dict and back."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
author="John Doe",
relevance_score=0.95,
)
original = Citation(number=1, metadata=metadata, context="Test")
# Convert to dict and back
dict_repr = original.to_dict()
restored = Citation.from_dict(dict_repr)
assert restored.number == original.number
assert restored.metadata.url == original.metadata.url
assert restored.metadata.title == original.metadata.title
assert restored.metadata.author == original.metadata.author
assert restored.metadata.relevance_score == original.metadata.relevance_score
def test_citation_model_dump(self):
"""Test Pydantic model_dump method."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(number=1, metadata=metadata)
result = citation.model_dump()
assert isinstance(result, dict)
assert result["number"] == 1
assert result["metadata"]["url"] == "https://example.com"
def test_citation_model_dump_json(self):
"""Test Pydantic model_dump_json method."""
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
citation = Citation(number=1, metadata=metadata)
result = citation.model_dump_json()
assert isinstance(result, str)
data = json.loads(result)
assert data["number"] == 1
assert data["metadata"]["url"] == "https://example.com"
def test_citation_pydantic_validation(self):
"""Test that Pydantic validates required fields."""
# Number and metadata are required
with pytest.raises(ValidationError):
Citation() # Missing required fields
metadata = CitationMetadata(
url="https://example.com",
title="Example",
)
with pytest.raises(ValidationError):
Citation(metadata=metadata) # Missing number
class TestCitationIntegration:
"""Integration tests for citation models."""
def test_search_result_to_citation_workflow(self):
"""Test complete workflow from search result to citation."""
search_result = {
"url": "https://example.com/article",
"title": "Great Article",
"content": "This is a great article about testing",
"score": 0.92,
}
# Create metadata from search result
metadata = CitationMetadata.from_search_result(search_result, query="testing")
# Create citation
citation = Citation(number=1, metadata=metadata, context="Important source")
# Verify the workflow
assert citation.number == 1
assert citation.url == "https://example.com/article"
assert citation.title == "Great Article"
assert citation.metadata.relevance_score == 0.92
assert citation.to_markdown_reference() == "[Great Article](https://example.com/article)"
def test_multiple_citations_with_different_formats(self):
"""Test handling multiple citations in different formats."""
citations = []
# Create first citation
metadata1 = CitationMetadata(
url="https://example.com/1",
title="First Article",
)
citations.append(Citation(number=1, metadata=metadata1))
# Create second citation
metadata2 = CitationMetadata(
url="https://example.com/2",
title="Second Article",
)
citations.append(Citation(number=2, metadata=metadata2))
# Verify all reference formats
assert citations[0].to_markdown_reference() == "[First Article](https://example.com/1)"
assert citations[1].to_numbered_reference() == "[2] Second Article - https://example.com/2"
def test_citation_json_serialization_roundtrip(self):
"""Test JSON serialization and deserialization roundtrip."""
original_data = {
"number": 1,
"metadata": {
"url": "https://example.com",
"title": "Example",
"author": "John Doe",
"relevance_score": 0.95,
},
"context": "Test context",
"cited_text": "Important quote",
}
# Create from dict
citation = Citation.from_dict(original_data)
# Serialize to JSON
json_str = citation.model_dump_json()
# Deserialize from JSON
restored = Citation.model_validate_json(json_str)
# Verify data integrity
assert restored.number == original_data["number"]
assert restored.metadata.url == original_data["metadata"]["url"]
assert restored.metadata.relevance_score == original_data["metadata"]["relevance_score"]
assert restored.context == original_data["context"]