Phase 7: ground-truth export (JSONL + stats) + CLI tool

- GET /api/v1/ground-truth/export  streaming JSONL (approved_only,
  since, until, has_corrections, limit)
- GET /api/v1/ground-truth/stats   total / approved / corrections
  counts + top-N most-corrected field paths
- python -m ocr_sprint.tools.export_ground_truth  operator CLI with
  the same filters + optional --print-stats
- Ground-truth sample reconstructs the pipeline's original output by
  replaying job_corrections in reverse
- docs/ground-truth-format.md    schema + fine-tuning guidance
- 17 new tests (service replay, endpoint filters, CLI)
- 201 total tests passing, ruff / mypy --strict clean

Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com>
This commit is contained in:
Devin AI
2026-04-25 20:24:40 +00:00
parent 9457fa3c55
commit 6003d96a94
11 changed files with 1148 additions and 1 deletions

112
docs/ground-truth-format.md Normal file
View File

@@ -0,0 +1,112 @@
# Ground-truth export format (Phase 7)
The service exposes the HITL corpus as [JSONL](https://jsonlines.org/) —
one training sample per line — via the HTTP endpoint
`GET /api/v1/ground-truth/export` and the equivalent CLI:
```bash
python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl
```
Both paths read the same database and emit byte-identical JSONL, so a
cron-scheduled dump and an ad-hoc curl download are interchangeable.
## Sample schema
Each line is a single JSON object with the following shape:
```jsonc
{
"job_id": "c5da6747-...",
"filename": "sprint-042.pdf",
"source_kind": "pdf",
"approved": true,
"reviewed_by": "reviewer-a", // free-form; comes from X-User-Id
"reviewed_at": "2025-06-01T10:15:00Z",
"created_at": "2025-05-28T08:02:17Z",
// The pipeline's original pre-HITL output, reconstructed by replaying
// the audit trail backwards. `null` for jobs that never produced a
// result (e.g. hard-failed on OCR).
"initial_result": {
"header": { "nomor_sprint": "Sprin/1/I/2025", "perihal": null, ... },
"personel": [ { "pangkat": "AIPDA", "nrp": "77060000", ... } ],
...
},
// The reviewer-approved answer (current value of jobs.result).
"final_result": { ...same shape as initial_result... },
// Every correction event, in chronological order.
"corrections": [
{
"field_path": "header.perihal",
"old_value": null,
"new_value": "Penyelidikan kasus pencurian",
"corrected_by": "reviewer-a",
"reason": "LLM missed it",
"corrected_at": "2025-05-30T14:00:00Z"
}
],
"review_flags": ["llm_fallback"],
"confidence": 0.78
}
```
## Recommended filters
* `approved_only=true` (default) — **do not** train on unreviewed
samples; they can still contain OCR mistakes.
* `has_corrections=true` — for a "hard examples" set where the pipeline
was originally wrong.
* `has_corrections=false` — for a "sanity" set where the pipeline was
already right. Good for regression tests after fine-tuning.
* `since` / `until` — build incremental snapshots without re-processing
the full history.
## When is the dataset big enough to fine-tune?
Rough operational checklist (rules of thumb — adjust based on your own
error analysis):
| Bucket | Minimum rows | Notes |
|---------------------------------|--------------|---------------------------------------------------------------|
| LoRA on header extraction (LLM) | ~200500 | Per-field error signal must be > random noise. |
| Per-satuan prompt tuning | ~50 / satuan | Helps when formats differ sharply between Polda/Polres units. |
| PP-Structure table fine-tune | ~1 000+ | Layout models are data-hungry; hold off until HITL is steady. |
Use `GET /api/v1/ground-truth/stats` to check coverage:
```json
{
"total_jobs": 842,
"approved_jobs": 613,
"total_corrections": 1 204,
"jobs_with_corrections": 431,
"top_corrected_fields": [
{ "field_path": "header.perihal", "count": 289 },
{ "field_path": "personel[0].nrp", "count": 51 },
...
]
}
```
Fields at the top of `top_corrected_fields` are the highest-leverage
targets for prompt tweaks, regex upgrades, or (eventually) fine-tuning.
## Fine-tuning outside this repo
The export is deliberately framework-agnostic. Suggested follow-ups on
dedicated GPU hardware:
* [**Unsloth**](https://github.com/unslothai/unsloth) — LoRA on
Qwen2.5 / Llama 3.1 with 24 × speedups on a single GPU.
* [**Axolotl**](https://github.com/axolotl-ai-cloud/axolotl) — more
batteries-included; good for multi-GPU runs.
Typical prompt-completion conversion: feed `initial_result` (or the raw
OCR text, if your pipeline keeps it) as the "input" and `final_result`
as the "output". The `corrections` list is only needed if you want to
build an error-class analysis — the model itself trains on the final
answer.

View File

@@ -0,0 +1,102 @@
"""Ground-truth export + statistics endpoints (Phase 7).
Two endpoints, both auth'd by the existing ``X-API-Key`` dependency:
* ``GET /ground-truth/export`` — streams JSONL of approved (or filtered)
samples for downstream fine-tuning pipelines.
* ``GET /ground-truth/stats`` — returns aggregate counts + top-corrected
field paths so operators know when/where fine-tuning will pay off.
"""
from __future__ import annotations
from collections.abc import Iterator
from datetime import datetime
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from ocr_sprint.api.deps.auth import require_api_key
from ocr_sprint.api.deps.db import get_session
from ocr_sprint.ground_truth import (
GroundTruthFilters,
ground_truth_stats,
iter_ground_truth_samples,
serialize_sample_to_jsonl,
)
from ocr_sprint.schemas.ground_truth import GroundTruthStats
router = APIRouter(
prefix="/ground-truth",
tags=["ground-truth"],
dependencies=[Depends(require_api_key)],
)
@router.get(
"/export",
response_class=StreamingResponse,
responses={
200: {
"content": {"application/x-ndjson": {}},
"description": "Newline-delimited JSON stream of training samples.",
}
},
)
def export_ground_truth(
session: Annotated[Session, Depends(get_session)],
since: Annotated[
datetime | None,
Query(description="Only include jobs created at or after this ISO timestamp."),
] = None,
until: Annotated[
datetime | None,
Query(description="Only include jobs created at or before this ISO timestamp."),
] = None,
approved_only: Annotated[
bool,
Query(description="Only export approved jobs (default true)."),
] = True,
has_corrections: Annotated[
bool | None,
Query(
description=(
"Optional: true = only jobs that had at least one correction, "
"false = only pristine (no-correction) jobs."
)
),
] = None,
limit: Annotated[
int | None,
Query(ge=1, le=100_000, description="Maximum rows to emit."),
] = None,
) -> StreamingResponse:
filters = GroundTruthFilters(
since=since,
until=until,
approved_only=approved_only,
has_corrections=has_corrections,
limit=limit,
)
def _stream() -> Iterator[bytes]:
for sample in iter_ground_truth_samples(session, filters):
yield serialize_sample_to_jsonl(sample).encode("utf-8")
return StreamingResponse(_stream(), media_type="application/x-ndjson")
@router.get(
"/stats",
response_model=GroundTruthStats,
)
def get_stats(
session: Annotated[Session, Depends(get_session)],
top_n: Annotated[
int,
Query(ge=1, le=100, description="How many top-corrected field paths to return."),
] = 10,
) -> GroundTruthStats:
return ground_truth_stats(session, top_n=top_n)

View File

@@ -0,0 +1,22 @@
"""Phase 7 — ground-truth export service.
Consumes the ``jobs`` + ``job_corrections`` tables to produce JSONL
training samples + aggregate statistics. No ML dependencies here: the
actual fine-tuning runs in a separate project on dedicated hardware.
"""
from ocr_sprint.ground_truth.service import (
GroundTruthFilters,
build_initial_result,
ground_truth_stats,
iter_ground_truth_samples,
serialize_sample_to_jsonl,
)
__all__ = [
"GroundTruthFilters",
"build_initial_result",
"ground_truth_stats",
"iter_ground_truth_samples",
"serialize_sample_to_jsonl",
]

View File

@@ -0,0 +1,248 @@
"""Core ground-truth export logic.
The service reads from ``jobs`` and ``job_corrections`` and produces
``GroundTruthSample`` rows. The HTTP layer and the CLI both go through
this module, so the export logic stays in exactly one place.
Design notes
------------
* **Reverse-replay for ``initial_result``.** Audit rows store both
``old_value`` and ``new_value``. We start from the current
``final_result`` and undo each correction in reverse chronological
order to recover the pipeline's original pre-HITL output. This is
what fine-tuning needs: the raw mistake next to the reviewer fix.
* **No deep copying in the caller.** The service always returns freshly
computed dicts so mutating one sample can't leak into the next.
* **Generator by design.** Large exports shouldn't load the whole table
into memory; the HTTP route streams JSONL line-by-line.
"""
from __future__ import annotations
import copy
import re
from collections.abc import Iterator
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from ocr_sprint.db.models import JobCorrectionRow, JobRow
from ocr_sprint.schemas.document import DocumentStatus
from ocr_sprint.schemas.ground_truth import (
FieldCorrectionCount,
GroundTruthCorrection,
GroundTruthSample,
GroundTruthStats,
)
@dataclass(frozen=True)
class GroundTruthFilters:
"""HTTP / CLI filters for ``iter_ground_truth_samples``.
All fields are optional; ``approved_only=True`` is the default because
only approved samples are actually safe to train on.
"""
since: datetime | None = None
until: datetime | None = None
approved_only: bool = True
has_corrections: bool | None = None
limit: int | None = None
def build_initial_result(
*,
final_result: dict[str, Any] | None,
corrections: list[JobCorrectionRow],
) -> dict[str, Any] | None:
"""Replay ``corrections`` backwards over ``final_result`` to recover
the pre-HITL version. Returns ``None`` if there's nothing to
reconstruct.
"""
if final_result is None:
return None
restored = copy.deepcopy(final_result)
for event in reversed(corrections):
_assign_path(restored, event.field_path, event.old_value)
return restored
def _assign_path(data: dict[str, Any], path: str, value: Any) -> None:
"""Set ``data[path] = value`` using the same dotted syntax the HITL
repository understands. Silently ignores paths that no longer resolve
(the dict shape may have drifted since the event was recorded; in
that case we just skip the replay step rather than raising).
"""
segments = _parse(path)
if not segments:
return
cursor: Any = data
for name, idx in segments[:-1]:
if not isinstance(cursor, dict) or name not in cursor:
return
cursor = cursor[name]
if idx is not None:
if not isinstance(cursor, list) or idx >= len(cursor):
return
cursor = cursor[idx]
name, idx = segments[-1]
if idx is not None:
if not isinstance(cursor, dict) or name not in cursor:
return
container = cursor[name]
if not isinstance(container, list) or idx >= len(container):
return
container[idx] = value
return
if not isinstance(cursor, dict):
return
cursor[name] = value
_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$")
def _parse(path: str) -> list[tuple[str, int | None]]:
"""Forgiving dotted-path parser — returns ``[]`` on malformed input
instead of raising, since audit-log replay must not blow up on a
stale path. The strict variant used by the HITL PATCH route lives
in ``ocr_sprint.db.repositories``.
"""
if not path or path.startswith(".") or path.endswith("."):
return []
out: list[tuple[str, int | None]] = []
for part in path.split("."):
match: re.Match[str] | None = _SEGMENT_RE.match(part)
if match is None:
return []
out.append((match.group(1), int(match.group(2)) if match.group(2) else None))
return out
def iter_ground_truth_samples(
session: Session,
filters: GroundTruthFilters,
) -> Iterator[GroundTruthSample]:
"""Yield ``GroundTruthSample`` rows matching ``filters``.
Rows are ordered by ``created_at ASC`` so repeated exports are
deterministic and downstream diffing works.
"""
stmt = select(JobRow).order_by(JobRow.created_at, JobRow.job_id)
if filters.approved_only:
stmt = stmt.where(JobRow.approved.is_(True))
if filters.since is not None:
stmt = stmt.where(JobRow.created_at >= filters.since)
if filters.until is not None:
stmt = stmt.where(JobRow.created_at <= filters.until)
remaining = filters.limit
for job_row in session.scalars(stmt):
if remaining is not None and remaining <= 0:
return
correction_rows = list(
session.scalars(
select(JobCorrectionRow)
.where(JobCorrectionRow.job_id == job_row.job_id)
.order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id)
)
)
if filters.has_corrections is True and not correction_rows:
continue
if filters.has_corrections is False and correction_rows:
continue
initial = build_initial_result(final_result=job_row.result, corrections=correction_rows)
sample = GroundTruthSample(
job_id=job_row.job_id,
filename=job_row.filename,
source_kind=job_row.source_kind,
approved=bool(job_row.approved),
reviewed_by=job_row.reviewed_by,
reviewed_at=job_row.reviewed_at,
created_at=job_row.created_at,
initial_result=initial,
final_result=copy.deepcopy(job_row.result) if job_row.result else None,
corrections=[
GroundTruthCorrection(
field_path=c.field_path,
old_value=c.old_value,
new_value=c.new_value,
corrected_by=c.corrected_by,
reason=c.reason,
corrected_at=c.corrected_at,
)
for c in correction_rows
],
review_flags=list(job_row.review_flags or []),
confidence=job_row.confidence,
)
if remaining is not None:
remaining -= 1
yield sample
def ground_truth_stats(session: Session, *, top_n: int = 10) -> GroundTruthStats:
"""Compute aggregate statistics over the jobs + corrections tables."""
by_status: dict[str, int] = {
row[0]: int(row[1])
for row in session.execute(
select(JobRow.status, func.count(JobRow.job_id)).group_by(JobRow.status)
).all()
}
total_jobs = int(sum(by_status.values()))
approved_jobs = int(
session.execute(
select(func.count(JobRow.job_id)).where(JobRow.approved.is_(True))
).scalar_one()
)
total_corrections = int(session.execute(select(func.count(JobCorrectionRow.id))).scalar_one())
jobs_with_corrections = int(
session.execute(select(func.count(func.distinct(JobCorrectionRow.job_id)))).scalar_one()
)
# Field-path histogram. A single SQL ``GROUP BY`` is cheaper than
# pulling every row over the wire, but we still cap at ``top_n`` so
# a pathological dataset can't drown the response.
top_rows = session.execute(
select(
JobCorrectionRow.field_path,
func.count(JobCorrectionRow.id).label("n"),
)
.group_by(JobCorrectionRow.field_path)
.order_by(func.count(JobCorrectionRow.id).desc(), JobCorrectionRow.field_path)
.limit(top_n)
).all()
return GroundTruthStats(
total_jobs=total_jobs,
completed_jobs=by_status.get(DocumentStatus.COMPLETED.value, 0),
needs_review_jobs=by_status.get(DocumentStatus.NEEDS_REVIEW.value, 0),
failed_jobs=by_status.get(DocumentStatus.FAILED.value, 0),
approved_jobs=approved_jobs,
total_corrections=total_corrections,
jobs_with_corrections=jobs_with_corrections,
top_corrected_fields=[
FieldCorrectionCount(field_path=row[0], count=int(row[1])) for row in top_rows
],
)
def serialize_sample_to_jsonl(sample: GroundTruthSample) -> str:
"""Serialize one sample as a newline-terminated JSON string.
Kept standalone so the HTTP and CLI layers share the exact same
representation — diffs between a curl download and a CLI dump stay
byte-identical.
"""
return sample.model_dump_json() + "\n"

View File

@@ -7,7 +7,7 @@ from fastapi import FastAPI
from ocr_sprint import __version__
from ocr_sprint.api.errors import register_error_handlers
from ocr_sprint.api.metrics import MetricsMiddleware, metrics_endpoint
from ocr_sprint.api.routes import documents, health
from ocr_sprint.api.routes import documents, ground_truth, health
from ocr_sprint.config import get_settings
from ocr_sprint.db import models as _models # noqa: F401 (register ORM tables)
from ocr_sprint.db.base import Base, get_engine
@@ -43,6 +43,7 @@ def create_app() -> FastAPI:
app.add_middleware(MetricsMiddleware)
app.include_router(health.router, prefix="/api/v1")
app.include_router(documents.router, prefix="/api/v1")
app.include_router(ground_truth.router, prefix="/api/v1")
app.add_api_route("/metrics", metrics_endpoint, methods=["GET"], include_in_schema=False)
return app

View File

@@ -0,0 +1,76 @@
"""Schemas for the Phase 7 ground-truth export.
Each ``GroundTruthSample`` represents one training-ready example:
* ``initial_*`` snapshots the pipeline's original (pre-HITL) output,
reconstructed by replaying the audit trail in reverse.
* ``final_*`` is the current ``result`` on the ``jobs`` row — the
reviewer-approved answer.
* ``corrections`` is the raw audit trail so downstream fine-tuning can
see *what* was changed, *why* (free-text reason), and by whom.
JSONL is emitted — one sample per line — so the file can be mmapped,
streamed, or piped straight into an HF ``datasets.load_dataset("json",
...)`` call.
"""
from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
class GroundTruthCorrection(BaseModel):
"""One row of the ``job_corrections`` audit trail, as exported."""
field_path: str
old_value: Any | None = None
new_value: Any | None = None
corrected_by: str | None = None
reason: str | None = None
corrected_at: datetime
class GroundTruthSample(BaseModel):
"""One training sample written as a single JSONL line."""
model_config = ConfigDict(populate_by_name=True)
job_id: UUID
filename: str
source_kind: str
approved: bool = False
reviewed_by: str | None = None
reviewed_at: datetime | None = None
created_at: datetime
# ``initial_*`` is the pipeline's pre-HITL answer, reconstructed from
# the audit trail. ``final_*`` is the reviewer-approved version.
initial_result: dict[str, Any] | None = None
final_result: dict[str, Any] | None = None
corrections: list[GroundTruthCorrection] = Field(default_factory=list)
review_flags: list[str] = Field(default_factory=list)
confidence: float | None = None
class FieldCorrectionCount(BaseModel):
field_path: str
count: int
class GroundTruthStats(BaseModel):
"""High-level dataset health report surfaced by ``GET /ground-truth/stats``."""
total_jobs: int
completed_jobs: int
needs_review_jobs: int
failed_jobs: int
approved_jobs: int
total_corrections: int
jobs_with_corrections: int
# Most-corrected field paths (descending). Operators use this to
# prioritise which fields to target with prompt tweaks or fine-tune
# data collection first.
top_corrected_fields: list[FieldCorrectionCount] = Field(default_factory=list)

View File

@@ -0,0 +1 @@
"""Operator CLI utilities. Invoked via ``python -m ocr_sprint.tools.<name>``."""

View File

@@ -0,0 +1,137 @@
"""CLI: ``python -m ocr_sprint.tools.export_ground_truth``.
Dumps the HITL ground-truth corpus to a local JSONL file. The HTTP
endpoint is fine for small queries, but weekly / monthly snapshots are
easier to schedule on the host via cron than via curl.
Usage
-----
python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl
python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl \
--since 2025-01-01 --has-corrections
The database URL is read from the same ``DATABASE_URL`` env var the API
uses, so this script works out of the box on any box that can run the
service.
"""
from __future__ import annotations
import argparse
import sys
from datetime import datetime
from pathlib import Path
from typing import Protocol
from ocr_sprint.db.base import session_scope
from ocr_sprint.ground_truth import (
GroundTruthFilters,
ground_truth_stats,
iter_ground_truth_samples,
serialize_sample_to_jsonl,
)
def _parse_iso(raw: str | None) -> datetime | None:
if raw is None:
return None
# ``fromisoformat`` supports dates too, so ``--since 2025-01-01``
# behaves like ``2025-01-01T00:00:00``.
return datetime.fromisoformat(raw)
def _build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
prog="ocr_sprint.tools.export_ground_truth",
description="Dump HITL ground-truth corpus to JSONL.",
)
p.add_argument(
"--out",
required=True,
type=Path,
help="Output JSONL file. Use '-' to write to stdout.",
)
p.add_argument("--since", type=str, default=None, help="ISO date/datetime lower bound.")
p.add_argument("--until", type=str, default=None, help="ISO date/datetime upper bound.")
p.add_argument(
"--include-unapproved",
action="store_true",
help=(
"Include jobs that haven't been approved yet. Not recommended "
"for training data — use only for debugging exports."
),
)
group = p.add_mutually_exclusive_group()
group.add_argument(
"--has-corrections",
action="store_true",
help="Only include jobs that had at least one correction.",
)
group.add_argument(
"--no-corrections",
action="store_true",
help="Only include pristine jobs (no corrections). Useful for sanity sets.",
)
p.add_argument("--limit", type=int, default=None, help="Maximum samples to emit.")
p.add_argument(
"--print-stats",
action="store_true",
help="Also print dataset summary to stderr after the dump.",
)
return p
def main(argv: list[str] | None = None) -> int:
args = _build_parser().parse_args(argv)
has_corrections: bool | None
if args.has_corrections:
has_corrections = True
elif args.no_corrections:
has_corrections = False
else:
has_corrections = None
filters = GroundTruthFilters(
since=_parse_iso(args.since),
until=_parse_iso(args.until),
approved_only=not args.include_unapproved,
has_corrections=has_corrections,
limit=args.limit,
)
out: Path = args.out
count = 0
if str(out) == "-":
_write_stream(sys.stdout.buffer, filters)
else:
out.parent.mkdir(parents=True, exist_ok=True)
with out.open("wb") as fh:
count = _write_stream(fh, filters)
print(f"wrote {count} sample(s) to {out}", file=sys.stderr)
if args.print_stats:
with session_scope() as session:
stats = ground_truth_stats(session)
print(stats.model_dump_json(indent=2), file=sys.stderr)
return 0
class _BinaryWriter(Protocol):
def write(self, data: bytes) -> int: ...
def _write_stream(fh: _BinaryWriter, filters: GroundTruthFilters) -> int:
"""Write JSONL to ``fh``. Returns the number of samples written."""
count = 0
with session_scope() as session:
for sample in iter_ground_truth_samples(session, filters):
fh.write(serialize_sample_to_jsonl(sample).encode("utf-8"))
count += 1
return count
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,158 @@
"""HTTP tests for the ground-truth export endpoints."""
from __future__ import annotations
import json
from datetime import date
import pytest
from fastapi.testclient import TestClient
from ocr_sprint.main import create_app
from ocr_sprint.pipeline import orchestrator as orch_module
from ocr_sprint.pipeline.orchestrator import PipelineOutput
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
from ocr_sprint.schemas.extraction import (
ExtractionResult,
HeaderFields,
PersonnelEntry,
)
@pytest.fixture
def client() -> TestClient:
return TestClient(create_app())
@pytest.fixture
def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput:
result = ExtractionResult(
header=HeaderFields(
nomor_sprint="Sprin/1/I/2025",
tanggal=date(2025, 1, 1),
satuan_penerbit="POLRES TEST",
),
personel=[
PersonnelEntry(pangkat="AIPDA", nrp="77060000", nama="BUDI", jabatan="ANGGOTA"),
],
confidence=0.9,
)
output = PipelineOutput(
source_kind=SourceKind.PDF,
status=DocumentStatus.COMPLETED,
confidence=0.9,
result=result,
)
def _fake_run(_content: bytes) -> PipelineOutput:
return output
monkeypatch.setattr(orch_module, "run_pipeline", _fake_run)
from ocr_sprint.api.routes import documents as docs_module
monkeypatch.setattr(docs_module, "run_pipeline", _fake_run)
from ocr_sprint.worker import tasks as tasks_module
monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run)
return output
def _create_and_approve(client: TestClient, *, correction_value: str | None = None) -> str:
post = client.post(
"/api/v1/documents?sync=true",
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
)
assert post.status_code == 200, post.text
jid = str(post.json()["job_id"])
if correction_value is not None:
patched = client.patch(
f"/api/v1/documents/{jid}",
json={"corrections": [{"path": "header.perihal", "value": correction_value}]},
)
assert patched.status_code == 200
approved = client.post(f"/api/v1/documents/{jid}/approve")
assert approved.status_code == 200
return jid
def test_stats_empty_dataset(client: TestClient) -> None:
resp = client.get("/api/v1/ground-truth/stats")
assert resp.status_code == 200
body = resp.json()
assert body["total_jobs"] == 0
assert body["approved_jobs"] == 0
assert body["total_corrections"] == 0
assert body["top_corrected_fields"] == []
def test_stats_rolls_up_counts(client: TestClient, fake_pipeline: PipelineOutput) -> None:
_create_and_approve(client, correction_value="Penyelidikan-1")
_create_and_approve(client, correction_value="Penyelidikan-2")
_create_and_approve(client, correction_value=None) # pristine
resp = client.get("/api/v1/ground-truth/stats")
assert resp.status_code == 200
body = resp.json()
assert body["total_jobs"] == 3
assert body["approved_jobs"] == 3
assert body["total_corrections"] == 2
assert body["jobs_with_corrections"] == 2
assert body["top_corrected_fields"][0]["field_path"] == "header.perihal"
assert body["top_corrected_fields"][0]["count"] == 2
def test_export_streams_jsonl(client: TestClient, fake_pipeline: PipelineOutput) -> None:
_create_and_approve(client, correction_value="Penyelidikan")
_create_and_approve(client, correction_value=None)
resp = client.get("/api/v1/ground-truth/export")
assert resp.status_code == 200
assert resp.headers["content-type"].startswith("application/x-ndjson")
lines = [line for line in resp.text.splitlines() if line.strip()]
assert len(lines) == 2
parsed = [json.loads(line) for line in lines]
for sample in parsed:
assert sample["approved"] is True
assert "initial_result" in sample
assert "final_result" in sample
def test_export_approved_only_default(client: TestClient, fake_pipeline: PipelineOutput) -> None:
"""Unapproved jobs shouldn't appear in the default export."""
# One approved, one just completed (no approve call).
_create_and_approve(client, correction_value=None)
client.post(
"/api/v1/documents?sync=true",
files={"file": ("y.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
)
resp = client.get("/api/v1/ground-truth/export")
lines = [line for line in resp.text.splitlines() if line.strip()]
assert len(lines) == 1
# Toggle approved_only=false to include both.
resp = client.get("/api/v1/ground-truth/export?approved_only=false")
lines = [line for line in resp.text.splitlines() if line.strip()]
assert len(lines) == 2
def test_export_has_corrections_filter(client: TestClient, fake_pipeline: PipelineOutput) -> None:
_create_and_approve(client, correction_value="Penyelidikan")
_create_and_approve(client, correction_value=None)
resp = client.get("/api/v1/ground-truth/export?has_corrections=true")
lines = [line for line in resp.text.splitlines() if line.strip()]
assert len(lines) == 1
assert json.loads(lines[0])["corrections"][0]["new_value"] == "Penyelidikan"
resp = client.get("/api/v1/ground-truth/export?has_corrections=false")
lines = [line for line in resp.text.splitlines() if line.strip()]
assert len(lines) == 1
assert json.loads(lines[0])["corrections"] == []
def test_export_respects_limit(client: TestClient, fake_pipeline: PipelineOutput) -> None:
for _ in range(5):
_create_and_approve(client, correction_value=None)
resp = client.get("/api/v1/ground-truth/export?limit=2")
lines = [line for line in resp.text.splitlines() if line.strip()]
assert len(lines) == 2

View File

@@ -0,0 +1,80 @@
"""Smoke tests for the ``ocr_sprint.tools.export_ground_truth`` CLI."""
from __future__ import annotations
import json
from pathlib import Path
from uuid import uuid4
import pytest
from ocr_sprint.db.base import Base, get_engine, session_scope
from ocr_sprint.db.repositories import JobRepository
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
from ocr_sprint.tools.export_ground_truth import main as export_main
@pytest.fixture
def db_ready() -> None:
Base.metadata.create_all(bind=get_engine())
def _seed_two_approved_jobs() -> None:
for _ in range(2):
jid = uuid4()
with session_scope() as session:
JobRepository(session).create(
job_id=jid, filename="x.pdf", source_kind=SourceKind.PDF, blob_key="k"
)
with session_scope() as session:
JobRepository(session).mark_completed(
jid,
status=DocumentStatus.COMPLETED,
confidence=0.9,
result={"header": {"nomor_sprint": "SPR/1/2025"}},
review_flags=[],
)
with session_scope() as session:
JobRepository(session).approve(jid, reviewed_by="rev")
def test_cli_writes_expected_number_of_lines(
db_ready: None, tmp_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
_seed_two_approved_jobs()
out = tmp_path / "corpus.jsonl"
exit_code = export_main(["--out", str(out)])
assert exit_code == 0
lines = [line for line in out.read_text().splitlines() if line.strip()]
assert len(lines) == 2
for line in lines:
parsed = json.loads(line)
assert parsed["approved"] is True
stderr = capsys.readouterr().err
assert "wrote 2 sample(s)" in stderr
def test_cli_respects_limit(db_ready: None, tmp_path: Path) -> None:
_seed_two_approved_jobs()
out = tmp_path / "corpus.jsonl"
exit_code = export_main(["--out", str(out), "--limit", "1"])
assert exit_code == 0
lines = [line for line in out.read_text().splitlines() if line.strip()]
assert len(lines) == 1
def test_cli_print_stats_emits_json_to_stderr(
db_ready: None, tmp_path: Path, capsys: pytest.CaptureFixture[str]
) -> None:
_seed_two_approved_jobs()
out = tmp_path / "corpus.jsonl"
exit_code = export_main(["--out", str(out), "--print-stats"])
assert exit_code == 0
stderr = capsys.readouterr().err
# Validate the JSON prologue (after the "wrote N" line).
json_start = stderr.index("{")
stats = json.loads(stderr[json_start:])
assert stats["total_jobs"] == 2
assert stats["approved_jobs"] == 2

View File

@@ -0,0 +1,210 @@
"""Unit tests for the ground-truth export service."""
from __future__ import annotations
import json
from datetime import datetime, timezone
from uuid import UUID, uuid4
import pytest
from ocr_sprint.db.base import Base, get_engine, session_scope
from ocr_sprint.db.repositories import JobRepository
from ocr_sprint.ground_truth import (
GroundTruthFilters,
build_initial_result,
ground_truth_stats,
iter_ground_truth_samples,
serialize_sample_to_jsonl,
)
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
@pytest.fixture
def db_ready() -> None:
Base.metadata.create_all(bind=get_engine())
def _seed_approved_job_with_corrections(
*,
final_result: dict[str, object] | None = None,
corrections: list[tuple[str, object, object]] | None = None,
approved: bool = True,
) -> UUID:
jid = uuid4()
with session_scope() as session:
repo = JobRepository(session)
repo.create(job_id=jid, filename="x.pdf", source_kind=SourceKind.PDF, blob_key="k")
with session_scope() as session:
JobRepository(session).mark_completed(
jid,
status=DocumentStatus.NEEDS_REVIEW,
confidence=0.8,
result=final_result
or {
"header": {"nomor_sprint": "SPR/1/2025", "satuan_penerbit": "POLRES X"},
"personel": [{"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"}],
},
review_flags=[],
)
if corrections:
# Corrections are tuples of (path, value, old_value_for_audit).
# We translate them into real PATCH calls so the audit rows look
# exactly like they would in production.
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[(p, new, None) for (p, new, _old) in corrections],
corrected_by="reviewer-a",
)
if approved:
with session_scope() as session:
JobRepository(session).approve(jid, reviewed_by="reviewer-a")
return jid
def test_build_initial_result_reverses_corrections() -> None:
"""With known old/new pairs the replay must return the pre-HITL dict."""
final = {
"header": {"nomor_sprint": "SPR/1/2025", "perihal": "Penyelidikan"},
}
class _Fake:
def __init__(self, path: str, old: object, new: object) -> None:
self.field_path = path
self.old_value = old
self.new_value = new
corrections = [
_Fake("header.perihal", None, "Penyelidikan"),
_Fake("header.nomor_sprint", "SPR/OLD/2025", "SPR/1/2025"),
]
restored = build_initial_result(final_result=final, corrections=corrections) # type: ignore[arg-type]
assert restored == {
"header": {"nomor_sprint": "SPR/OLD/2025", "perihal": None},
}
# ``final`` must not have been mutated.
assert final["header"]["nomor_sprint"] == "SPR/1/2025"
def test_build_initial_result_ignores_unresolvable_paths() -> None:
"""A legacy path that no longer resolves should be skipped, not raise."""
final = {"header": {"nomor_sprint": "SPR/1"}}
class _Fake:
def __init__(self, path: str, old: object, new: object) -> None:
self.field_path = path
self.old_value = old
self.new_value = new
corrections = [
_Fake("personel[99].nrp", "stale", "new"), # container doesn't exist
_Fake("bogus..path", "x", "y"),
]
restored = build_initial_result(final_result=final, corrections=corrections) # type: ignore[arg-type]
assert restored == final
def test_iter_samples_reconstructs_initial_and_final(db_ready: None) -> None:
jid = _seed_approved_job_with_corrections(
corrections=[("header.perihal", "Penyelidikan", None)],
)
with session_scope() as session:
samples = list(iter_ground_truth_samples(session, GroundTruthFilters()))
assert len(samples) == 1
s = samples[0]
assert s.job_id == jid
assert s.approved is True
assert s.reviewed_by == "reviewer-a"
assert s.final_result is not None
assert s.final_result["header"]["perihal"] == "Penyelidikan"
# Initial reconstruction undoes the correction.
assert s.initial_result is not None
assert s.initial_result["header"].get("perihal") is None
assert len(s.corrections) == 1
assert s.corrections[0].field_path == "header.perihal"
def test_iter_samples_respects_approved_only_default(db_ready: None) -> None:
"""Unapproved jobs must be excluded unless approved_only=False."""
_seed_approved_job_with_corrections(approved=False)
with session_scope() as session:
samples = list(iter_ground_truth_samples(session, GroundTruthFilters()))
assert samples == []
with session_scope() as session:
samples = list(iter_ground_truth_samples(session, GroundTruthFilters(approved_only=False)))
assert len(samples) == 1
def test_iter_samples_respects_has_corrections_filter(db_ready: None) -> None:
_seed_approved_job_with_corrections(corrections=None) # pristine
_seed_approved_job_with_corrections(corrections=[("header.perihal", "fill", None)])
with session_scope() as session:
with_ = list(iter_ground_truth_samples(session, GroundTruthFilters(has_corrections=True)))
without = list(
iter_ground_truth_samples(session, GroundTruthFilters(has_corrections=False))
)
assert len(with_) == 1
assert len(without) == 1
assert with_[0].job_id != without[0].job_id
def test_iter_samples_respects_since_until_and_limit(db_ready: None) -> None:
_seed_approved_job_with_corrections()
_seed_approved_job_with_corrections()
_seed_approved_job_with_corrections()
# Since far in the future → empty.
future = datetime(2999, 1, 1, tzinfo=timezone.utc)
with session_scope() as session:
out = list(iter_ground_truth_samples(session, GroundTruthFilters(since=future)))
assert out == []
# Limit caps the number emitted.
with session_scope() as session:
out = list(iter_ground_truth_samples(session, GroundTruthFilters(limit=2)))
assert len(out) == 2
def test_stats_counts_rollup_and_top_fields(db_ready: None) -> None:
# 1 approved job with 2 corrections on ``header.perihal``.
_seed_approved_job_with_corrections(
corrections=[
("header.perihal", "first", None),
]
)
jid = _seed_approved_job_with_corrections(
corrections=[
("header.perihal", "second", None),
("personel[0].nrp", "77060001", None),
],
)
assert jid # silence unused
with session_scope() as session:
stats = ground_truth_stats(session)
assert stats.total_jobs == 2
assert stats.approved_jobs == 2
assert stats.total_corrections == 3
assert stats.jobs_with_corrections == 2
# ``header.perihal`` is the most-corrected field (2 > 1).
assert stats.top_corrected_fields[0].field_path == "header.perihal"
assert stats.top_corrected_fields[0].count == 2
assert {f.field_path for f in stats.top_corrected_fields} == {
"header.perihal",
"personel[0].nrp",
}
def test_serialize_is_valid_jsonl(db_ready: None) -> None:
_seed_approved_job_with_corrections(corrections=[("header.perihal", "X", None)])
with session_scope() as session:
for sample in iter_ground_truth_samples(session, GroundTruthFilters()):
line = serialize_sample_to_jsonl(sample)
assert line.endswith("\n")
# Strip trailing newline and make sure the result is a single
# JSON object (not a list — JSONL must be one object per line).
parsed = json.loads(line)
assert isinstance(parsed, dict)
assert parsed["approved"] is True