Phase 7: ground-truth export (JSONL + stats) + CLI tool

- GET /api/v1/ground-truth/export  streaming JSONL (approved_only,
  since, until, has_corrections, limit)
- GET /api/v1/ground-truth/stats   total / approved / corrections
  counts + top-N most-corrected field paths
- python -m ocr_sprint.tools.export_ground_truth  operator CLI with
  the same filters + optional --print-stats
- Ground-truth sample reconstructs the pipeline's original output by
  replaying job_corrections in reverse
- docs/ground-truth-format.md    schema + fine-tuning guidance
- 17 new tests (service replay, endpoint filters, CLI)
- 201 total tests passing, ruff / mypy --strict clean

Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com>
This commit is contained in:
Devin AI
2026-04-25 20:24:40 +00:00
parent 9457fa3c55
commit 6003d96a94
11 changed files with 1148 additions and 1 deletions

View File

@@ -0,0 +1,102 @@
"""Ground-truth export + statistics endpoints (Phase 7).
Two endpoints, both auth'd by the existing ``X-API-Key`` dependency:
* ``GET /ground-truth/export`` — streams JSONL of approved (or filtered)
samples for downstream fine-tuning pipelines.
* ``GET /ground-truth/stats`` — returns aggregate counts + top-corrected
field paths so operators know when/where fine-tuning will pay off.
"""
from __future__ import annotations
from collections.abc import Iterator
from datetime import datetime
from typing import Annotated
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from ocr_sprint.api.deps.auth import require_api_key
from ocr_sprint.api.deps.db import get_session
from ocr_sprint.ground_truth import (
GroundTruthFilters,
ground_truth_stats,
iter_ground_truth_samples,
serialize_sample_to_jsonl,
)
from ocr_sprint.schemas.ground_truth import GroundTruthStats
router = APIRouter(
prefix="/ground-truth",
tags=["ground-truth"],
dependencies=[Depends(require_api_key)],
)
@router.get(
"/export",
response_class=StreamingResponse,
responses={
200: {
"content": {"application/x-ndjson": {}},
"description": "Newline-delimited JSON stream of training samples.",
}
},
)
def export_ground_truth(
session: Annotated[Session, Depends(get_session)],
since: Annotated[
datetime | None,
Query(description="Only include jobs created at or after this ISO timestamp."),
] = None,
until: Annotated[
datetime | None,
Query(description="Only include jobs created at or before this ISO timestamp."),
] = None,
approved_only: Annotated[
bool,
Query(description="Only export approved jobs (default true)."),
] = True,
has_corrections: Annotated[
bool | None,
Query(
description=(
"Optional: true = only jobs that had at least one correction, "
"false = only pristine (no-correction) jobs."
)
),
] = None,
limit: Annotated[
int | None,
Query(ge=1, le=100_000, description="Maximum rows to emit."),
] = None,
) -> StreamingResponse:
filters = GroundTruthFilters(
since=since,
until=until,
approved_only=approved_only,
has_corrections=has_corrections,
limit=limit,
)
def _stream() -> Iterator[bytes]:
for sample in iter_ground_truth_samples(session, filters):
yield serialize_sample_to_jsonl(sample).encode("utf-8")
return StreamingResponse(_stream(), media_type="application/x-ndjson")
@router.get(
"/stats",
response_model=GroundTruthStats,
)
def get_stats(
session: Annotated[Session, Depends(get_session)],
top_n: Annotated[
int,
Query(ge=1, le=100, description="How many top-corrected field paths to return."),
] = 10,
) -> GroundTruthStats:
return ground_truth_stats(session, top_n=top_n)

View File

@@ -0,0 +1,22 @@
"""Phase 7 — ground-truth export service.
Consumes the ``jobs`` + ``job_corrections`` tables to produce JSONL
training samples + aggregate statistics. No ML dependencies here: the
actual fine-tuning runs in a separate project on dedicated hardware.
"""
from ocr_sprint.ground_truth.service import (
GroundTruthFilters,
build_initial_result,
ground_truth_stats,
iter_ground_truth_samples,
serialize_sample_to_jsonl,
)
__all__ = [
"GroundTruthFilters",
"build_initial_result",
"ground_truth_stats",
"iter_ground_truth_samples",
"serialize_sample_to_jsonl",
]

View File

@@ -0,0 +1,248 @@
"""Core ground-truth export logic.
The service reads from ``jobs`` and ``job_corrections`` and produces
``GroundTruthSample`` rows. The HTTP layer and the CLI both go through
this module, so the export logic stays in exactly one place.
Design notes
------------
* **Reverse-replay for ``initial_result``.** Audit rows store both
``old_value`` and ``new_value``. We start from the current
``final_result`` and undo each correction in reverse chronological
order to recover the pipeline's original pre-HITL output. This is
what fine-tuning needs: the raw mistake next to the reviewer fix.
* **No deep copying in the caller.** The service always returns freshly
computed dicts so mutating one sample can't leak into the next.
* **Generator by design.** Large exports shouldn't load the whole table
into memory; the HTTP route streams JSONL line-by-line.
"""
from __future__ import annotations
import copy
import re
from collections.abc import Iterator
from dataclasses import dataclass
from datetime import datetime
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from ocr_sprint.db.models import JobCorrectionRow, JobRow
from ocr_sprint.schemas.document import DocumentStatus
from ocr_sprint.schemas.ground_truth import (
FieldCorrectionCount,
GroundTruthCorrection,
GroundTruthSample,
GroundTruthStats,
)
@dataclass(frozen=True)
class GroundTruthFilters:
"""HTTP / CLI filters for ``iter_ground_truth_samples``.
All fields are optional; ``approved_only=True`` is the default because
only approved samples are actually safe to train on.
"""
since: datetime | None = None
until: datetime | None = None
approved_only: bool = True
has_corrections: bool | None = None
limit: int | None = None
def build_initial_result(
*,
final_result: dict[str, Any] | None,
corrections: list[JobCorrectionRow],
) -> dict[str, Any] | None:
"""Replay ``corrections`` backwards over ``final_result`` to recover
the pre-HITL version. Returns ``None`` if there's nothing to
reconstruct.
"""
if final_result is None:
return None
restored = copy.deepcopy(final_result)
for event in reversed(corrections):
_assign_path(restored, event.field_path, event.old_value)
return restored
def _assign_path(data: dict[str, Any], path: str, value: Any) -> None:
"""Set ``data[path] = value`` using the same dotted syntax the HITL
repository understands. Silently ignores paths that no longer resolve
(the dict shape may have drifted since the event was recorded; in
that case we just skip the replay step rather than raising).
"""
segments = _parse(path)
if not segments:
return
cursor: Any = data
for name, idx in segments[:-1]:
if not isinstance(cursor, dict) or name not in cursor:
return
cursor = cursor[name]
if idx is not None:
if not isinstance(cursor, list) or idx >= len(cursor):
return
cursor = cursor[idx]
name, idx = segments[-1]
if idx is not None:
if not isinstance(cursor, dict) or name not in cursor:
return
container = cursor[name]
if not isinstance(container, list) or idx >= len(container):
return
container[idx] = value
return
if not isinstance(cursor, dict):
return
cursor[name] = value
_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$")
def _parse(path: str) -> list[tuple[str, int | None]]:
"""Forgiving dotted-path parser — returns ``[]`` on malformed input
instead of raising, since audit-log replay must not blow up on a
stale path. The strict variant used by the HITL PATCH route lives
in ``ocr_sprint.db.repositories``.
"""
if not path or path.startswith(".") or path.endswith("."):
return []
out: list[tuple[str, int | None]] = []
for part in path.split("."):
match: re.Match[str] | None = _SEGMENT_RE.match(part)
if match is None:
return []
out.append((match.group(1), int(match.group(2)) if match.group(2) else None))
return out
def iter_ground_truth_samples(
session: Session,
filters: GroundTruthFilters,
) -> Iterator[GroundTruthSample]:
"""Yield ``GroundTruthSample`` rows matching ``filters``.
Rows are ordered by ``created_at ASC`` so repeated exports are
deterministic and downstream diffing works.
"""
stmt = select(JobRow).order_by(JobRow.created_at, JobRow.job_id)
if filters.approved_only:
stmt = stmt.where(JobRow.approved.is_(True))
if filters.since is not None:
stmt = stmt.where(JobRow.created_at >= filters.since)
if filters.until is not None:
stmt = stmt.where(JobRow.created_at <= filters.until)
remaining = filters.limit
for job_row in session.scalars(stmt):
if remaining is not None and remaining <= 0:
return
correction_rows = list(
session.scalars(
select(JobCorrectionRow)
.where(JobCorrectionRow.job_id == job_row.job_id)
.order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id)
)
)
if filters.has_corrections is True and not correction_rows:
continue
if filters.has_corrections is False and correction_rows:
continue
initial = build_initial_result(final_result=job_row.result, corrections=correction_rows)
sample = GroundTruthSample(
job_id=job_row.job_id,
filename=job_row.filename,
source_kind=job_row.source_kind,
approved=bool(job_row.approved),
reviewed_by=job_row.reviewed_by,
reviewed_at=job_row.reviewed_at,
created_at=job_row.created_at,
initial_result=initial,
final_result=copy.deepcopy(job_row.result) if job_row.result else None,
corrections=[
GroundTruthCorrection(
field_path=c.field_path,
old_value=c.old_value,
new_value=c.new_value,
corrected_by=c.corrected_by,
reason=c.reason,
corrected_at=c.corrected_at,
)
for c in correction_rows
],
review_flags=list(job_row.review_flags or []),
confidence=job_row.confidence,
)
if remaining is not None:
remaining -= 1
yield sample
def ground_truth_stats(session: Session, *, top_n: int = 10) -> GroundTruthStats:
"""Compute aggregate statistics over the jobs + corrections tables."""
by_status: dict[str, int] = {
row[0]: int(row[1])
for row in session.execute(
select(JobRow.status, func.count(JobRow.job_id)).group_by(JobRow.status)
).all()
}
total_jobs = int(sum(by_status.values()))
approved_jobs = int(
session.execute(
select(func.count(JobRow.job_id)).where(JobRow.approved.is_(True))
).scalar_one()
)
total_corrections = int(session.execute(select(func.count(JobCorrectionRow.id))).scalar_one())
jobs_with_corrections = int(
session.execute(select(func.count(func.distinct(JobCorrectionRow.job_id)))).scalar_one()
)
# Field-path histogram. A single SQL ``GROUP BY`` is cheaper than
# pulling every row over the wire, but we still cap at ``top_n`` so
# a pathological dataset can't drown the response.
top_rows = session.execute(
select(
JobCorrectionRow.field_path,
func.count(JobCorrectionRow.id).label("n"),
)
.group_by(JobCorrectionRow.field_path)
.order_by(func.count(JobCorrectionRow.id).desc(), JobCorrectionRow.field_path)
.limit(top_n)
).all()
return GroundTruthStats(
total_jobs=total_jobs,
completed_jobs=by_status.get(DocumentStatus.COMPLETED.value, 0),
needs_review_jobs=by_status.get(DocumentStatus.NEEDS_REVIEW.value, 0),
failed_jobs=by_status.get(DocumentStatus.FAILED.value, 0),
approved_jobs=approved_jobs,
total_corrections=total_corrections,
jobs_with_corrections=jobs_with_corrections,
top_corrected_fields=[
FieldCorrectionCount(field_path=row[0], count=int(row[1])) for row in top_rows
],
)
def serialize_sample_to_jsonl(sample: GroundTruthSample) -> str:
"""Serialize one sample as a newline-terminated JSON string.
Kept standalone so the HTTP and CLI layers share the exact same
representation — diffs between a curl download and a CLI dump stay
byte-identical.
"""
return sample.model_dump_json() + "\n"

View File

@@ -7,7 +7,7 @@ from fastapi import FastAPI
from ocr_sprint import __version__
from ocr_sprint.api.errors import register_error_handlers
from ocr_sprint.api.metrics import MetricsMiddleware, metrics_endpoint
from ocr_sprint.api.routes import documents, health
from ocr_sprint.api.routes import documents, ground_truth, health
from ocr_sprint.config import get_settings
from ocr_sprint.db import models as _models # noqa: F401 (register ORM tables)
from ocr_sprint.db.base import Base, get_engine
@@ -43,6 +43,7 @@ def create_app() -> FastAPI:
app.add_middleware(MetricsMiddleware)
app.include_router(health.router, prefix="/api/v1")
app.include_router(documents.router, prefix="/api/v1")
app.include_router(ground_truth.router, prefix="/api/v1")
app.add_api_route("/metrics", metrics_endpoint, methods=["GET"], include_in_schema=False)
return app

View File

@@ -0,0 +1,76 @@
"""Schemas for the Phase 7 ground-truth export.
Each ``GroundTruthSample`` represents one training-ready example:
* ``initial_*`` snapshots the pipeline's original (pre-HITL) output,
reconstructed by replaying the audit trail in reverse.
* ``final_*`` is the current ``result`` on the ``jobs`` row — the
reviewer-approved answer.
* ``corrections`` is the raw audit trail so downstream fine-tuning can
see *what* was changed, *why* (free-text reason), and by whom.
JSONL is emitted — one sample per line — so the file can be mmapped,
streamed, or piped straight into an HF ``datasets.load_dataset("json",
...)`` call.
"""
from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field
class GroundTruthCorrection(BaseModel):
"""One row of the ``job_corrections`` audit trail, as exported."""
field_path: str
old_value: Any | None = None
new_value: Any | None = None
corrected_by: str | None = None
reason: str | None = None
corrected_at: datetime
class GroundTruthSample(BaseModel):
"""One training sample written as a single JSONL line."""
model_config = ConfigDict(populate_by_name=True)
job_id: UUID
filename: str
source_kind: str
approved: bool = False
reviewed_by: str | None = None
reviewed_at: datetime | None = None
created_at: datetime
# ``initial_*`` is the pipeline's pre-HITL answer, reconstructed from
# the audit trail. ``final_*`` is the reviewer-approved version.
initial_result: dict[str, Any] | None = None
final_result: dict[str, Any] | None = None
corrections: list[GroundTruthCorrection] = Field(default_factory=list)
review_flags: list[str] = Field(default_factory=list)
confidence: float | None = None
class FieldCorrectionCount(BaseModel):
field_path: str
count: int
class GroundTruthStats(BaseModel):
"""High-level dataset health report surfaced by ``GET /ground-truth/stats``."""
total_jobs: int
completed_jobs: int
needs_review_jobs: int
failed_jobs: int
approved_jobs: int
total_corrections: int
jobs_with_corrections: int
# Most-corrected field paths (descending). Operators use this to
# prioritise which fields to target with prompt tweaks or fine-tune
# data collection first.
top_corrected_fields: list[FieldCorrectionCount] = Field(default_factory=list)

View File

@@ -0,0 +1 @@
"""Operator CLI utilities. Invoked via ``python -m ocr_sprint.tools.<name>``."""

View File

@@ -0,0 +1,137 @@
"""CLI: ``python -m ocr_sprint.tools.export_ground_truth``.
Dumps the HITL ground-truth corpus to a local JSONL file. The HTTP
endpoint is fine for small queries, but weekly / monthly snapshots are
easier to schedule on the host via cron than via curl.
Usage
-----
python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl
python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl \
--since 2025-01-01 --has-corrections
The database URL is read from the same ``DATABASE_URL`` env var the API
uses, so this script works out of the box on any box that can run the
service.
"""
from __future__ import annotations
import argparse
import sys
from datetime import datetime
from pathlib import Path
from typing import Protocol
from ocr_sprint.db.base import session_scope
from ocr_sprint.ground_truth import (
GroundTruthFilters,
ground_truth_stats,
iter_ground_truth_samples,
serialize_sample_to_jsonl,
)
def _parse_iso(raw: str | None) -> datetime | None:
if raw is None:
return None
# ``fromisoformat`` supports dates too, so ``--since 2025-01-01``
# behaves like ``2025-01-01T00:00:00``.
return datetime.fromisoformat(raw)
def _build_parser() -> argparse.ArgumentParser:
p = argparse.ArgumentParser(
prog="ocr_sprint.tools.export_ground_truth",
description="Dump HITL ground-truth corpus to JSONL.",
)
p.add_argument(
"--out",
required=True,
type=Path,
help="Output JSONL file. Use '-' to write to stdout.",
)
p.add_argument("--since", type=str, default=None, help="ISO date/datetime lower bound.")
p.add_argument("--until", type=str, default=None, help="ISO date/datetime upper bound.")
p.add_argument(
"--include-unapproved",
action="store_true",
help=(
"Include jobs that haven't been approved yet. Not recommended "
"for training data — use only for debugging exports."
),
)
group = p.add_mutually_exclusive_group()
group.add_argument(
"--has-corrections",
action="store_true",
help="Only include jobs that had at least one correction.",
)
group.add_argument(
"--no-corrections",
action="store_true",
help="Only include pristine jobs (no corrections). Useful for sanity sets.",
)
p.add_argument("--limit", type=int, default=None, help="Maximum samples to emit.")
p.add_argument(
"--print-stats",
action="store_true",
help="Also print dataset summary to stderr after the dump.",
)
return p
def main(argv: list[str] | None = None) -> int:
args = _build_parser().parse_args(argv)
has_corrections: bool | None
if args.has_corrections:
has_corrections = True
elif args.no_corrections:
has_corrections = False
else:
has_corrections = None
filters = GroundTruthFilters(
since=_parse_iso(args.since),
until=_parse_iso(args.until),
approved_only=not args.include_unapproved,
has_corrections=has_corrections,
limit=args.limit,
)
out: Path = args.out
count = 0
if str(out) == "-":
_write_stream(sys.stdout.buffer, filters)
else:
out.parent.mkdir(parents=True, exist_ok=True)
with out.open("wb") as fh:
count = _write_stream(fh, filters)
print(f"wrote {count} sample(s) to {out}", file=sys.stderr)
if args.print_stats:
with session_scope() as session:
stats = ground_truth_stats(session)
print(stats.model_dump_json(indent=2), file=sys.stderr)
return 0
class _BinaryWriter(Protocol):
def write(self, data: bytes) -> int: ...
def _write_stream(fh: _BinaryWriter, filters: GroundTruthFilters) -> int:
"""Write JSONL to ``fh``. Returns the number of samples written."""
count = 0
with session_scope() as session:
for sample in iter_ground_truth_samples(session, filters):
fh.write(serialize_sample_to_jsonl(sample).encode("utf-8"))
count += 1
return count
if __name__ == "__main__":
sys.exit(main())