"""Unit tests for the ground-truth export service.""" from __future__ import annotations import json from datetime import datetime, timezone from uuid import UUID, uuid4 import pytest from ocr_sprint.db.base import Base, get_engine, session_scope from ocr_sprint.db.repositories import JobRepository from ocr_sprint.ground_truth import ( GroundTruthFilters, build_initial_result, ground_truth_stats, iter_ground_truth_samples, serialize_sample_to_jsonl, ) from ocr_sprint.schemas.document import DocumentStatus, SourceKind @pytest.fixture def db_ready() -> None: Base.metadata.create_all(bind=get_engine()) def _seed_approved_job_with_corrections( *, final_result: dict[str, object] | None = None, corrections: list[tuple[str, object, object]] | None = None, approved: bool = True, ) -> UUID: jid = uuid4() with session_scope() as session: repo = JobRepository(session) repo.create(job_id=jid, filename="x.pdf", source_kind=SourceKind.PDF, blob_key="k") with session_scope() as session: JobRepository(session).mark_completed( jid, status=DocumentStatus.NEEDS_REVIEW, confidence=0.8, result=final_result or { "header": {"nomor_sprint": "SPR/1/2025", "satuan_penerbit": "POLRES X"}, "personel": [{"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"}], }, review_flags=[], ) if corrections: # Corrections are tuples of (path, value, old_value_for_audit). # We translate them into real PATCH calls so the audit rows look # exactly like they would in production. with session_scope() as session: JobRepository(session).apply_corrections( jid, corrections=[(p, new, None) for (p, new, _old) in corrections], corrected_by="reviewer-a", ) if approved: with session_scope() as session: JobRepository(session).approve(jid, reviewed_by="reviewer-a") return jid def test_build_initial_result_reverses_corrections() -> None: """With known old/new pairs the replay must return the pre-HITL dict.""" final = { "header": {"nomor_sprint": "SPR/1/2025", "perihal": "Penyelidikan"}, } class _Fake: def __init__(self, path: str, old: object, new: object) -> None: self.field_path = path self.old_value = old self.new_value = new corrections = [ _Fake("header.perihal", None, "Penyelidikan"), _Fake("header.nomor_sprint", "SPR/OLD/2025", "SPR/1/2025"), ] restored = build_initial_result(final_result=final, corrections=corrections) # type: ignore[arg-type] assert restored == { "header": {"nomor_sprint": "SPR/OLD/2025", "perihal": None}, } # ``final`` must not have been mutated. assert final["header"]["nomor_sprint"] == "SPR/1/2025" def test_build_initial_result_ignores_unresolvable_paths() -> None: """A legacy path that no longer resolves should be skipped, not raise.""" final = {"header": {"nomor_sprint": "SPR/1"}} class _Fake: def __init__(self, path: str, old: object, new: object) -> None: self.field_path = path self.old_value = old self.new_value = new corrections = [ _Fake("personel[99].nrp", "stale", "new"), # container doesn't exist _Fake("bogus..path", "x", "y"), ] restored = build_initial_result(final_result=final, corrections=corrections) # type: ignore[arg-type] assert restored == final def test_iter_samples_reconstructs_initial_and_final(db_ready: None) -> None: jid = _seed_approved_job_with_corrections( corrections=[("header.perihal", "Penyelidikan", None)], ) with session_scope() as session: samples = list(iter_ground_truth_samples(session, GroundTruthFilters())) assert len(samples) == 1 s = samples[0] assert s.job_id == jid assert s.approved is True assert s.reviewed_by == "reviewer-a" assert s.final_result is not None assert s.final_result["header"]["perihal"] == "Penyelidikan" # Initial reconstruction undoes the correction. assert s.initial_result is not None assert s.initial_result["header"].get("perihal") is None assert len(s.corrections) == 1 assert s.corrections[0].field_path == "header.perihal" def test_iter_samples_respects_approved_only_default(db_ready: None) -> None: """Unapproved jobs must be excluded unless approved_only=False.""" _seed_approved_job_with_corrections(approved=False) with session_scope() as session: samples = list(iter_ground_truth_samples(session, GroundTruthFilters())) assert samples == [] with session_scope() as session: samples = list(iter_ground_truth_samples(session, GroundTruthFilters(approved_only=False))) assert len(samples) == 1 def test_iter_samples_respects_has_corrections_filter(db_ready: None) -> None: _seed_approved_job_with_corrections(corrections=None) # pristine _seed_approved_job_with_corrections(corrections=[("header.perihal", "fill", None)]) with session_scope() as session: with_ = list(iter_ground_truth_samples(session, GroundTruthFilters(has_corrections=True))) without = list( iter_ground_truth_samples(session, GroundTruthFilters(has_corrections=False)) ) assert len(with_) == 1 assert len(without) == 1 assert with_[0].job_id != without[0].job_id def test_iter_samples_respects_since_until_and_limit(db_ready: None) -> None: _seed_approved_job_with_corrections() _seed_approved_job_with_corrections() _seed_approved_job_with_corrections() # Since far in the future → empty. future = datetime(2999, 1, 1, tzinfo=timezone.utc) with session_scope() as session: out = list(iter_ground_truth_samples(session, GroundTruthFilters(since=future))) assert out == [] # Limit caps the number emitted. with session_scope() as session: out = list(iter_ground_truth_samples(session, GroundTruthFilters(limit=2))) assert len(out) == 2 def test_stats_counts_rollup_and_top_fields(db_ready: None) -> None: # 1 approved job with 2 corrections on ``header.perihal``. _seed_approved_job_with_corrections( corrections=[ ("header.perihal", "first", None), ] ) jid = _seed_approved_job_with_corrections( corrections=[ ("header.perihal", "second", None), ("personel[0].nrp", "77060001", None), ], ) assert jid # silence unused with session_scope() as session: stats = ground_truth_stats(session) assert stats.total_jobs == 2 assert stats.approved_jobs == 2 assert stats.total_corrections == 3 assert stats.jobs_with_corrections == 2 # ``header.perihal`` is the most-corrected field (2 > 1). assert stats.top_corrected_fields[0].field_path == "header.perihal" assert stats.top_corrected_fields[0].count == 2 assert {f.field_path for f in stats.top_corrected_fields} == { "header.perihal", "personel[0].nrp", } def test_serialize_is_valid_jsonl(db_ready: None) -> None: _seed_approved_job_with_corrections(corrections=[("header.perihal", "X", None)]) with session_scope() as session: for sample in iter_ground_truth_samples(session, GroundTruthFilters()): line = serialize_sample_to_jsonl(sample) assert line.endswith("\n") # Strip trailing newline and make sure the result is a single # JSON object (not a list — JSONL must be one object per line). parsed = json.loads(line) assert isinstance(parsed, dict) assert parsed["approved"] is True