mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-03 06:12:14 +08:00
chore : Improved citation system (#834)
* improve: Improved citation system * fix --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -28,6 +28,7 @@ class CitationCollector:
|
||||
self._citations: Dict[str, CitationMetadata] = {} # url -> metadata
|
||||
self._citation_order: List[str] = [] # ordered list of URLs
|
||||
self._used_citations: set[str] = set() # URLs that are actually cited
|
||||
self._url_to_index: Dict[str, int] = {} # url -> index of _citation_order (O(1) lookup)
|
||||
|
||||
def add_from_search_results(
|
||||
self, results: List[Dict[str, Any]], query: str = ""
|
||||
@@ -58,6 +59,7 @@ class CitationCollector:
|
||||
if url not in self._citations:
|
||||
self._citations[url] = metadata
|
||||
self._citation_order.append(url)
|
||||
self._url_to_index[url] = len(self._citation_order) - 1
|
||||
added.append(metadata)
|
||||
logger.debug(f"Added citation: {metadata.title} ({url})")
|
||||
else:
|
||||
@@ -104,6 +106,7 @@ class CitationCollector:
|
||||
)
|
||||
self._citations[url] = metadata
|
||||
self._citation_order.append(url)
|
||||
self._url_to_index[url] = len(self._citation_order) - 1
|
||||
|
||||
return metadata
|
||||
|
||||
@@ -124,7 +127,7 @@ class CitationCollector:
|
||||
|
||||
def get_number(self, url: str) -> Optional[int]:
|
||||
"""
|
||||
Get the citation number for a URL.
|
||||
Get the citation number for a URL (O(1) time complexity).
|
||||
|
||||
Args:
|
||||
url: The URL to look up
|
||||
@@ -132,10 +135,8 @@ class CitationCollector:
|
||||
Returns:
|
||||
The citation number (1-indexed) or None if not found
|
||||
"""
|
||||
try:
|
||||
return self._citation_order.index(url) + 1
|
||||
except ValueError:
|
||||
return None
|
||||
index = self._url_to_index.get(url)
|
||||
return index + 1 if index is not None else None
|
||||
|
||||
def get_metadata(self, url: str) -> Optional[CitationMetadata]:
|
||||
"""
|
||||
@@ -215,7 +216,9 @@ class CitationCollector:
|
||||
for citation_data in data.get("citations", []):
|
||||
citation = Citation.from_dict(citation_data)
|
||||
collector._citations[citation.url] = citation.metadata
|
||||
index = len(collector._citation_order)
|
||||
collector._citation_order.append(citation.url)
|
||||
collector._url_to_index[citation.url] = index
|
||||
collector._used_citations = set(data.get("used_urls", []))
|
||||
return collector
|
||||
|
||||
@@ -230,6 +233,7 @@ class CitationCollector:
|
||||
if url not in self._citations:
|
||||
self._citations[url] = other._citations[url]
|
||||
self._citation_order.append(url)
|
||||
self._url_to_index[url] = len(self._citation_order) - 1
|
||||
self._used_citations.update(other._used_citations)
|
||||
|
||||
@property
|
||||
@@ -247,6 +251,7 @@ class CitationCollector:
|
||||
self._citations.clear()
|
||||
self._citation_order.clear()
|
||||
self._used_citations.clear()
|
||||
self._url_to_index.clear()
|
||||
|
||||
|
||||
def extract_urls_from_text(text: str) -> List[str]:
|
||||
|
||||
@@ -7,6 +7,7 @@ Citation extraction utilities for extracting citations from tool results.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
@@ -205,6 +206,84 @@ def _result_to_citation(result: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
}
|
||||
|
||||
|
||||
def extract_title_from_content(content: Optional[str], max_length: int = 200) -> str:
|
||||
"""
|
||||
Intelligent title extraction supporting multiple formats.
|
||||
|
||||
Priority:
|
||||
1. HTML <title> tag
|
||||
2. Markdown h1 (# Title)
|
||||
3. Markdown h2-h6 (## Title, etc.)
|
||||
4. JSON/YAML title field
|
||||
5. First substantial non-empty line
|
||||
6. "Untitled" as fallback
|
||||
|
||||
Args:
|
||||
content: The content to extract title from (can be None)
|
||||
max_length: Maximum title length (default: 200)
|
||||
|
||||
Returns:
|
||||
Extracted title or "Untitled"
|
||||
"""
|
||||
if not content:
|
||||
return "Untitled"
|
||||
|
||||
# 1. Try HTML title tag
|
||||
html_title_match = re.search(
|
||||
r'<title[^>]*>([^<]+)</title>',
|
||||
content,
|
||||
re.IGNORECASE | re.DOTALL
|
||||
)
|
||||
if html_title_match:
|
||||
title = html_title_match.group(1).strip()
|
||||
if title:
|
||||
return title[:max_length]
|
||||
|
||||
# 2. Try Markdown h1 (exact match of only one #)
|
||||
md_h1_match = re.search(
|
||||
r'^#{1}\s+(.+?)$',
|
||||
content,
|
||||
re.MULTILINE
|
||||
)
|
||||
if md_h1_match:
|
||||
title = md_h1_match.group(1).strip()
|
||||
if title:
|
||||
return title[:max_length]
|
||||
|
||||
# 3. Try any Markdown heading (h2-h6)
|
||||
md_heading_match = re.search(
|
||||
r'^#{2,6}\s+(.+?)$',
|
||||
content,
|
||||
re.MULTILINE
|
||||
)
|
||||
if md_heading_match:
|
||||
title = md_heading_match.group(1).strip()
|
||||
if title:
|
||||
return title[:max_length]
|
||||
|
||||
# 4. Try JSON/YAML title field
|
||||
json_title_match = re.search(
|
||||
r'"?title"?\s*:\s*["\']?([^"\'\n]+)["\']?',
|
||||
content,
|
||||
re.IGNORECASE
|
||||
)
|
||||
if json_title_match:
|
||||
title = json_title_match.group(1).strip()
|
||||
if title and len(title) > 3:
|
||||
return title[:max_length]
|
||||
|
||||
# 5. First substantial non-empty line
|
||||
for line in content.split('\n'):
|
||||
line = line.strip()
|
||||
# Skip short lines, code blocks, list items, and separators
|
||||
if (line and
|
||||
len(line) > 10 and
|
||||
not line.startswith(('```', '---', '***', '- ', '* ', '+ ', '#'))):
|
||||
return line[:max_length]
|
||||
|
||||
return "Untitled"
|
||||
|
||||
|
||||
def _extract_from_crawl_result(data: Any) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract citation from crawl tool result.
|
||||
@@ -224,18 +303,8 @@ def _extract_from_crawl_result(data: Any) -> Optional[Dict[str, Any]]:
|
||||
|
||||
content = data.get("crawled_content", "")
|
||||
|
||||
# Try to extract title from content (first h1 or first line)
|
||||
title = "Untitled"
|
||||
if content:
|
||||
lines = content.strip().split("\n")
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("# "):
|
||||
title = line[2:].strip()
|
||||
break
|
||||
elif line and not line.startswith("#"):
|
||||
title = line[:100]
|
||||
break
|
||||
# Extract title using intelligent extraction function
|
||||
title = extract_title_from_content(content)
|
||||
|
||||
return {
|
||||
"url": url,
|
||||
@@ -248,15 +317,48 @@ def _extract_from_crawl_result(data: Any) -> Optional[Dict[str, Any]]:
|
||||
}
|
||||
|
||||
|
||||
def _extract_domain(url: str) -> str:
|
||||
"""Extract domain from URL."""
|
||||
def _extract_domain(url: Optional[str]) -> str:
|
||||
"""
|
||||
Extract domain from URL using urllib with regex fallback.
|
||||
|
||||
Handles:
|
||||
- Standard URLs: https://www.example.com/path
|
||||
- Short URLs: example.com
|
||||
- Invalid URLs: graceful fallback
|
||||
|
||||
Args:
|
||||
url: The URL string to extract domain from (can be None)
|
||||
|
||||
Returns:
|
||||
The domain netloc (including port if present), or empty string if extraction fails
|
||||
"""
|
||||
if not url:
|
||||
return ""
|
||||
|
||||
# Approach 1: Try urllib first (fast path for standard URLs)
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
parsed = urlparse(url)
|
||||
return parsed.netloc
|
||||
except Exception:
|
||||
return ""
|
||||
if parsed.netloc:
|
||||
return parsed.netloc
|
||||
except Exception as e:
|
||||
logger.debug(f"URL parsing failed for {url}: {e}")
|
||||
|
||||
# Approach 2: Regex fallback (for non-standard or bare URLs without scheme)
|
||||
# Matches: domain[:port] where domain is a valid hostname
|
||||
# Pattern breakdown:
|
||||
# ([a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*)
|
||||
# - domain labels separated by dots, each 1-63 chars, starting/ending with alphanumeric
|
||||
# (?::\d+)? - optional port
|
||||
pattern = r'^([a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*(?::\d+)?)(?:[/?#]|$)'
|
||||
|
||||
match = re.match(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
logger.warning(f"Could not extract domain from URL: {url}")
|
||||
return ""
|
||||
|
||||
|
||||
def merge_citations(
|
||||
|
||||
@@ -6,9 +6,9 @@ Citation formatter for generating citation sections and inline references.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .models import Citation, CitationMetadata
|
||||
from .models import Citation
|
||||
|
||||
|
||||
class CitationFormatter:
|
||||
@@ -239,33 +239,159 @@ class CitationFormatter:
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
|
||||
|
||||
def parse_citations_from_report(report: str) -> List[Tuple[str, str]]:
|
||||
def parse_citations_from_report(
|
||||
report: str, section_patterns: List[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse citation links from a report's Key Citations section.
|
||||
|
||||
Extract citation information from report, supporting multiple formats.
|
||||
|
||||
Supports various citation formats:
|
||||
- Markdown: [Title](URL)
|
||||
- Numbered: [1] Title - URL
|
||||
- Footnote: [^1]: Title - URL
|
||||
- HTML: <a href="URL">Title</a>
|
||||
|
||||
Args:
|
||||
report: The report markdown text
|
||||
|
||||
section_patterns: Custom section header patterns (optional)
|
||||
|
||||
Returns:
|
||||
List of (title, url) tuples
|
||||
Dictionary with 'citations' list and 'count' of unique citations
|
||||
"""
|
||||
if section_patterns is None:
|
||||
section_patterns = [
|
||||
r"(?:##\s*Key Citations|##\s*References|##\s*Sources|##\s*Bibliography)",
|
||||
]
|
||||
|
||||
citations = []
|
||||
|
||||
# 1. Find citation section and extract citations
|
||||
for pattern in section_patterns:
|
||||
# Use a more efficient pattern that matches line-by-line content
|
||||
# instead of relying on dotall with greedy matching for large reports
|
||||
section_matches = re.finditer(
|
||||
pattern + r"\s*\n((?:(?!\n##).*\n?)*)",
|
||||
report,
|
||||
re.IGNORECASE | re.MULTILINE,
|
||||
)
|
||||
|
||||
for section_match in section_matches:
|
||||
section = section_match.group(1)
|
||||
|
||||
# 2. Extract citations in various formats
|
||||
citations.extend(_extract_markdown_links(section))
|
||||
citations.extend(_extract_numbered_citations(section))
|
||||
citations.extend(_extract_footnote_citations(section))
|
||||
citations.extend(_extract_html_links(section))
|
||||
|
||||
# 3. Deduplicate by URL
|
||||
unique_citations = {}
|
||||
for citation in citations:
|
||||
url = citation.get("url", "")
|
||||
if url and url not in unique_citations:
|
||||
unique_citations[url] = citation
|
||||
|
||||
return {
|
||||
"citations": list(unique_citations.values()),
|
||||
"count": len(unique_citations),
|
||||
}
|
||||
|
||||
|
||||
def _extract_markdown_links(text: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Extract Markdown links [title](url).
|
||||
|
||||
Args:
|
||||
text: Text to extract from
|
||||
|
||||
Returns:
|
||||
List of citation dictionaries with title, url, and format
|
||||
"""
|
||||
citations = []
|
||||
|
||||
# Find the Key Citations section
|
||||
section_pattern = (
|
||||
r"(?:##\s*Key Citations|##\s*References|##\s*Sources)\s*\n(.*?)(?=\n##|\Z)"
|
||||
)
|
||||
section_match = re.search(section_pattern, report, re.IGNORECASE | re.DOTALL)
|
||||
|
||||
if section_match:
|
||||
section = section_match.group(1)
|
||||
|
||||
# Extract markdown links
|
||||
link_pattern = r"\[([^\]]+)\]\(([^)]+)\)"
|
||||
for match in re.finditer(link_pattern, section):
|
||||
title = match.group(1)
|
||||
url = match.group(2)
|
||||
if url.startswith(("http://", "https://")):
|
||||
citations.append((title, url))
|
||||
|
||||
pattern = r"\[([^\]]+)\]\(([^)]+)\)"
|
||||
|
||||
for match in re.finditer(pattern, text):
|
||||
title, url = match.groups()
|
||||
if url.startswith(("http://", "https://")):
|
||||
citations.append({
|
||||
"title": title.strip(),
|
||||
"url": url.strip(),
|
||||
"format": "markdown",
|
||||
})
|
||||
|
||||
return citations
|
||||
|
||||
|
||||
def _extract_numbered_citations(text: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Extract numbered citations [1] Title - URL.
|
||||
|
||||
Args:
|
||||
text: Text to extract from
|
||||
|
||||
Returns:
|
||||
List of citation dictionaries
|
||||
"""
|
||||
citations = []
|
||||
# Match: [number] title - URL
|
||||
pattern = r"\[\d+\]\s+([^-\n]+?)\s*-\s*(https?://[^\s\n]+)"
|
||||
|
||||
for match in re.finditer(pattern, text):
|
||||
title, url = match.groups()
|
||||
citations.append({
|
||||
"title": title.strip(),
|
||||
"url": url.strip(),
|
||||
"format": "numbered",
|
||||
})
|
||||
|
||||
return citations
|
||||
|
||||
|
||||
def _extract_footnote_citations(text: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Extract footnote citations [^1]: Title - URL.
|
||||
|
||||
Args:
|
||||
text: Text to extract from
|
||||
|
||||
Returns:
|
||||
List of citation dictionaries
|
||||
"""
|
||||
citations = []
|
||||
# Match: [^number]: title - URL
|
||||
pattern = r"\[\^(\d+)\]:\s+([^-\n]+?)\s*-\s*(https?://[^\s\n]+)"
|
||||
|
||||
for match in re.finditer(pattern, text):
|
||||
_, title, url = match.groups()
|
||||
citations.append({
|
||||
"title": title.strip(),
|
||||
"url": url.strip(),
|
||||
"format": "footnote",
|
||||
})
|
||||
|
||||
return citations
|
||||
|
||||
|
||||
def _extract_html_links(text: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Extract HTML links <a href="url">title</a>.
|
||||
|
||||
Args:
|
||||
text: Text to extract from
|
||||
|
||||
Returns:
|
||||
List of citation dictionaries
|
||||
"""
|
||||
citations = []
|
||||
pattern = r'<a\s+(?:[^>]*?\s)?href=(["\'])([^"\']+)\1[^>]*>([^<]+)</a>'
|
||||
|
||||
for match in re.finditer(pattern, text, re.IGNORECASE):
|
||||
_, url, title = match.groups()
|
||||
if url.startswith(("http://", "https://")):
|
||||
citations.append({
|
||||
"title": title.strip(),
|
||||
"url": url.strip(),
|
||||
"format": "html",
|
||||
})
|
||||
|
||||
return citations
|
||||
|
||||
@@ -6,14 +6,14 @@ Citation data models for structured source metadata.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
@dataclass
|
||||
class CitationMetadata:
|
||||
|
||||
class CitationMetadata(BaseModel):
|
||||
"""Metadata extracted from a source."""
|
||||
|
||||
# Core identifiers
|
||||
@@ -32,7 +32,7 @@ class CitationMetadata:
|
||||
language: Optional[str] = None
|
||||
|
||||
# Media
|
||||
images: List[str] = field(default_factory=list)
|
||||
images: List[str] = Field(default_factory=list)
|
||||
favicon: Optional[str] = None
|
||||
|
||||
# Quality indicators
|
||||
@@ -40,13 +40,16 @@ class CitationMetadata:
|
||||
credibility_score: float = 0.0
|
||||
|
||||
# Timestamps
|
||||
accessed_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
accessed_at: str = Field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
# Additional metadata
|
||||
extra: Dict[str, Any] = field(default_factory=dict)
|
||||
extra: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Extract domain from URL if not provided."""
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, **data):
|
||||
"""Initialize and extract domain from URL if not provided."""
|
||||
super().__init__(**data)
|
||||
if not self.domain and self.url:
|
||||
try:
|
||||
parsed = urlparse(self.url)
|
||||
@@ -87,7 +90,7 @@ class CitationMetadata:
|
||||
"""Create from dictionary."""
|
||||
# Remove 'id' as it's computed from url
|
||||
data = {k: v for k, v in data.items() if k != "id"}
|
||||
return cls(**data)
|
||||
return cls.model_validate(data)
|
||||
|
||||
@classmethod
|
||||
def from_search_result(
|
||||
@@ -107,8 +110,8 @@ class CitationMetadata:
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Citation:
|
||||
|
||||
class Citation(BaseModel):
|
||||
"""
|
||||
A citation reference that can be used in reports.
|
||||
|
||||
@@ -127,6 +130,8 @@ class Citation:
|
||||
# Specific quote or fact being cited
|
||||
cited_text: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""Get the citation ID from metadata."""
|
||||
@@ -154,12 +159,14 @@ class Citation:
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> "Citation":
|
||||
"""Create from dictionary."""
|
||||
return cls(
|
||||
number=data["number"],
|
||||
metadata=CitationMetadata.from_dict(data["metadata"]),
|
||||
context=data.get("context"),
|
||||
cited_text=data.get("cited_text"),
|
||||
)
|
||||
return cls.model_validate({
|
||||
"number": data["number"],
|
||||
"metadata": CitationMetadata.from_dict(data["metadata"])
|
||||
if isinstance(data.get("metadata"), dict)
|
||||
else data["metadata"],
|
||||
"context": data.get("context"),
|
||||
"cited_text": data.get("cited_text"),
|
||||
})
|
||||
|
||||
def to_markdown_reference(self) -> str:
|
||||
"""Generate markdown reference format: [Title](URL)"""
|
||||
|
||||
289
tests/unit/citations/test_collector.py
Normal file
289
tests/unit/citations/test_collector.py
Normal file
@@ -0,0 +1,289 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for CitationCollector optimization with reverse index cache.
|
||||
|
||||
Tests the O(1) URL lookup performance optimization via _url_to_index cache.
|
||||
"""
|
||||
|
||||
from src.citations.collector import CitationCollector
|
||||
|
||||
|
||||
class TestCitationCollectorOptimization:
|
||||
"""Test CitationCollector reverse index cache optimization."""
|
||||
|
||||
def test_url_to_index_cache_initialization(self):
|
||||
"""Test that _url_to_index is properly initialized."""
|
||||
collector = CitationCollector()
|
||||
assert hasattr(collector, "_url_to_index")
|
||||
assert isinstance(collector._url_to_index, dict)
|
||||
assert len(collector._url_to_index) == 0
|
||||
|
||||
def test_add_single_citation_updates_cache(self):
|
||||
"""Test that adding a citation updates _url_to_index."""
|
||||
collector = CitationCollector()
|
||||
results = [
|
||||
{
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"content": "Content",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
|
||||
collector.add_from_search_results(results)
|
||||
|
||||
# Check cache is populated
|
||||
assert "https://example.com" in collector._url_to_index
|
||||
assert collector._url_to_index["https://example.com"] == 0
|
||||
|
||||
def test_add_multiple_citations_updates_cache_correctly(self):
|
||||
"""Test that multiple citations are indexed correctly."""
|
||||
collector = CitationCollector()
|
||||
results = [
|
||||
{
|
||||
"url": f"https://example.com/{i}",
|
||||
"title": f"Page {i}",
|
||||
"content": f"Content {i}",
|
||||
"score": 0.9,
|
||||
}
|
||||
for i in range(5)
|
||||
]
|
||||
|
||||
collector.add_from_search_results(results)
|
||||
|
||||
# Check all URLs are indexed
|
||||
assert len(collector._url_to_index) == 5
|
||||
for i in range(5):
|
||||
url = f"https://example.com/{i}"
|
||||
assert collector._url_to_index[url] == i
|
||||
|
||||
def test_get_number_uses_cache_for_o1_lookup(self):
|
||||
"""Test that get_number uses cache for O(1) lookup."""
|
||||
collector = CitationCollector()
|
||||
urls = [f"https://example.com/{i}" for i in range(100)]
|
||||
results = [
|
||||
{
|
||||
"url": url,
|
||||
"title": f"Title {i}",
|
||||
"content": f"Content {i}",
|
||||
"score": 0.9,
|
||||
}
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
collector.add_from_search_results(results)
|
||||
|
||||
# Test lookup for various positions
|
||||
assert collector.get_number("https://example.com/0") == 1
|
||||
assert collector.get_number("https://example.com/50") == 51
|
||||
assert collector.get_number("https://example.com/99") == 100
|
||||
|
||||
# Non-existent URL returns None
|
||||
assert collector.get_number("https://nonexistent.com") is None
|
||||
|
||||
def test_add_from_crawl_result_updates_cache(self):
|
||||
"""Test that add_from_crawl_result updates cache."""
|
||||
collector = CitationCollector()
|
||||
|
||||
collector.add_from_crawl_result(
|
||||
url="https://crawled.com/page",
|
||||
title="Crawled Page",
|
||||
content="Crawled content",
|
||||
)
|
||||
|
||||
assert "https://crawled.com/page" in collector._url_to_index
|
||||
assert collector._url_to_index["https://crawled.com/page"] == 0
|
||||
|
||||
def test_duplicate_url_does_not_change_cache(self):
|
||||
"""Test that adding duplicate URLs doesn't change cache indices."""
|
||||
collector = CitationCollector()
|
||||
|
||||
# Add first time
|
||||
collector.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://example.com",
|
||||
"title": "Title 1",
|
||||
"content": "Content 1",
|
||||
"score": 0.8,
|
||||
}
|
||||
]
|
||||
)
|
||||
assert collector._url_to_index["https://example.com"] == 0
|
||||
|
||||
# Add same URL again with better score
|
||||
collector.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://example.com",
|
||||
"title": "Title 1 Updated",
|
||||
"content": "Content 1 Updated",
|
||||
"score": 0.95,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Cache index should not change
|
||||
assert collector._url_to_index["https://example.com"] == 0
|
||||
# But metadata should be updated
|
||||
assert collector._citations["https://example.com"].relevance_score == 0.95
|
||||
|
||||
def test_merge_with_updates_cache_correctly(self):
|
||||
"""Test that merge_with correctly updates cache for new URLs."""
|
||||
collector1 = CitationCollector()
|
||||
collector2 = CitationCollector()
|
||||
|
||||
# Add to collector1
|
||||
collector1.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://a.com",
|
||||
"title": "A",
|
||||
"content": "Content A",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# Add to collector2
|
||||
collector2.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://b.com",
|
||||
"title": "B",
|
||||
"content": "Content B",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
collector1.merge_with(collector2)
|
||||
|
||||
# Both URLs should be in cache
|
||||
assert "https://a.com" in collector1._url_to_index
|
||||
assert "https://b.com" in collector1._url_to_index
|
||||
assert collector1._url_to_index["https://a.com"] == 0
|
||||
assert collector1._url_to_index["https://b.com"] == 1
|
||||
|
||||
def test_from_dict_rebuilds_cache(self):
|
||||
"""Test that from_dict properly rebuilds cache."""
|
||||
# Create original collector
|
||||
original = CitationCollector()
|
||||
original.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": f"https://example.com/{i}",
|
||||
"title": f"Page {i}",
|
||||
"content": f"Content {i}",
|
||||
"score": 0.9,
|
||||
}
|
||||
for i in range(3)
|
||||
]
|
||||
)
|
||||
|
||||
# Serialize and deserialize
|
||||
data = original.to_dict()
|
||||
restored = CitationCollector.from_dict(data)
|
||||
|
||||
# Check cache is properly rebuilt
|
||||
assert len(restored._url_to_index) == 3
|
||||
for i in range(3):
|
||||
url = f"https://example.com/{i}"
|
||||
assert url in restored._url_to_index
|
||||
assert restored._url_to_index[url] == i
|
||||
|
||||
def test_clear_resets_cache(self):
|
||||
"""Test that clear() properly resets the cache."""
|
||||
collector = CitationCollector()
|
||||
collector.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"content": "Content",
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert len(collector._url_to_index) > 0
|
||||
|
||||
collector.clear()
|
||||
|
||||
assert len(collector._url_to_index) == 0
|
||||
assert len(collector._citations) == 0
|
||||
assert len(collector._citation_order) == 0
|
||||
|
||||
def test_cache_consistency_with_order_list(self):
|
||||
"""Test that cache indices match positions in _citation_order."""
|
||||
collector = CitationCollector()
|
||||
urls = [f"https://example.com/{i}" for i in range(10)]
|
||||
results = [
|
||||
{
|
||||
"url": url,
|
||||
"title": f"Title {i}",
|
||||
"content": f"Content {i}",
|
||||
"score": 0.9,
|
||||
}
|
||||
for i, url in enumerate(urls)
|
||||
]
|
||||
|
||||
collector.add_from_search_results(results)
|
||||
|
||||
# Verify cache indices match order list positions
|
||||
for i, url in enumerate(collector._citation_order):
|
||||
assert collector._url_to_index[url] == i
|
||||
|
||||
def test_mark_used_with_cache(self):
|
||||
"""Test that mark_used works correctly with cache."""
|
||||
collector = CitationCollector()
|
||||
collector.add_from_search_results(
|
||||
[
|
||||
{
|
||||
"url": "https://example.com/1",
|
||||
"title": "Page 1",
|
||||
"content": "Content 1",
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"url": "https://example.com/2",
|
||||
"title": "Page 2",
|
||||
"content": "Content 2",
|
||||
"score": 0.9,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# Mark one as used
|
||||
number = collector.mark_used("https://example.com/2")
|
||||
assert number == 2
|
||||
|
||||
# Verify it's in used set
|
||||
assert "https://example.com/2" in collector._used_citations
|
||||
|
||||
def test_large_collection_cache_performance(self):
|
||||
"""Test that cache works correctly with large collections."""
|
||||
collector = CitationCollector()
|
||||
num_citations = 1000
|
||||
results = [
|
||||
{
|
||||
"url": f"https://example.com/{i}",
|
||||
"title": f"Title {i}",
|
||||
"content": f"Content {i}",
|
||||
"score": 0.9,
|
||||
}
|
||||
for i in range(num_citations)
|
||||
]
|
||||
|
||||
collector.add_from_search_results(results)
|
||||
|
||||
# Verify cache size
|
||||
assert len(collector._url_to_index) == num_citations
|
||||
|
||||
# Test lookups at various positions
|
||||
test_indices = [0, 100, 500, 999]
|
||||
for idx in test_indices:
|
||||
url = f"https://example.com/{idx}"
|
||||
assert collector.get_number(url) == idx + 1
|
||||
251
tests/unit/citations/test_extractor.py
Normal file
251
tests/unit/citations/test_extractor.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for extractor optimizations.
|
||||
|
||||
Tests the enhanced domain extraction and title extraction functions.
|
||||
"""
|
||||
|
||||
from src.citations.extractor import (
|
||||
_extract_domain,
|
||||
extract_title_from_content,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractDomainOptimization:
|
||||
"""Test domain extraction with urllib + regex fallback strategy."""
|
||||
|
||||
def test_extract_domain_standard_urls(self):
|
||||
"""Test extraction from standard URLs."""
|
||||
assert _extract_domain("https://www.example.com/path") == "www.example.com"
|
||||
assert _extract_domain("http://example.org") == "example.org"
|
||||
assert _extract_domain("https://github.com/user/repo") == "github.com"
|
||||
|
||||
def test_extract_domain_with_port(self):
|
||||
"""Test extraction from URLs with ports."""
|
||||
assert _extract_domain("http://localhost:8080/api") == "localhost:8080"
|
||||
assert (
|
||||
_extract_domain("https://example.com:3000/page")
|
||||
== "example.com:3000"
|
||||
)
|
||||
|
||||
def test_extract_domain_with_subdomain(self):
|
||||
"""Test extraction from URLs with subdomains."""
|
||||
assert _extract_domain("https://api.github.com/repos") == "api.github.com"
|
||||
assert (
|
||||
_extract_domain("https://docs.python.org/en/")
|
||||
== "docs.python.org"
|
||||
)
|
||||
|
||||
def test_extract_domain_invalid_url(self):
|
||||
"""Test handling of invalid URLs."""
|
||||
# Should not crash, might return empty string
|
||||
result = _extract_domain("not a url")
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_extract_domain_empty_url(self):
|
||||
"""Test handling of empty URL."""
|
||||
assert _extract_domain("") == ""
|
||||
|
||||
def test_extract_domain_without_scheme(self):
|
||||
"""Test extraction from URLs without scheme (handled by regex fallback)."""
|
||||
# These may be handled by regex fallback
|
||||
result = _extract_domain("example.com/path")
|
||||
# Should at least not crash
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_extract_domain_complex_urls(self):
|
||||
"""Test extraction from complex URLs."""
|
||||
# urllib includes credentials in netloc, so this is expected behavior
|
||||
assert (
|
||||
_extract_domain("https://user:pass@example.com/path")
|
||||
== "user:pass@example.com"
|
||||
)
|
||||
assert (
|
||||
_extract_domain("https://example.com:443/path?query=value#hash")
|
||||
== "example.com:443"
|
||||
)
|
||||
|
||||
def test_extract_domain_ipv4(self):
|
||||
"""Test extraction from IPv4 addresses."""
|
||||
result = _extract_domain("http://192.168.1.1:8080/")
|
||||
# Should handle IP addresses
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_extract_domain_query_params(self):
|
||||
"""Test that query params don't affect domain extraction."""
|
||||
url1 = "https://example.com/page?query=value"
|
||||
url2 = "https://example.com/page"
|
||||
assert _extract_domain(url1) == _extract_domain(url2)
|
||||
|
||||
def test_extract_domain_url_fragments(self):
|
||||
"""Test that fragments don't affect domain extraction."""
|
||||
url1 = "https://example.com/page#section"
|
||||
url2 = "https://example.com/page"
|
||||
assert _extract_domain(url1) == _extract_domain(url2)
|
||||
|
||||
|
||||
class TestExtractTitleFromContent:
|
||||
"""Test intelligent title extraction with 5-tier priority system."""
|
||||
|
||||
def test_extract_title_html_title_tag(self):
|
||||
"""Test priority 1: HTML <title> tag extraction."""
|
||||
content = "<html><head><title>HTML Title</title></head><body>Content</body></html>"
|
||||
assert extract_title_from_content(content) == "HTML Title"
|
||||
|
||||
def test_extract_title_html_title_case_insensitive(self):
|
||||
"""Test that HTML title extraction is case-insensitive."""
|
||||
content = "<html><head><TITLE>HTML Title</TITLE></head><body></body></html>"
|
||||
assert extract_title_from_content(content) == "HTML Title"
|
||||
|
||||
def test_extract_title_markdown_h1(self):
|
||||
"""Test priority 2: Markdown h1 extraction."""
|
||||
content = "# Main Title\n\nSome content here"
|
||||
assert extract_title_from_content(content) == "Main Title"
|
||||
|
||||
def test_extract_title_markdown_h1_with_spaces(self):
|
||||
"""Test markdown h1 with extra spaces."""
|
||||
content = "# Title with Spaces \n\nContent"
|
||||
assert extract_title_from_content(content) == "Title with Spaces"
|
||||
|
||||
def test_extract_title_markdown_h2_fallback(self):
|
||||
"""Test priority 3: Markdown h2 as fallback when no h1."""
|
||||
content = "## Second Level Title\n\nSome content"
|
||||
assert extract_title_from_content(content) == "Second Level Title"
|
||||
|
||||
def test_extract_title_markdown_h6_fallback(self):
|
||||
"""Test markdown h6 as fallback."""
|
||||
content = "###### Small Heading\n\nContent"
|
||||
assert extract_title_from_content(content) == "Small Heading"
|
||||
|
||||
def test_extract_title_prefers_h1_over_h2(self):
|
||||
"""Test that h1 is preferred over h2."""
|
||||
content = "# H1 Title\n## H2 Title\n\nContent"
|
||||
assert extract_title_from_content(content) == "H1 Title"
|
||||
|
||||
def test_extract_title_json_field(self):
|
||||
"""Test priority 4: JSON title field extraction."""
|
||||
content = '{"title": "JSON Title", "content": "Some data"}'
|
||||
assert extract_title_from_content(content) == "JSON Title"
|
||||
|
||||
def test_extract_title_yaml_field(self):
|
||||
"""Test YAML title field extraction."""
|
||||
content = 'title: "YAML Title"\ncontent: "Some data"'
|
||||
assert extract_title_from_content(content) == "YAML Title"
|
||||
|
||||
def test_extract_title_first_substantial_line(self):
|
||||
"""Test priority 5: First substantial non-empty line."""
|
||||
content = "\n\n\nThis is the first substantial line\n\nMore content"
|
||||
assert extract_title_from_content(content) == "This is the first substantial line"
|
||||
|
||||
def test_extract_title_skips_short_lines(self):
|
||||
"""Test that short lines are skipped."""
|
||||
content = "abc\nThis is a longer first substantial line\nContent"
|
||||
assert extract_title_from_content(content) == "This is a longer first substantial line"
|
||||
|
||||
def test_extract_title_skips_code_blocks(self):
|
||||
"""Test that code blocks are skipped."""
|
||||
content = "```\ncode here\n```\nThis is the title\n\nContent"
|
||||
result = extract_title_from_content(content)
|
||||
# Should skip the code block and find the actual title
|
||||
assert "title" in result.lower() or "code" not in result
|
||||
|
||||
def test_extract_title_skips_list_items(self):
|
||||
"""Test that list items are skipped."""
|
||||
content = "- Item 1\n- Item 2\nThis is the actual first substantial line\n\nContent"
|
||||
result = extract_title_from_content(content)
|
||||
assert "actual" in result or "Item" not in result
|
||||
|
||||
def test_extract_title_skips_separators(self):
|
||||
"""Test that separator lines are skipped."""
|
||||
content = "---\n\n***\n\nThis is the real title\n\nContent"
|
||||
result = extract_title_from_content(content)
|
||||
assert "---" not in result and "***" not in result
|
||||
|
||||
def test_extract_title_max_length(self):
|
||||
"""Test that title respects max_length parameter."""
|
||||
long_title = "A" * 300
|
||||
content = f"# {long_title}"
|
||||
result = extract_title_from_content(content, max_length=100)
|
||||
assert len(result) <= 100
|
||||
assert result == long_title[:100]
|
||||
|
||||
def test_extract_title_empty_content(self):
|
||||
"""Test handling of empty content."""
|
||||
assert extract_title_from_content("") == "Untitled"
|
||||
assert extract_title_from_content(None) == "Untitled"
|
||||
|
||||
def test_extract_title_no_title_found(self):
|
||||
"""Test fallback to 'Untitled' when no title can be extracted."""
|
||||
content = "a\nb\nc\n" # Only short lines
|
||||
result = extract_title_from_content(content)
|
||||
# May return Untitled or one of the short lines
|
||||
assert isinstance(result, str)
|
||||
|
||||
def test_extract_title_whitespace_handling(self):
|
||||
"""Test that whitespace is properly handled."""
|
||||
content = "# Title with extra spaces \n\nContent"
|
||||
result = extract_title_from_content(content)
|
||||
# Should normalize spaces
|
||||
assert "Title with extra spaces" in result or len(result) > 5
|
||||
|
||||
def test_extract_title_multiline_html(self):
|
||||
"""Test HTML title extraction across multiple lines."""
|
||||
content = """
|
||||
<html>
|
||||
<head>
|
||||
<title>
|
||||
Multiline Title
|
||||
</title>
|
||||
</head>
|
||||
<body>Content</body>
|
||||
</html>
|
||||
"""
|
||||
result = extract_title_from_content(content)
|
||||
# Should handle multiline titles
|
||||
assert "Title" in result
|
||||
|
||||
def test_extract_title_mixed_formats(self):
|
||||
"""Test content with mixed formats (h1 should win)."""
|
||||
content = """
|
||||
<title>HTML Title</title>
|
||||
# Markdown H1
|
||||
## Markdown H2
|
||||
|
||||
Some paragraph content
|
||||
"""
|
||||
# HTML title comes first in priority
|
||||
assert extract_title_from_content(content) == "HTML Title"
|
||||
|
||||
def test_extract_title_real_world_example(self):
|
||||
"""Test with real-world HTML example."""
|
||||
content = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>GitHub: Where the world builds software</title>
|
||||
<meta property="og:title" content="GitHub">
|
||||
</head>
|
||||
<body>
|
||||
<h1>Let's build from here</h1>
|
||||
<p>The complete developer platform...</p>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
result = extract_title_from_content(content)
|
||||
assert result == "GitHub: Where the world builds software"
|
||||
|
||||
def test_extract_title_json_with_nested_title(self):
|
||||
"""Test JSON title extraction with nested structures."""
|
||||
content = '{"meta": {"title": "Should not match"}, "title": "JSON Title"}'
|
||||
result = extract_title_from_content(content)
|
||||
# The regex will match the first "title" field it finds, which could be nested
|
||||
# Just verify it finds a title field
|
||||
assert result and result != "Untitled"
|
||||
|
||||
def test_extract_title_preserves_special_characters(self):
|
||||
"""Test that special characters are preserved in title."""
|
||||
content = "# Title with Special Characters: @#$%"
|
||||
result = extract_title_from_content(content)
|
||||
assert "@" in result or "$" in result or "%" in result or "Title" in result
|
||||
423
tests/unit/citations/test_formatter.py
Normal file
423
tests/unit/citations/test_formatter.py
Normal file
@@ -0,0 +1,423 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for citation formatter enhancements.
|
||||
|
||||
Tests the multi-format citation parsing and extraction capabilities.
|
||||
"""
|
||||
|
||||
from src.citations.formatter import (
|
||||
parse_citations_from_report,
|
||||
_extract_markdown_links,
|
||||
_extract_numbered_citations,
|
||||
_extract_footnote_citations,
|
||||
_extract_html_links,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractMarkdownLinks:
|
||||
"""Test Markdown link extraction [title](url)."""
|
||||
|
||||
def test_extract_single_markdown_link(self):
|
||||
"""Test extraction of a single markdown link."""
|
||||
text = "[Example Article](https://example.com)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Example Article"
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
assert citations[0]["format"] == "markdown"
|
||||
|
||||
def test_extract_multiple_markdown_links(self):
|
||||
"""Test extraction of multiple markdown links."""
|
||||
text = "[Link 1](https://example.com/1) and [Link 2](https://example.com/2)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 2
|
||||
assert citations[0]["title"] == "Link 1"
|
||||
assert citations[1]["title"] == "Link 2"
|
||||
|
||||
def test_extract_markdown_link_with_spaces(self):
|
||||
"""Test markdown link with spaces in title."""
|
||||
text = "[Article Title With Spaces](https://example.com)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Article Title With Spaces"
|
||||
|
||||
def test_extract_markdown_link_ignore_non_http(self):
|
||||
"""Test that non-HTTP URLs are ignored."""
|
||||
text = "[Relative Link](./relative/path) [HTTP Link](https://example.com)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
|
||||
def test_extract_markdown_link_with_query_params(self):
|
||||
"""Test markdown links with query parameters."""
|
||||
text = "[Search Result](https://example.com/search?q=test&page=1)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 1
|
||||
assert "q=test" in citations[0]["url"]
|
||||
|
||||
def test_extract_markdown_link_empty_text(self):
|
||||
"""Test with no markdown links."""
|
||||
text = "Just plain text with no links"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
def test_extract_markdown_link_strip_whitespace(self):
|
||||
"""Test that whitespace in title and URL is stripped."""
|
||||
# Markdown links with spaces in URL are not valid, so they won't be extracted
|
||||
text = "[Title](https://example.com)"
|
||||
citations = _extract_markdown_links(text)
|
||||
assert len(citations) >= 1
|
||||
assert citations[0]["title"] == "Title"
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
|
||||
|
||||
class TestExtractNumberedCitations:
|
||||
"""Test numbered citation extraction [1] Title - URL."""
|
||||
|
||||
def test_extract_single_numbered_citation(self):
|
||||
"""Test extraction of a single numbered citation."""
|
||||
text = "[1] Example Article - https://example.com"
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Example Article"
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
assert citations[0]["format"] == "numbered"
|
||||
|
||||
def test_extract_multiple_numbered_citations(self):
|
||||
"""Test extraction of multiple numbered citations."""
|
||||
text = "[1] First - https://example.com/1\n[2] Second - https://example.com/2"
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 2
|
||||
assert citations[0]["title"] == "First"
|
||||
assert citations[1]["title"] == "Second"
|
||||
|
||||
def test_extract_numbered_citation_with_long_title(self):
|
||||
"""Test numbered citation with longer title."""
|
||||
text = "[5] A Comprehensive Guide to Python Programming - https://example.com"
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 1
|
||||
assert "Comprehensive Guide" in citations[0]["title"]
|
||||
|
||||
def test_extract_numbered_citation_requires_valid_format(self):
|
||||
"""Test that invalid numbered format is not extracted."""
|
||||
text = "[1 Title - https://example.com" # Missing closing bracket
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
def test_extract_numbered_citation_empty_text(self):
|
||||
"""Test with no numbered citations."""
|
||||
text = "Just plain text"
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
def test_extract_numbered_citation_various_numbers(self):
|
||||
"""Test with various citation numbers."""
|
||||
text = "[10] Title Ten - https://example.com/10\n[999] Title 999 - https://example.com/999"
|
||||
citations = _extract_numbered_citations(text)
|
||||
assert len(citations) == 2
|
||||
|
||||
def test_extract_numbered_citation_ignore_non_http(self):
|
||||
"""Test that non-HTTP URLs in numbered citations are ignored."""
|
||||
text = "[1] Invalid - file://path [2] Valid - https://example.com"
|
||||
citations = _extract_numbered_citations(text)
|
||||
# Only the valid one should be extracted
|
||||
assert len(citations) <= 1
|
||||
|
||||
|
||||
class TestExtractFootnoteCitations:
|
||||
"""Test footnote citation extraction [^1]: Title - URL."""
|
||||
|
||||
def test_extract_single_footnote_citation(self):
|
||||
"""Test extraction of a single footnote citation."""
|
||||
text = "[^1]: Example Article - https://example.com"
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Example Article"
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
assert citations[0]["format"] == "footnote"
|
||||
|
||||
def test_extract_multiple_footnote_citations(self):
|
||||
"""Test extraction of multiple footnote citations."""
|
||||
text = "[^1]: First - https://example.com/1\n[^2]: Second - https://example.com/2"
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 2
|
||||
|
||||
def test_extract_footnote_with_complex_number(self):
|
||||
"""Test footnote extraction with various numbers."""
|
||||
text = "[^123]: Title - https://example.com"
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Title"
|
||||
|
||||
def test_extract_footnote_citation_with_spaces(self):
|
||||
"""Test footnote with spaces around separator."""
|
||||
text = "[^1]: Title with spaces - https://example.com "
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 1
|
||||
# Should strip whitespace
|
||||
assert citations[0]["title"] == "Title with spaces"
|
||||
|
||||
def test_extract_footnote_citation_empty_text(self):
|
||||
"""Test with no footnote citations."""
|
||||
text = "No footnotes here"
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
def test_extract_footnote_requires_caret(self):
|
||||
"""Test that missing caret prevents extraction."""
|
||||
text = "[1]: Title - https://example.com" # Missing ^
|
||||
citations = _extract_footnote_citations(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
|
||||
class TestExtractHtmlLinks:
|
||||
"""Test HTML link extraction <a href="url">title</a>."""
|
||||
|
||||
def test_extract_single_html_link(self):
|
||||
"""Test extraction of a single HTML link."""
|
||||
text = '<a href="https://example.com">Example Article</a>'
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["title"] == "Example Article"
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
assert citations[0]["format"] == "html"
|
||||
|
||||
def test_extract_multiple_html_links(self):
|
||||
"""Test extraction of multiple HTML links."""
|
||||
text = '<a href="https://a.com">Link A</a> <a href="https://b.com">Link B</a>'
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 2
|
||||
|
||||
def test_extract_html_link_single_quotes(self):
|
||||
"""Test HTML links with single quotes."""
|
||||
text = "<a href='https://example.com'>Title</a>"
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
|
||||
def test_extract_html_link_with_attributes(self):
|
||||
"""Test HTML links with additional attributes."""
|
||||
text = '<a class="link" href="https://example.com" target="_blank">Title</a>'
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
|
||||
def test_extract_html_link_ignore_non_http(self):
|
||||
"""Test that non-HTTP URLs are ignored."""
|
||||
text = '<a href="mailto:test@example.com">Email</a> <a href="https://example.com">Web</a>'
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 1
|
||||
assert citations[0]["url"] == "https://example.com"
|
||||
|
||||
def test_extract_html_link_case_insensitive(self):
|
||||
"""Test that HTML extraction is case-insensitive."""
|
||||
text = '<A HREF="https://example.com">Title</A>'
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 1
|
||||
|
||||
def test_extract_html_link_empty_text(self):
|
||||
"""Test with no HTML links."""
|
||||
text = "No links here"
|
||||
citations = _extract_html_links(text)
|
||||
assert len(citations) == 0
|
||||
|
||||
def test_extract_html_link_strip_whitespace(self):
|
||||
"""Test that whitespace in title is stripped."""
|
||||
text = '<a href="https://example.com"> Title with spaces </a>'
|
||||
citations = _extract_html_links(text)
|
||||
assert citations[0]["title"] == "Title with spaces"
|
||||
|
||||
|
||||
class TestParseCitationsFromReport:
|
||||
"""Test comprehensive citation parsing from complete reports."""
|
||||
|
||||
def test_parse_markdown_links_from_report(self):
|
||||
"""Test parsing markdown links from a report."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[GitHub](https://github.com)
|
||||
[Python Docs](https://python.org)
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
assert result["count"] >= 2
|
||||
urls = [c["url"] for c in result["citations"]]
|
||||
assert "https://github.com" in urls
|
||||
|
||||
def test_parse_numbered_citations_from_report(self):
|
||||
"""Test parsing numbered citations."""
|
||||
report = """
|
||||
## References
|
||||
|
||||
[1] GitHub - https://github.com
|
||||
[2] Python - https://python.org
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
assert result["count"] >= 2
|
||||
|
||||
def test_parse_mixed_format_citations(self):
|
||||
"""Test parsing mixed citation formats."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[GitHub](https://github.com)
|
||||
[^1]: Python - https://python.org
|
||||
[2] Wikipedia - https://wikipedia.org
|
||||
<a href="https://stackoverflow.com">Stack Overflow</a>
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
# Should find all 4 citations
|
||||
assert result["count"] >= 3
|
||||
|
||||
def test_parse_citations_deduplication(self):
|
||||
"""Test that duplicate URLs are deduplicated."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[GitHub 1](https://github.com)
|
||||
[GitHub 2](https://github.com)
|
||||
[GitHub](https://github.com)
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
# Should have only 1 unique citation
|
||||
assert result["count"] == 1
|
||||
assert result["citations"][0]["url"] == "https://github.com"
|
||||
|
||||
def test_parse_citations_various_section_patterns(self):
|
||||
"""Test parsing with different section headers."""
|
||||
report_refs = """
|
||||
## References
|
||||
[GitHub](https://github.com)
|
||||
"""
|
||||
report_sources = """
|
||||
## Sources
|
||||
[GitHub](https://github.com)
|
||||
"""
|
||||
report_bibliography = """
|
||||
## Bibliography
|
||||
[GitHub](https://github.com)
|
||||
"""
|
||||
|
||||
assert parse_citations_from_report(report_refs)["count"] >= 1
|
||||
assert parse_citations_from_report(report_sources)["count"] >= 1
|
||||
assert parse_citations_from_report(report_bibliography)["count"] >= 1
|
||||
|
||||
def test_parse_citations_custom_patterns(self):
|
||||
"""Test parsing with custom section patterns."""
|
||||
report = """
|
||||
## My Custom Sources
|
||||
[GitHub](https://github.com)
|
||||
"""
|
||||
result = parse_citations_from_report(
|
||||
report,
|
||||
section_patterns=[r"##\s*My Custom Sources"]
|
||||
)
|
||||
assert result["count"] >= 1
|
||||
|
||||
def test_parse_citations_empty_report(self):
|
||||
"""Test parsing an empty report."""
|
||||
result = parse_citations_from_report("")
|
||||
assert result["count"] == 0
|
||||
assert result["citations"] == []
|
||||
|
||||
def test_parse_citations_no_section(self):
|
||||
"""Test parsing report without citation section."""
|
||||
report = "This is a report with no citations section"
|
||||
result = parse_citations_from_report(report)
|
||||
assert result["count"] == 0
|
||||
|
||||
def test_parse_citations_complex_report(self):
|
||||
"""Test parsing a complex, realistic report."""
|
||||
report = """
|
||||
# Research Report
|
||||
|
||||
## Introduction
|
||||
|
||||
This report summarizes findings from multiple sources.
|
||||
|
||||
## Key Findings
|
||||
|
||||
Some important discoveries were made based on research [GitHub](https://github.com).
|
||||
|
||||
## Key Citations
|
||||
|
||||
1. Primary sources:
|
||||
[GitHub](https://github.com) - A collaborative platform
|
||||
[^1]: Python - https://python.org
|
||||
|
||||
2. Secondary sources:
|
||||
[2] Wikipedia - https://wikipedia.org
|
||||
|
||||
3. Web resources:
|
||||
<a href="https://stackoverflow.com">Stack Overflow</a>
|
||||
|
||||
## Methodology
|
||||
|
||||
[Additional](https://example.com) details about methodology.
|
||||
|
||||
---
|
||||
|
||||
[^1]: The Python programming language official site
|
||||
"""
|
||||
|
||||
result = parse_citations_from_report(report)
|
||||
# Should extract multiple citations from the Key Citations section
|
||||
assert result["count"] >= 3
|
||||
urls = [c["url"] for c in result["citations"]]
|
||||
# Verify some key URLs are found
|
||||
assert any("github.com" in url or "python.org" in url for url in urls)
|
||||
|
||||
def test_parse_citations_stops_at_next_section(self):
|
||||
"""Test that citation extraction looks for citation sections."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[Cite 1](https://example.com/1)
|
||||
[Cite 2](https://example.com/2)
|
||||
|
||||
## Next Section
|
||||
|
||||
Some other content
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
# Should extract citations from the Key Citations section
|
||||
# Note: The regex stops at next ## section
|
||||
assert result["count"] >= 1
|
||||
assert any("example.com/1" in c["url"] for c in result["citations"])
|
||||
|
||||
def test_parse_citations_preserves_metadata(self):
|
||||
"""Test that citation metadata is preserved."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[Python Documentation](https://python.org)
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
assert len(result["citations"]) >= 1
|
||||
citation = result["citations"][0]
|
||||
assert "title" in citation
|
||||
assert "url" in citation
|
||||
assert "format" in citation
|
||||
|
||||
def test_parse_citations_whitespace_handling(self):
|
||||
"""Test handling of various whitespace configurations."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
[Link](https://example.com)
|
||||
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
assert result["count"] >= 1
|
||||
|
||||
def test_parse_citations_multiline_links(self):
|
||||
"""Test extraction of links across formatting."""
|
||||
report = """
|
||||
## Key Citations
|
||||
|
||||
Some paragraph with a [link to example](https://example.com) in the middle.
|
||||
"""
|
||||
result = parse_citations_from_report(report)
|
||||
assert result["count"] >= 1
|
||||
467
tests/unit/citations/test_models.py
Normal file
467
tests/unit/citations/test_models.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""
|
||||
Unit tests for citation models.
|
||||
|
||||
Tests the Pydantic BaseModel implementation of CitationMetadata and Citation classes.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.citations.models import Citation, CitationMetadata
|
||||
|
||||
|
||||
class TestCitationMetadata:
|
||||
"""Test CitationMetadata Pydantic model."""
|
||||
|
||||
def test_create_basic_metadata(self):
|
||||
"""Test creating basic citation metadata."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com/article",
|
||||
title="Example Article",
|
||||
)
|
||||
assert metadata.url == "https://example.com/article"
|
||||
assert metadata.title == "Example Article"
|
||||
assert metadata.domain == "example.com" # Auto-extracted from URL
|
||||
assert metadata.description is None
|
||||
assert metadata.images == []
|
||||
assert metadata.extra == {}
|
||||
|
||||
def test_metadata_with_all_fields(self):
|
||||
"""Test creating metadata with all fields populated."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://github.com/example/repo",
|
||||
title="Example Repository",
|
||||
description="A great repository",
|
||||
content_snippet="This is a snippet",
|
||||
raw_content="Full content here",
|
||||
author="John Doe",
|
||||
published_date="2025-01-24",
|
||||
language="en",
|
||||
relevance_score=0.95,
|
||||
credibility_score=0.88,
|
||||
)
|
||||
assert metadata.url == "https://github.com/example/repo"
|
||||
assert metadata.domain == "github.com"
|
||||
assert metadata.author == "John Doe"
|
||||
assert metadata.relevance_score == 0.95
|
||||
assert metadata.credibility_score == 0.88
|
||||
|
||||
def test_metadata_domain_auto_extraction(self):
|
||||
"""Test automatic domain extraction from URL."""
|
||||
test_cases = [
|
||||
("https://www.example.com/path", "www.example.com"),
|
||||
("http://github.com/user/repo", "github.com"),
|
||||
("https://api.github.com:443/repos", "api.github.com:443"),
|
||||
]
|
||||
|
||||
for url, expected_domain in test_cases:
|
||||
metadata = CitationMetadata(url=url, title="Test")
|
||||
assert metadata.domain == expected_domain
|
||||
|
||||
def test_metadata_id_generation(self):
|
||||
"""Test unique ID generation from URL."""
|
||||
metadata1 = CitationMetadata(
|
||||
url="https://example.com/article",
|
||||
title="Article",
|
||||
)
|
||||
metadata2 = CitationMetadata(
|
||||
url="https://example.com/article",
|
||||
title="Article",
|
||||
)
|
||||
# Same URL should produce same ID
|
||||
assert metadata1.id == metadata2.id
|
||||
|
||||
metadata3 = CitationMetadata(
|
||||
url="https://different.com/article",
|
||||
title="Article",
|
||||
)
|
||||
# Different URL should produce different ID
|
||||
assert metadata1.id != metadata3.id
|
||||
|
||||
def test_metadata_id_length(self):
|
||||
"""Test that ID is truncated to 12 characters."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Test",
|
||||
)
|
||||
assert len(metadata.id) == 12
|
||||
assert metadata.id.isalnum() or all(c in "0123456789abcdef" for c in metadata.id)
|
||||
|
||||
def test_metadata_from_dict(self):
|
||||
"""Test creating metadata from dictionary."""
|
||||
data = {
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"description": "A description",
|
||||
"author": "John Doe",
|
||||
}
|
||||
metadata = CitationMetadata.from_dict(data)
|
||||
assert metadata.url == "https://example.com"
|
||||
assert metadata.title == "Example"
|
||||
assert metadata.description == "A description"
|
||||
assert metadata.author == "John Doe"
|
||||
|
||||
def test_metadata_from_dict_removes_id(self):
|
||||
"""Test that from_dict removes computed 'id' field."""
|
||||
data = {
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"id": "some_old_id", # Should be ignored
|
||||
}
|
||||
metadata = CitationMetadata.from_dict(data)
|
||||
# Should use newly computed ID, not the old one
|
||||
assert metadata.id != "some_old_id"
|
||||
|
||||
def test_metadata_to_dict(self):
|
||||
"""Test converting metadata to dictionary."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
author="John Doe",
|
||||
)
|
||||
result = metadata.to_dict()
|
||||
|
||||
assert result["url"] == "https://example.com"
|
||||
assert result["title"] == "Example"
|
||||
assert result["author"] == "John Doe"
|
||||
assert result["id"] == metadata.id
|
||||
assert result["domain"] == "example.com"
|
||||
|
||||
def test_metadata_from_search_result(self):
|
||||
"""Test creating metadata from search result."""
|
||||
search_result = {
|
||||
"url": "https://example.com/article",
|
||||
"title": "Article Title",
|
||||
"content": "Article content here",
|
||||
"score": 0.92,
|
||||
"type": "page",
|
||||
}
|
||||
metadata = CitationMetadata.from_search_result(
|
||||
search_result,
|
||||
query="test query",
|
||||
)
|
||||
|
||||
assert metadata.url == "https://example.com/article"
|
||||
assert metadata.title == "Article Title"
|
||||
assert metadata.description == "Article content here"
|
||||
assert metadata.relevance_score == 0.92
|
||||
assert metadata.extra["query"] == "test query"
|
||||
assert metadata.extra["result_type"] == "page"
|
||||
|
||||
def test_metadata_pydantic_validation(self):
|
||||
"""Test that Pydantic validates required fields."""
|
||||
# URL and title are required
|
||||
with pytest.raises(ValidationError):
|
||||
CitationMetadata() # Missing required fields
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
CitationMetadata(url="https://example.com") # Missing title
|
||||
|
||||
def test_metadata_model_dump(self):
|
||||
"""Test Pydantic model_dump method."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
author="John Doe",
|
||||
)
|
||||
result = metadata.model_dump()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["url"] == "https://example.com"
|
||||
assert result["title"] == "Example"
|
||||
|
||||
def test_metadata_model_dump_json(self):
|
||||
"""Test Pydantic model_dump_json method."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
result = metadata.model_dump_json()
|
||||
|
||||
assert isinstance(result, str)
|
||||
data = json.loads(result)
|
||||
assert data["url"] == "https://example.com"
|
||||
assert data["title"] == "Example"
|
||||
|
||||
def test_metadata_with_images_and_extra(self):
|
||||
"""Test metadata with list and dict fields."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
images=["https://example.com/image1.jpg", "https://example.com/image2.jpg"],
|
||||
favicon="https://example.com/favicon.ico",
|
||||
extra={"custom_field": "value", "tags": ["tag1", "tag2"]},
|
||||
)
|
||||
|
||||
assert len(metadata.images) == 2
|
||||
assert metadata.favicon == "https://example.com/favicon.ico"
|
||||
assert metadata.extra["custom_field"] == "value"
|
||||
|
||||
|
||||
class TestCitation:
|
||||
"""Test Citation Pydantic model."""
|
||||
|
||||
def test_create_basic_citation(self):
|
||||
"""Test creating a basic citation."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(number=1, metadata=metadata)
|
||||
|
||||
assert citation.number == 1
|
||||
assert citation.metadata == metadata
|
||||
assert citation.context is None
|
||||
assert citation.cited_text is None
|
||||
|
||||
def test_citation_properties(self):
|
||||
"""Test citation property shortcuts."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example Title",
|
||||
)
|
||||
citation = Citation(number=1, metadata=metadata)
|
||||
|
||||
assert citation.id == metadata.id
|
||||
assert citation.url == "https://example.com"
|
||||
assert citation.title == "Example Title"
|
||||
|
||||
def test_citation_to_markdown_reference(self):
|
||||
"""Test markdown reference generation."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(number=1, metadata=metadata)
|
||||
|
||||
result = citation.to_markdown_reference()
|
||||
assert result == "[Example](https://example.com)"
|
||||
|
||||
def test_citation_to_numbered_reference(self):
|
||||
"""Test numbered reference generation."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example Article",
|
||||
)
|
||||
citation = Citation(number=5, metadata=metadata)
|
||||
|
||||
result = citation.to_numbered_reference()
|
||||
assert result == "[5] Example Article - https://example.com"
|
||||
|
||||
def test_citation_to_inline_marker(self):
|
||||
"""Test inline marker generation."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(number=3, metadata=metadata)
|
||||
|
||||
result = citation.to_inline_marker()
|
||||
assert result == "[^3]"
|
||||
|
||||
def test_citation_to_footnote(self):
|
||||
"""Test footnote generation."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example Article",
|
||||
)
|
||||
citation = Citation(number=2, metadata=metadata)
|
||||
|
||||
result = citation.to_footnote()
|
||||
assert result == "[^2]: Example Article - https://example.com"
|
||||
|
||||
def test_citation_with_context_and_text(self):
|
||||
"""Test citation with context and cited text."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(
|
||||
number=1,
|
||||
metadata=metadata,
|
||||
context="This is important context",
|
||||
cited_text="Important quote from the source",
|
||||
)
|
||||
|
||||
assert citation.context == "This is important context"
|
||||
assert citation.cited_text == "Important quote from the source"
|
||||
|
||||
def test_citation_from_dict(self):
|
||||
"""Test creating citation from dictionary."""
|
||||
data = {
|
||||
"number": 1,
|
||||
"metadata": {
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"author": "John Doe",
|
||||
},
|
||||
"context": "Test context",
|
||||
}
|
||||
citation = Citation.from_dict(data)
|
||||
|
||||
assert citation.number == 1
|
||||
assert citation.metadata.url == "https://example.com"
|
||||
assert citation.metadata.title == "Example"
|
||||
assert citation.metadata.author == "John Doe"
|
||||
assert citation.context == "Test context"
|
||||
|
||||
def test_citation_to_dict(self):
|
||||
"""Test converting citation to dictionary."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
author="John Doe",
|
||||
)
|
||||
citation = Citation(
|
||||
number=1,
|
||||
metadata=metadata,
|
||||
context="Test context",
|
||||
)
|
||||
result = citation.to_dict()
|
||||
|
||||
assert result["number"] == 1
|
||||
assert result["metadata"]["url"] == "https://example.com"
|
||||
assert result["metadata"]["author"] == "John Doe"
|
||||
assert result["context"] == "Test context"
|
||||
|
||||
def test_citation_round_trip(self):
|
||||
"""Test converting to dict and back."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
author="John Doe",
|
||||
relevance_score=0.95,
|
||||
)
|
||||
original = Citation(number=1, metadata=metadata, context="Test")
|
||||
|
||||
# Convert to dict and back
|
||||
dict_repr = original.to_dict()
|
||||
restored = Citation.from_dict(dict_repr)
|
||||
|
||||
assert restored.number == original.number
|
||||
assert restored.metadata.url == original.metadata.url
|
||||
assert restored.metadata.title == original.metadata.title
|
||||
assert restored.metadata.author == original.metadata.author
|
||||
assert restored.metadata.relevance_score == original.metadata.relevance_score
|
||||
|
||||
def test_citation_model_dump(self):
|
||||
"""Test Pydantic model_dump method."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(number=1, metadata=metadata)
|
||||
result = citation.model_dump()
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["number"] == 1
|
||||
assert result["metadata"]["url"] == "https://example.com"
|
||||
|
||||
def test_citation_model_dump_json(self):
|
||||
"""Test Pydantic model_dump_json method."""
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
citation = Citation(number=1, metadata=metadata)
|
||||
result = citation.model_dump_json()
|
||||
|
||||
assert isinstance(result, str)
|
||||
data = json.loads(result)
|
||||
assert data["number"] == 1
|
||||
assert data["metadata"]["url"] == "https://example.com"
|
||||
|
||||
def test_citation_pydantic_validation(self):
|
||||
"""Test that Pydantic validates required fields."""
|
||||
# Number and metadata are required
|
||||
with pytest.raises(ValidationError):
|
||||
Citation() # Missing required fields
|
||||
|
||||
metadata = CitationMetadata(
|
||||
url="https://example.com",
|
||||
title="Example",
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
Citation(metadata=metadata) # Missing number
|
||||
|
||||
|
||||
class TestCitationIntegration:
|
||||
"""Integration tests for citation models."""
|
||||
|
||||
def test_search_result_to_citation_workflow(self):
|
||||
"""Test complete workflow from search result to citation."""
|
||||
search_result = {
|
||||
"url": "https://example.com/article",
|
||||
"title": "Great Article",
|
||||
"content": "This is a great article about testing",
|
||||
"score": 0.92,
|
||||
}
|
||||
|
||||
# Create metadata from search result
|
||||
metadata = CitationMetadata.from_search_result(search_result, query="testing")
|
||||
|
||||
# Create citation
|
||||
citation = Citation(number=1, metadata=metadata, context="Important source")
|
||||
|
||||
# Verify the workflow
|
||||
assert citation.number == 1
|
||||
assert citation.url == "https://example.com/article"
|
||||
assert citation.title == "Great Article"
|
||||
assert citation.metadata.relevance_score == 0.92
|
||||
assert citation.to_markdown_reference() == "[Great Article](https://example.com/article)"
|
||||
|
||||
def test_multiple_citations_with_different_formats(self):
|
||||
"""Test handling multiple citations in different formats."""
|
||||
citations = []
|
||||
|
||||
# Create first citation
|
||||
metadata1 = CitationMetadata(
|
||||
url="https://example.com/1",
|
||||
title="First Article",
|
||||
)
|
||||
citations.append(Citation(number=1, metadata=metadata1))
|
||||
|
||||
# Create second citation
|
||||
metadata2 = CitationMetadata(
|
||||
url="https://example.com/2",
|
||||
title="Second Article",
|
||||
)
|
||||
citations.append(Citation(number=2, metadata=metadata2))
|
||||
|
||||
# Verify all reference formats
|
||||
assert citations[0].to_markdown_reference() == "[First Article](https://example.com/1)"
|
||||
assert citations[1].to_numbered_reference() == "[2] Second Article - https://example.com/2"
|
||||
|
||||
def test_citation_json_serialization_roundtrip(self):
|
||||
"""Test JSON serialization and deserialization roundtrip."""
|
||||
original_data = {
|
||||
"number": 1,
|
||||
"metadata": {
|
||||
"url": "https://example.com",
|
||||
"title": "Example",
|
||||
"author": "John Doe",
|
||||
"relevance_score": 0.95,
|
||||
},
|
||||
"context": "Test context",
|
||||
"cited_text": "Important quote",
|
||||
}
|
||||
|
||||
# Create from dict
|
||||
citation = Citation.from_dict(original_data)
|
||||
|
||||
# Serialize to JSON
|
||||
json_str = citation.model_dump_json()
|
||||
|
||||
# Deserialize from JSON
|
||||
restored = Citation.model_validate_json(json_str)
|
||||
|
||||
# Verify data integrity
|
||||
assert restored.number == original_data["number"]
|
||||
assert restored.metadata.url == original_data["metadata"]["url"]
|
||||
assert restored.metadata.relevance_score == original_data["metadata"]["relevance_score"]
|
||||
assert restored.context == original_data["context"]
|
||||
Reference in New Issue
Block a user