Phase 7: ground-truth export (JSONL + stats) + CLI tool
- GET /api/v1/ground-truth/export streaming JSONL (approved_only, since, until, has_corrections, limit) - GET /api/v1/ground-truth/stats total / approved / corrections counts + top-N most-corrected field paths - python -m ocr_sprint.tools.export_ground_truth operator CLI with the same filters + optional --print-stats - Ground-truth sample reconstructs the pipeline's original output by replaying job_corrections in reverse - docs/ground-truth-format.md schema + fine-tuning guidance - 17 new tests (service replay, endpoint filters, CLI) - 201 total tests passing, ruff / mypy --strict clean Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com>
This commit is contained in:
158
tests/unit/test_api_ground_truth.py
Normal file
158
tests/unit/test_api_ground_truth.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""HTTP tests for the ground-truth export endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ocr_sprint.main import create_app
|
||||
from ocr_sprint.pipeline import orchestrator as orch_module
|
||||
from ocr_sprint.pipeline.orchestrator import PipelineOutput
|
||||
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||
from ocr_sprint.schemas.extraction import (
|
||||
ExtractionResult,
|
||||
HeaderFields,
|
||||
PersonnelEntry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> TestClient:
|
||||
return TestClient(create_app())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput:
|
||||
result = ExtractionResult(
|
||||
header=HeaderFields(
|
||||
nomor_sprint="Sprin/1/I/2025",
|
||||
tanggal=date(2025, 1, 1),
|
||||
satuan_penerbit="POLRES TEST",
|
||||
),
|
||||
personel=[
|
||||
PersonnelEntry(pangkat="AIPDA", nrp="77060000", nama="BUDI", jabatan="ANGGOTA"),
|
||||
],
|
||||
confidence=0.9,
|
||||
)
|
||||
output = PipelineOutput(
|
||||
source_kind=SourceKind.PDF,
|
||||
status=DocumentStatus.COMPLETED,
|
||||
confidence=0.9,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def _fake_run(_content: bytes) -> PipelineOutput:
|
||||
return output
|
||||
|
||||
monkeypatch.setattr(orch_module, "run_pipeline", _fake_run)
|
||||
from ocr_sprint.api.routes import documents as docs_module
|
||||
|
||||
monkeypatch.setattr(docs_module, "run_pipeline", _fake_run)
|
||||
from ocr_sprint.worker import tasks as tasks_module
|
||||
|
||||
monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run)
|
||||
return output
|
||||
|
||||
|
||||
def _create_and_approve(client: TestClient, *, correction_value: str | None = None) -> str:
|
||||
post = client.post(
|
||||
"/api/v1/documents?sync=true",
|
||||
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||
)
|
||||
assert post.status_code == 200, post.text
|
||||
jid = str(post.json()["job_id"])
|
||||
if correction_value is not None:
|
||||
patched = client.patch(
|
||||
f"/api/v1/documents/{jid}",
|
||||
json={"corrections": [{"path": "header.perihal", "value": correction_value}]},
|
||||
)
|
||||
assert patched.status_code == 200
|
||||
approved = client.post(f"/api/v1/documents/{jid}/approve")
|
||||
assert approved.status_code == 200
|
||||
return jid
|
||||
|
||||
|
||||
def test_stats_empty_dataset(client: TestClient) -> None:
|
||||
resp = client.get("/api/v1/ground-truth/stats")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["total_jobs"] == 0
|
||||
assert body["approved_jobs"] == 0
|
||||
assert body["total_corrections"] == 0
|
||||
assert body["top_corrected_fields"] == []
|
||||
|
||||
|
||||
def test_stats_rolls_up_counts(client: TestClient, fake_pipeline: PipelineOutput) -> None:
|
||||
_create_and_approve(client, correction_value="Penyelidikan-1")
|
||||
_create_and_approve(client, correction_value="Penyelidikan-2")
|
||||
_create_and_approve(client, correction_value=None) # pristine
|
||||
|
||||
resp = client.get("/api/v1/ground-truth/stats")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["total_jobs"] == 3
|
||||
assert body["approved_jobs"] == 3
|
||||
assert body["total_corrections"] == 2
|
||||
assert body["jobs_with_corrections"] == 2
|
||||
assert body["top_corrected_fields"][0]["field_path"] == "header.perihal"
|
||||
assert body["top_corrected_fields"][0]["count"] == 2
|
||||
|
||||
|
||||
def test_export_streams_jsonl(client: TestClient, fake_pipeline: PipelineOutput) -> None:
|
||||
_create_and_approve(client, correction_value="Penyelidikan")
|
||||
_create_and_approve(client, correction_value=None)
|
||||
|
||||
resp = client.get("/api/v1/ground-truth/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("application/x-ndjson")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 2
|
||||
parsed = [json.loads(line) for line in lines]
|
||||
for sample in parsed:
|
||||
assert sample["approved"] is True
|
||||
assert "initial_result" in sample
|
||||
assert "final_result" in sample
|
||||
|
||||
|
||||
def test_export_approved_only_default(client: TestClient, fake_pipeline: PipelineOutput) -> None:
|
||||
"""Unapproved jobs shouldn't appear in the default export."""
|
||||
# One approved, one just completed (no approve call).
|
||||
_create_and_approve(client, correction_value=None)
|
||||
client.post(
|
||||
"/api/v1/documents?sync=true",
|
||||
files={"file": ("y.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||
)
|
||||
resp = client.get("/api/v1/ground-truth/export")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 1
|
||||
|
||||
# Toggle approved_only=false to include both.
|
||||
resp = client.get("/api/v1/ground-truth/export?approved_only=false")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 2
|
||||
|
||||
|
||||
def test_export_has_corrections_filter(client: TestClient, fake_pipeline: PipelineOutput) -> None:
|
||||
_create_and_approve(client, correction_value="Penyelidikan")
|
||||
_create_and_approve(client, correction_value=None)
|
||||
|
||||
resp = client.get("/api/v1/ground-truth/export?has_corrections=true")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 1
|
||||
assert json.loads(lines[0])["corrections"][0]["new_value"] == "Penyelidikan"
|
||||
|
||||
resp = client.get("/api/v1/ground-truth/export?has_corrections=false")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 1
|
||||
assert json.loads(lines[0])["corrections"] == []
|
||||
|
||||
|
||||
def test_export_respects_limit(client: TestClient, fake_pipeline: PipelineOutput) -> None:
|
||||
for _ in range(5):
|
||||
_create_and_approve(client, correction_value=None)
|
||||
resp = client.get("/api/v1/ground-truth/export?limit=2")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 2
|
||||
80
tests/unit/test_cli_export_ground_truth.py
Normal file
80
tests/unit/test_cli_export_ground_truth.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Smoke tests for the ``ocr_sprint.tools.export_ground_truth`` CLI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from ocr_sprint.db.base import Base, get_engine, session_scope
|
||||
from ocr_sprint.db.repositories import JobRepository
|
||||
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||
from ocr_sprint.tools.export_ground_truth import main as export_main
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_ready() -> None:
|
||||
Base.metadata.create_all(bind=get_engine())
|
||||
|
||||
|
||||
def _seed_two_approved_jobs() -> None:
|
||||
for _ in range(2):
|
||||
jid = uuid4()
|
||||
with session_scope() as session:
|
||||
JobRepository(session).create(
|
||||
job_id=jid, filename="x.pdf", source_kind=SourceKind.PDF, blob_key="k"
|
||||
)
|
||||
with session_scope() as session:
|
||||
JobRepository(session).mark_completed(
|
||||
jid,
|
||||
status=DocumentStatus.COMPLETED,
|
||||
confidence=0.9,
|
||||
result={"header": {"nomor_sprint": "SPR/1/2025"}},
|
||||
review_flags=[],
|
||||
)
|
||||
with session_scope() as session:
|
||||
JobRepository(session).approve(jid, reviewed_by="rev")
|
||||
|
||||
|
||||
def test_cli_writes_expected_number_of_lines(
|
||||
db_ready: None, tmp_path: Path, capsys: pytest.CaptureFixture[str]
|
||||
) -> None:
|
||||
_seed_two_approved_jobs()
|
||||
out = tmp_path / "corpus.jsonl"
|
||||
|
||||
exit_code = export_main(["--out", str(out)])
|
||||
assert exit_code == 0
|
||||
lines = [line for line in out.read_text().splitlines() if line.strip()]
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
parsed = json.loads(line)
|
||||
assert parsed["approved"] is True
|
||||
|
||||
stderr = capsys.readouterr().err
|
||||
assert "wrote 2 sample(s)" in stderr
|
||||
|
||||
|
||||
def test_cli_respects_limit(db_ready: None, tmp_path: Path) -> None:
|
||||
_seed_two_approved_jobs()
|
||||
out = tmp_path / "corpus.jsonl"
|
||||
exit_code = export_main(["--out", str(out), "--limit", "1"])
|
||||
assert exit_code == 0
|
||||
lines = [line for line in out.read_text().splitlines() if line.strip()]
|
||||
assert len(lines) == 1
|
||||
|
||||
|
||||
def test_cli_print_stats_emits_json_to_stderr(
|
||||
db_ready: None, tmp_path: Path, capsys: pytest.CaptureFixture[str]
|
||||
) -> None:
|
||||
_seed_two_approved_jobs()
|
||||
out = tmp_path / "corpus.jsonl"
|
||||
exit_code = export_main(["--out", str(out), "--print-stats"])
|
||||
assert exit_code == 0
|
||||
stderr = capsys.readouterr().err
|
||||
# Validate the JSON prologue (after the "wrote N" line).
|
||||
json_start = stderr.index("{")
|
||||
stats = json.loads(stderr[json_start:])
|
||||
assert stats["total_jobs"] == 2
|
||||
assert stats["approved_jobs"] == 2
|
||||
210
tests/unit/test_ground_truth_service.py
Normal file
210
tests/unit/test_ground_truth_service.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Unit tests for the ground-truth export service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from ocr_sprint.db.base import Base, get_engine, session_scope
|
||||
from ocr_sprint.db.repositories import JobRepository
|
||||
from ocr_sprint.ground_truth import (
|
||||
GroundTruthFilters,
|
||||
build_initial_result,
|
||||
ground_truth_stats,
|
||||
iter_ground_truth_samples,
|
||||
serialize_sample_to_jsonl,
|
||||
)
|
||||
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_ready() -> None:
|
||||
Base.metadata.create_all(bind=get_engine())
|
||||
|
||||
|
||||
def _seed_approved_job_with_corrections(
|
||||
*,
|
||||
final_result: dict[str, object] | None = None,
|
||||
corrections: list[tuple[str, object, object]] | None = None,
|
||||
approved: bool = True,
|
||||
) -> UUID:
|
||||
jid = uuid4()
|
||||
with session_scope() as session:
|
||||
repo = JobRepository(session)
|
||||
repo.create(job_id=jid, filename="x.pdf", source_kind=SourceKind.PDF, blob_key="k")
|
||||
with session_scope() as session:
|
||||
JobRepository(session).mark_completed(
|
||||
jid,
|
||||
status=DocumentStatus.NEEDS_REVIEW,
|
||||
confidence=0.8,
|
||||
result=final_result
|
||||
or {
|
||||
"header": {"nomor_sprint": "SPR/1/2025", "satuan_penerbit": "POLRES X"},
|
||||
"personel": [{"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"}],
|
||||
},
|
||||
review_flags=[],
|
||||
)
|
||||
if corrections:
|
||||
# Corrections are tuples of (path, value, old_value_for_audit).
|
||||
# We translate them into real PATCH calls so the audit rows look
|
||||
# exactly like they would in production.
|
||||
with session_scope() as session:
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[(p, new, None) for (p, new, _old) in corrections],
|
||||
corrected_by="reviewer-a",
|
||||
)
|
||||
if approved:
|
||||
with session_scope() as session:
|
||||
JobRepository(session).approve(jid, reviewed_by="reviewer-a")
|
||||
return jid
|
||||
|
||||
|
||||
def test_build_initial_result_reverses_corrections() -> None:
|
||||
"""With known old/new pairs the replay must return the pre-HITL dict."""
|
||||
final = {
|
||||
"header": {"nomor_sprint": "SPR/1/2025", "perihal": "Penyelidikan"},
|
||||
}
|
||||
|
||||
class _Fake:
|
||||
def __init__(self, path: str, old: object, new: object) -> None:
|
||||
self.field_path = path
|
||||
self.old_value = old
|
||||
self.new_value = new
|
||||
|
||||
corrections = [
|
||||
_Fake("header.perihal", None, "Penyelidikan"),
|
||||
_Fake("header.nomor_sprint", "SPR/OLD/2025", "SPR/1/2025"),
|
||||
]
|
||||
restored = build_initial_result(final_result=final, corrections=corrections) # type: ignore[arg-type]
|
||||
assert restored == {
|
||||
"header": {"nomor_sprint": "SPR/OLD/2025", "perihal": None},
|
||||
}
|
||||
# ``final`` must not have been mutated.
|
||||
assert final["header"]["nomor_sprint"] == "SPR/1/2025"
|
||||
|
||||
|
||||
def test_build_initial_result_ignores_unresolvable_paths() -> None:
|
||||
"""A legacy path that no longer resolves should be skipped, not raise."""
|
||||
final = {"header": {"nomor_sprint": "SPR/1"}}
|
||||
|
||||
class _Fake:
|
||||
def __init__(self, path: str, old: object, new: object) -> None:
|
||||
self.field_path = path
|
||||
self.old_value = old
|
||||
self.new_value = new
|
||||
|
||||
corrections = [
|
||||
_Fake("personel[99].nrp", "stale", "new"), # container doesn't exist
|
||||
_Fake("bogus..path", "x", "y"),
|
||||
]
|
||||
restored = build_initial_result(final_result=final, corrections=corrections) # type: ignore[arg-type]
|
||||
assert restored == final
|
||||
|
||||
|
||||
def test_iter_samples_reconstructs_initial_and_final(db_ready: None) -> None:
|
||||
jid = _seed_approved_job_with_corrections(
|
||||
corrections=[("header.perihal", "Penyelidikan", None)],
|
||||
)
|
||||
with session_scope() as session:
|
||||
samples = list(iter_ground_truth_samples(session, GroundTruthFilters()))
|
||||
assert len(samples) == 1
|
||||
s = samples[0]
|
||||
assert s.job_id == jid
|
||||
assert s.approved is True
|
||||
assert s.reviewed_by == "reviewer-a"
|
||||
assert s.final_result is not None
|
||||
assert s.final_result["header"]["perihal"] == "Penyelidikan"
|
||||
# Initial reconstruction undoes the correction.
|
||||
assert s.initial_result is not None
|
||||
assert s.initial_result["header"].get("perihal") is None
|
||||
assert len(s.corrections) == 1
|
||||
assert s.corrections[0].field_path == "header.perihal"
|
||||
|
||||
|
||||
def test_iter_samples_respects_approved_only_default(db_ready: None) -> None:
|
||||
"""Unapproved jobs must be excluded unless approved_only=False."""
|
||||
_seed_approved_job_with_corrections(approved=False)
|
||||
with session_scope() as session:
|
||||
samples = list(iter_ground_truth_samples(session, GroundTruthFilters()))
|
||||
assert samples == []
|
||||
with session_scope() as session:
|
||||
samples = list(iter_ground_truth_samples(session, GroundTruthFilters(approved_only=False)))
|
||||
assert len(samples) == 1
|
||||
|
||||
|
||||
def test_iter_samples_respects_has_corrections_filter(db_ready: None) -> None:
|
||||
_seed_approved_job_with_corrections(corrections=None) # pristine
|
||||
_seed_approved_job_with_corrections(corrections=[("header.perihal", "fill", None)])
|
||||
with session_scope() as session:
|
||||
with_ = list(iter_ground_truth_samples(session, GroundTruthFilters(has_corrections=True)))
|
||||
without = list(
|
||||
iter_ground_truth_samples(session, GroundTruthFilters(has_corrections=False))
|
||||
)
|
||||
assert len(with_) == 1
|
||||
assert len(without) == 1
|
||||
assert with_[0].job_id != without[0].job_id
|
||||
|
||||
|
||||
def test_iter_samples_respects_since_until_and_limit(db_ready: None) -> None:
|
||||
_seed_approved_job_with_corrections()
|
||||
_seed_approved_job_with_corrections()
|
||||
_seed_approved_job_with_corrections()
|
||||
|
||||
# Since far in the future → empty.
|
||||
future = datetime(2999, 1, 1, tzinfo=timezone.utc)
|
||||
with session_scope() as session:
|
||||
out = list(iter_ground_truth_samples(session, GroundTruthFilters(since=future)))
|
||||
assert out == []
|
||||
|
||||
# Limit caps the number emitted.
|
||||
with session_scope() as session:
|
||||
out = list(iter_ground_truth_samples(session, GroundTruthFilters(limit=2)))
|
||||
assert len(out) == 2
|
||||
|
||||
|
||||
def test_stats_counts_rollup_and_top_fields(db_ready: None) -> None:
|
||||
# 1 approved job with 2 corrections on ``header.perihal``.
|
||||
_seed_approved_job_with_corrections(
|
||||
corrections=[
|
||||
("header.perihal", "first", None),
|
||||
]
|
||||
)
|
||||
jid = _seed_approved_job_with_corrections(
|
||||
corrections=[
|
||||
("header.perihal", "second", None),
|
||||
("personel[0].nrp", "77060001", None),
|
||||
],
|
||||
)
|
||||
assert jid # silence unused
|
||||
|
||||
with session_scope() as session:
|
||||
stats = ground_truth_stats(session)
|
||||
|
||||
assert stats.total_jobs == 2
|
||||
assert stats.approved_jobs == 2
|
||||
assert stats.total_corrections == 3
|
||||
assert stats.jobs_with_corrections == 2
|
||||
# ``header.perihal`` is the most-corrected field (2 > 1).
|
||||
assert stats.top_corrected_fields[0].field_path == "header.perihal"
|
||||
assert stats.top_corrected_fields[0].count == 2
|
||||
assert {f.field_path for f in stats.top_corrected_fields} == {
|
||||
"header.perihal",
|
||||
"personel[0].nrp",
|
||||
}
|
||||
|
||||
|
||||
def test_serialize_is_valid_jsonl(db_ready: None) -> None:
|
||||
_seed_approved_job_with_corrections(corrections=[("header.perihal", "X", None)])
|
||||
with session_scope() as session:
|
||||
for sample in iter_ground_truth_samples(session, GroundTruthFilters()):
|
||||
line = serialize_sample_to_jsonl(sample)
|
||||
assert line.endswith("\n")
|
||||
# Strip trailing newline and make sure the result is a single
|
||||
# JSON object (not a list — JSONL must be one object per line).
|
||||
parsed = json.loads(line)
|
||||
assert isinstance(parsed, dict)
|
||||
assert parsed["approved"] is True
|
||||
Reference in New Issue
Block a user