diff --git a/docs/ground-truth-format.md b/docs/ground-truth-format.md new file mode 100644 index 0000000..8898bcd --- /dev/null +++ b/docs/ground-truth-format.md @@ -0,0 +1,112 @@ +# Ground-truth export format (Phase 7) + +The service exposes the HITL corpus as [JSONL](https://jsonlines.org/) — +one training sample per line — via the HTTP endpoint +`GET /api/v1/ground-truth/export` and the equivalent CLI: + +```bash +python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl +``` + +Both paths read the same database and emit byte-identical JSONL, so a +cron-scheduled dump and an ad-hoc curl download are interchangeable. + +## Sample schema + +Each line is a single JSON object with the following shape: + +```jsonc +{ + "job_id": "c5da6747-...", + "filename": "sprint-042.pdf", + "source_kind": "pdf", + "approved": true, + "reviewed_by": "reviewer-a", // free-form; comes from X-User-Id + "reviewed_at": "2025-06-01T10:15:00Z", + "created_at": "2025-05-28T08:02:17Z", + + // The pipeline's original pre-HITL output, reconstructed by replaying + // the audit trail backwards. `null` for jobs that never produced a + // result (e.g. hard-failed on OCR). + "initial_result": { + "header": { "nomor_sprint": "Sprin/1/I/2025", "perihal": null, ... }, + "personel": [ { "pangkat": "AIPDA", "nrp": "77060000", ... } ], + ... + }, + + // The reviewer-approved answer (current value of jobs.result). + "final_result": { ...same shape as initial_result... }, + + // Every correction event, in chronological order. + "corrections": [ + { + "field_path": "header.perihal", + "old_value": null, + "new_value": "Penyelidikan kasus pencurian", + "corrected_by": "reviewer-a", + "reason": "LLM missed it", + "corrected_at": "2025-05-30T14:00:00Z" + } + ], + + "review_flags": ["llm_fallback"], + "confidence": 0.78 +} +``` + +## Recommended filters + +* `approved_only=true` (default) — **do not** train on unreviewed + samples; they can still contain OCR mistakes. +* `has_corrections=true` — for a "hard examples" set where the pipeline + was originally wrong. +* `has_corrections=false` — for a "sanity" set where the pipeline was + already right. Good for regression tests after fine-tuning. +* `since` / `until` — build incremental snapshots without re-processing + the full history. + +## When is the dataset big enough to fine-tune? + +Rough operational checklist (rules of thumb — adjust based on your own +error analysis): + +| Bucket | Minimum rows | Notes | +|---------------------------------|--------------|---------------------------------------------------------------| +| LoRA on header extraction (LLM) | ~200–500 | Per-field error signal must be > random noise. | +| Per-satuan prompt tuning | ~50 / satuan | Helps when formats differ sharply between Polda/Polres units. | +| PP-Structure table fine-tune | ~1 000+ | Layout models are data-hungry; hold off until HITL is steady. | + +Use `GET /api/v1/ground-truth/stats` to check coverage: + +```json +{ + "total_jobs": 842, + "approved_jobs": 613, + "total_corrections": 1 204, + "jobs_with_corrections": 431, + "top_corrected_fields": [ + { "field_path": "header.perihal", "count": 289 }, + { "field_path": "personel[0].nrp", "count": 51 }, + ... + ] +} +``` + +Fields at the top of `top_corrected_fields` are the highest-leverage +targets for prompt tweaks, regex upgrades, or (eventually) fine-tuning. + +## Fine-tuning outside this repo + +The export is deliberately framework-agnostic. Suggested follow-ups on +dedicated GPU hardware: + +* [**Unsloth**](https://github.com/unslothai/unsloth) — LoRA on + Qwen2.5 / Llama 3.1 with 2–4 × speedups on a single GPU. +* [**Axolotl**](https://github.com/axolotl-ai-cloud/axolotl) — more + batteries-included; good for multi-GPU runs. + +Typical prompt-completion conversion: feed `initial_result` (or the raw +OCR text, if your pipeline keeps it) as the "input" and `final_result` +as the "output". The `corrections` list is only needed if you want to +build an error-class analysis — the model itself trains on the final +answer. diff --git a/src/ocr_sprint/api/routes/ground_truth.py b/src/ocr_sprint/api/routes/ground_truth.py new file mode 100644 index 0000000..5f77fc3 --- /dev/null +++ b/src/ocr_sprint/api/routes/ground_truth.py @@ -0,0 +1,102 @@ +"""Ground-truth export + statistics endpoints (Phase 7). + +Two endpoints, both auth'd by the existing ``X-API-Key`` dependency: + +* ``GET /ground-truth/export`` — streams JSONL of approved (or filtered) + samples for downstream fine-tuning pipelines. +* ``GET /ground-truth/stats`` — returns aggregate counts + top-corrected + field paths so operators know when/where fine-tuning will pay off. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from datetime import datetime +from typing import Annotated + +from fastapi import APIRouter, Depends, Query +from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session + +from ocr_sprint.api.deps.auth import require_api_key +from ocr_sprint.api.deps.db import get_session +from ocr_sprint.ground_truth import ( + GroundTruthFilters, + ground_truth_stats, + iter_ground_truth_samples, + serialize_sample_to_jsonl, +) +from ocr_sprint.schemas.ground_truth import GroundTruthStats + +router = APIRouter( + prefix="/ground-truth", + tags=["ground-truth"], + dependencies=[Depends(require_api_key)], +) + + +@router.get( + "/export", + response_class=StreamingResponse, + responses={ + 200: { + "content": {"application/x-ndjson": {}}, + "description": "Newline-delimited JSON stream of training samples.", + } + }, +) +def export_ground_truth( + session: Annotated[Session, Depends(get_session)], + since: Annotated[ + datetime | None, + Query(description="Only include jobs created at or after this ISO timestamp."), + ] = None, + until: Annotated[ + datetime | None, + Query(description="Only include jobs created at or before this ISO timestamp."), + ] = None, + approved_only: Annotated[ + bool, + Query(description="Only export approved jobs (default true)."), + ] = True, + has_corrections: Annotated[ + bool | None, + Query( + description=( + "Optional: true = only jobs that had at least one correction, " + "false = only pristine (no-correction) jobs." + ) + ), + ] = None, + limit: Annotated[ + int | None, + Query(ge=1, le=100_000, description="Maximum rows to emit."), + ] = None, +) -> StreamingResponse: + filters = GroundTruthFilters( + since=since, + until=until, + approved_only=approved_only, + has_corrections=has_corrections, + limit=limit, + ) + + def _stream() -> Iterator[bytes]: + for sample in iter_ground_truth_samples(session, filters): + yield serialize_sample_to_jsonl(sample).encode("utf-8") + + return StreamingResponse(_stream(), media_type="application/x-ndjson") + + +@router.get( + "/stats", + response_model=GroundTruthStats, +) +def get_stats( + session: Annotated[Session, Depends(get_session)], + top_n: Annotated[ + int, + Query(ge=1, le=100, description="How many top-corrected field paths to return."), + ] = 10, +) -> GroundTruthStats: + return ground_truth_stats(session, top_n=top_n) diff --git a/src/ocr_sprint/ground_truth/__init__.py b/src/ocr_sprint/ground_truth/__init__.py new file mode 100644 index 0000000..688460a --- /dev/null +++ b/src/ocr_sprint/ground_truth/__init__.py @@ -0,0 +1,22 @@ +"""Phase 7 — ground-truth export service. + +Consumes the ``jobs`` + ``job_corrections`` tables to produce JSONL +training samples + aggregate statistics. No ML dependencies here: the +actual fine-tuning runs in a separate project on dedicated hardware. +""" + +from ocr_sprint.ground_truth.service import ( + GroundTruthFilters, + build_initial_result, + ground_truth_stats, + iter_ground_truth_samples, + serialize_sample_to_jsonl, +) + +__all__ = [ + "GroundTruthFilters", + "build_initial_result", + "ground_truth_stats", + "iter_ground_truth_samples", + "serialize_sample_to_jsonl", +] diff --git a/src/ocr_sprint/ground_truth/service.py b/src/ocr_sprint/ground_truth/service.py new file mode 100644 index 0000000..5d33eaf --- /dev/null +++ b/src/ocr_sprint/ground_truth/service.py @@ -0,0 +1,248 @@ +"""Core ground-truth export logic. + +The service reads from ``jobs`` and ``job_corrections`` and produces +``GroundTruthSample`` rows. The HTTP layer and the CLI both go through +this module, so the export logic stays in exactly one place. + +Design notes +------------ +* **Reverse-replay for ``initial_result``.** Audit rows store both + ``old_value`` and ``new_value``. We start from the current + ``final_result`` and undo each correction in reverse chronological + order to recover the pipeline's original pre-HITL output. This is + what fine-tuning needs: the raw mistake next to the reviewer fix. +* **No deep copying in the caller.** The service always returns freshly + computed dicts so mutating one sample can't leak into the next. +* **Generator by design.** Large exports shouldn't load the whole table + into memory; the HTTP route streams JSONL line-by-line. +""" + +from __future__ import annotations + +import copy +import re +from collections.abc import Iterator +from dataclasses import dataclass +from datetime import datetime +from typing import Any + +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from ocr_sprint.db.models import JobCorrectionRow, JobRow +from ocr_sprint.schemas.document import DocumentStatus +from ocr_sprint.schemas.ground_truth import ( + FieldCorrectionCount, + GroundTruthCorrection, + GroundTruthSample, + GroundTruthStats, +) + + +@dataclass(frozen=True) +class GroundTruthFilters: + """HTTP / CLI filters for ``iter_ground_truth_samples``. + + All fields are optional; ``approved_only=True`` is the default because + only approved samples are actually safe to train on. + """ + + since: datetime | None = None + until: datetime | None = None + approved_only: bool = True + has_corrections: bool | None = None + limit: int | None = None + + +def build_initial_result( + *, + final_result: dict[str, Any] | None, + corrections: list[JobCorrectionRow], +) -> dict[str, Any] | None: + """Replay ``corrections`` backwards over ``final_result`` to recover + the pre-HITL version. Returns ``None`` if there's nothing to + reconstruct. + """ + if final_result is None: + return None + restored = copy.deepcopy(final_result) + for event in reversed(corrections): + _assign_path(restored, event.field_path, event.old_value) + return restored + + +def _assign_path(data: dict[str, Any], path: str, value: Any) -> None: + """Set ``data[path] = value`` using the same dotted syntax the HITL + repository understands. Silently ignores paths that no longer resolve + (the dict shape may have drifted since the event was recorded; in + that case we just skip the replay step rather than raising). + """ + segments = _parse(path) + if not segments: + return + + cursor: Any = data + for name, idx in segments[:-1]: + if not isinstance(cursor, dict) or name not in cursor: + return + cursor = cursor[name] + if idx is not None: + if not isinstance(cursor, list) or idx >= len(cursor): + return + cursor = cursor[idx] + + name, idx = segments[-1] + if idx is not None: + if not isinstance(cursor, dict) or name not in cursor: + return + container = cursor[name] + if not isinstance(container, list) or idx >= len(container): + return + container[idx] = value + return + + if not isinstance(cursor, dict): + return + cursor[name] = value + + +_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$") + + +def _parse(path: str) -> list[tuple[str, int | None]]: + """Forgiving dotted-path parser — returns ``[]`` on malformed input + instead of raising, since audit-log replay must not blow up on a + stale path. The strict variant used by the HITL PATCH route lives + in ``ocr_sprint.db.repositories``. + """ + if not path or path.startswith(".") or path.endswith("."): + return [] + out: list[tuple[str, int | None]] = [] + for part in path.split("."): + match: re.Match[str] | None = _SEGMENT_RE.match(part) + if match is None: + return [] + out.append((match.group(1), int(match.group(2)) if match.group(2) else None)) + return out + + +def iter_ground_truth_samples( + session: Session, + filters: GroundTruthFilters, +) -> Iterator[GroundTruthSample]: + """Yield ``GroundTruthSample`` rows matching ``filters``. + + Rows are ordered by ``created_at ASC`` so repeated exports are + deterministic and downstream diffing works. + """ + stmt = select(JobRow).order_by(JobRow.created_at, JobRow.job_id) + if filters.approved_only: + stmt = stmt.where(JobRow.approved.is_(True)) + if filters.since is not None: + stmt = stmt.where(JobRow.created_at >= filters.since) + if filters.until is not None: + stmt = stmt.where(JobRow.created_at <= filters.until) + + remaining = filters.limit + for job_row in session.scalars(stmt): + if remaining is not None and remaining <= 0: + return + + correction_rows = list( + session.scalars( + select(JobCorrectionRow) + .where(JobCorrectionRow.job_id == job_row.job_id) + .order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id) + ) + ) + + if filters.has_corrections is True and not correction_rows: + continue + if filters.has_corrections is False and correction_rows: + continue + + initial = build_initial_result(final_result=job_row.result, corrections=correction_rows) + sample = GroundTruthSample( + job_id=job_row.job_id, + filename=job_row.filename, + source_kind=job_row.source_kind, + approved=bool(job_row.approved), + reviewed_by=job_row.reviewed_by, + reviewed_at=job_row.reviewed_at, + created_at=job_row.created_at, + initial_result=initial, + final_result=copy.deepcopy(job_row.result) if job_row.result else None, + corrections=[ + GroundTruthCorrection( + field_path=c.field_path, + old_value=c.old_value, + new_value=c.new_value, + corrected_by=c.corrected_by, + reason=c.reason, + corrected_at=c.corrected_at, + ) + for c in correction_rows + ], + review_flags=list(job_row.review_flags or []), + confidence=job_row.confidence, + ) + if remaining is not None: + remaining -= 1 + yield sample + + +def ground_truth_stats(session: Session, *, top_n: int = 10) -> GroundTruthStats: + """Compute aggregate statistics over the jobs + corrections tables.""" + by_status: dict[str, int] = { + row[0]: int(row[1]) + for row in session.execute( + select(JobRow.status, func.count(JobRow.job_id)).group_by(JobRow.status) + ).all() + } + total_jobs = int(sum(by_status.values())) + approved_jobs = int( + session.execute( + select(func.count(JobRow.job_id)).where(JobRow.approved.is_(True)) + ).scalar_one() + ) + + total_corrections = int(session.execute(select(func.count(JobCorrectionRow.id))).scalar_one()) + jobs_with_corrections = int( + session.execute(select(func.count(func.distinct(JobCorrectionRow.job_id)))).scalar_one() + ) + + # Field-path histogram. A single SQL ``GROUP BY`` is cheaper than + # pulling every row over the wire, but we still cap at ``top_n`` so + # a pathological dataset can't drown the response. + top_rows = session.execute( + select( + JobCorrectionRow.field_path, + func.count(JobCorrectionRow.id).label("n"), + ) + .group_by(JobCorrectionRow.field_path) + .order_by(func.count(JobCorrectionRow.id).desc(), JobCorrectionRow.field_path) + .limit(top_n) + ).all() + + return GroundTruthStats( + total_jobs=total_jobs, + completed_jobs=by_status.get(DocumentStatus.COMPLETED.value, 0), + needs_review_jobs=by_status.get(DocumentStatus.NEEDS_REVIEW.value, 0), + failed_jobs=by_status.get(DocumentStatus.FAILED.value, 0), + approved_jobs=approved_jobs, + total_corrections=total_corrections, + jobs_with_corrections=jobs_with_corrections, + top_corrected_fields=[ + FieldCorrectionCount(field_path=row[0], count=int(row[1])) for row in top_rows + ], + ) + + +def serialize_sample_to_jsonl(sample: GroundTruthSample) -> str: + """Serialize one sample as a newline-terminated JSON string. + + Kept standalone so the HTTP and CLI layers share the exact same + representation — diffs between a curl download and a CLI dump stay + byte-identical. + """ + return sample.model_dump_json() + "\n" diff --git a/src/ocr_sprint/main.py b/src/ocr_sprint/main.py index 05724cf..7218b62 100644 --- a/src/ocr_sprint/main.py +++ b/src/ocr_sprint/main.py @@ -7,7 +7,7 @@ from fastapi import FastAPI from ocr_sprint import __version__ from ocr_sprint.api.errors import register_error_handlers from ocr_sprint.api.metrics import MetricsMiddleware, metrics_endpoint -from ocr_sprint.api.routes import documents, health +from ocr_sprint.api.routes import documents, ground_truth, health from ocr_sprint.config import get_settings from ocr_sprint.db import models as _models # noqa: F401 (register ORM tables) from ocr_sprint.db.base import Base, get_engine @@ -43,6 +43,7 @@ def create_app() -> FastAPI: app.add_middleware(MetricsMiddleware) app.include_router(health.router, prefix="/api/v1") app.include_router(documents.router, prefix="/api/v1") + app.include_router(ground_truth.router, prefix="/api/v1") app.add_api_route("/metrics", metrics_endpoint, methods=["GET"], include_in_schema=False) return app diff --git a/src/ocr_sprint/schemas/ground_truth.py b/src/ocr_sprint/schemas/ground_truth.py new file mode 100644 index 0000000..01667bd --- /dev/null +++ b/src/ocr_sprint/schemas/ground_truth.py @@ -0,0 +1,76 @@ +"""Schemas for the Phase 7 ground-truth export. + +Each ``GroundTruthSample`` represents one training-ready example: + +* ``initial_*`` snapshots the pipeline's original (pre-HITL) output, + reconstructed by replaying the audit trail in reverse. +* ``final_*`` is the current ``result`` on the ``jobs`` row — the + reviewer-approved answer. +* ``corrections`` is the raw audit trail so downstream fine-tuning can + see *what* was changed, *why* (free-text reason), and by whom. + +JSONL is emitted — one sample per line — so the file can be mmapped, +streamed, or piped straight into an HF ``datasets.load_dataset("json", +...)`` call. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + + +class GroundTruthCorrection(BaseModel): + """One row of the ``job_corrections`` audit trail, as exported.""" + + field_path: str + old_value: Any | None = None + new_value: Any | None = None + corrected_by: str | None = None + reason: str | None = None + corrected_at: datetime + + +class GroundTruthSample(BaseModel): + """One training sample written as a single JSONL line.""" + + model_config = ConfigDict(populate_by_name=True) + + job_id: UUID + filename: str + source_kind: str + approved: bool = False + reviewed_by: str | None = None + reviewed_at: datetime | None = None + created_at: datetime + # ``initial_*`` is the pipeline's pre-HITL answer, reconstructed from + # the audit trail. ``final_*`` is the reviewer-approved version. + initial_result: dict[str, Any] | None = None + final_result: dict[str, Any] | None = None + corrections: list[GroundTruthCorrection] = Field(default_factory=list) + review_flags: list[str] = Field(default_factory=list) + confidence: float | None = None + + +class FieldCorrectionCount(BaseModel): + field_path: str + count: int + + +class GroundTruthStats(BaseModel): + """High-level dataset health report surfaced by ``GET /ground-truth/stats``.""" + + total_jobs: int + completed_jobs: int + needs_review_jobs: int + failed_jobs: int + approved_jobs: int + total_corrections: int + jobs_with_corrections: int + # Most-corrected field paths (descending). Operators use this to + # prioritise which fields to target with prompt tweaks or fine-tune + # data collection first. + top_corrected_fields: list[FieldCorrectionCount] = Field(default_factory=list) diff --git a/src/ocr_sprint/tools/__init__.py b/src/ocr_sprint/tools/__init__.py new file mode 100644 index 0000000..fbbc320 --- /dev/null +++ b/src/ocr_sprint/tools/__init__.py @@ -0,0 +1 @@ +"""Operator CLI utilities. Invoked via ``python -m ocr_sprint.tools.``.""" diff --git a/src/ocr_sprint/tools/export_ground_truth.py b/src/ocr_sprint/tools/export_ground_truth.py new file mode 100644 index 0000000..b89e470 --- /dev/null +++ b/src/ocr_sprint/tools/export_ground_truth.py @@ -0,0 +1,137 @@ +"""CLI: ``python -m ocr_sprint.tools.export_ground_truth``. + +Dumps the HITL ground-truth corpus to a local JSONL file. The HTTP +endpoint is fine for small queries, but weekly / monthly snapshots are +easier to schedule on the host via cron than via curl. + +Usage +----- + python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl + python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl \ + --since 2025-01-01 --has-corrections + +The database URL is read from the same ``DATABASE_URL`` env var the API +uses, so this script works out of the box on any box that can run the +service. +""" + +from __future__ import annotations + +import argparse +import sys +from datetime import datetime +from pathlib import Path +from typing import Protocol + +from ocr_sprint.db.base import session_scope +from ocr_sprint.ground_truth import ( + GroundTruthFilters, + ground_truth_stats, + iter_ground_truth_samples, + serialize_sample_to_jsonl, +) + + +def _parse_iso(raw: str | None) -> datetime | None: + if raw is None: + return None + # ``fromisoformat`` supports dates too, so ``--since 2025-01-01`` + # behaves like ``2025-01-01T00:00:00``. + return datetime.fromisoformat(raw) + + +def _build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser( + prog="ocr_sprint.tools.export_ground_truth", + description="Dump HITL ground-truth corpus to JSONL.", + ) + p.add_argument( + "--out", + required=True, + type=Path, + help="Output JSONL file. Use '-' to write to stdout.", + ) + p.add_argument("--since", type=str, default=None, help="ISO date/datetime lower bound.") + p.add_argument("--until", type=str, default=None, help="ISO date/datetime upper bound.") + p.add_argument( + "--include-unapproved", + action="store_true", + help=( + "Include jobs that haven't been approved yet. Not recommended " + "for training data — use only for debugging exports." + ), + ) + group = p.add_mutually_exclusive_group() + group.add_argument( + "--has-corrections", + action="store_true", + help="Only include jobs that had at least one correction.", + ) + group.add_argument( + "--no-corrections", + action="store_true", + help="Only include pristine jobs (no corrections). Useful for sanity sets.", + ) + p.add_argument("--limit", type=int, default=None, help="Maximum samples to emit.") + p.add_argument( + "--print-stats", + action="store_true", + help="Also print dataset summary to stderr after the dump.", + ) + return p + + +def main(argv: list[str] | None = None) -> int: + args = _build_parser().parse_args(argv) + + has_corrections: bool | None + if args.has_corrections: + has_corrections = True + elif args.no_corrections: + has_corrections = False + else: + has_corrections = None + + filters = GroundTruthFilters( + since=_parse_iso(args.since), + until=_parse_iso(args.until), + approved_only=not args.include_unapproved, + has_corrections=has_corrections, + limit=args.limit, + ) + + out: Path = args.out + count = 0 + if str(out) == "-": + _write_stream(sys.stdout.buffer, filters) + else: + out.parent.mkdir(parents=True, exist_ok=True) + with out.open("wb") as fh: + count = _write_stream(fh, filters) + + print(f"wrote {count} sample(s) to {out}", file=sys.stderr) + + if args.print_stats: + with session_scope() as session: + stats = ground_truth_stats(session) + print(stats.model_dump_json(indent=2), file=sys.stderr) + + return 0 + + +class _BinaryWriter(Protocol): + def write(self, data: bytes) -> int: ... + + +def _write_stream(fh: _BinaryWriter, filters: GroundTruthFilters) -> int: + """Write JSONL to ``fh``. Returns the number of samples written.""" + count = 0 + with session_scope() as session: + for sample in iter_ground_truth_samples(session, filters): + fh.write(serialize_sample_to_jsonl(sample).encode("utf-8")) + count += 1 + return count + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/unit/test_api_ground_truth.py b/tests/unit/test_api_ground_truth.py new file mode 100644 index 0000000..4168efd --- /dev/null +++ b/tests/unit/test_api_ground_truth.py @@ -0,0 +1,158 @@ +"""HTTP tests for the ground-truth export endpoints.""" + +from __future__ import annotations + +import json +from datetime import date + +import pytest +from fastapi.testclient import TestClient + +from ocr_sprint.main import create_app +from ocr_sprint.pipeline import orchestrator as orch_module +from ocr_sprint.pipeline.orchestrator import PipelineOutput +from ocr_sprint.schemas.document import DocumentStatus, SourceKind +from ocr_sprint.schemas.extraction import ( + ExtractionResult, + HeaderFields, + PersonnelEntry, +) + + +@pytest.fixture +def client() -> TestClient: + return TestClient(create_app()) + + +@pytest.fixture +def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput: + result = ExtractionResult( + header=HeaderFields( + nomor_sprint="Sprin/1/I/2025", + tanggal=date(2025, 1, 1), + satuan_penerbit="POLRES TEST", + ), + personel=[ + PersonnelEntry(pangkat="AIPDA", nrp="77060000", nama="BUDI", jabatan="ANGGOTA"), + ], + confidence=0.9, + ) + output = PipelineOutput( + source_kind=SourceKind.PDF, + status=DocumentStatus.COMPLETED, + confidence=0.9, + result=result, + ) + + def _fake_run(_content: bytes) -> PipelineOutput: + return output + + monkeypatch.setattr(orch_module, "run_pipeline", _fake_run) + from ocr_sprint.api.routes import documents as docs_module + + monkeypatch.setattr(docs_module, "run_pipeline", _fake_run) + from ocr_sprint.worker import tasks as tasks_module + + monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run) + return output + + +def _create_and_approve(client: TestClient, *, correction_value: str | None = None) -> str: + post = client.post( + "/api/v1/documents?sync=true", + files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, + ) + assert post.status_code == 200, post.text + jid = str(post.json()["job_id"]) + if correction_value is not None: + patched = client.patch( + f"/api/v1/documents/{jid}", + json={"corrections": [{"path": "header.perihal", "value": correction_value}]}, + ) + assert patched.status_code == 200 + approved = client.post(f"/api/v1/documents/{jid}/approve") + assert approved.status_code == 200 + return jid + + +def test_stats_empty_dataset(client: TestClient) -> None: + resp = client.get("/api/v1/ground-truth/stats") + assert resp.status_code == 200 + body = resp.json() + assert body["total_jobs"] == 0 + assert body["approved_jobs"] == 0 + assert body["total_corrections"] == 0 + assert body["top_corrected_fields"] == [] + + +def test_stats_rolls_up_counts(client: TestClient, fake_pipeline: PipelineOutput) -> None: + _create_and_approve(client, correction_value="Penyelidikan-1") + _create_and_approve(client, correction_value="Penyelidikan-2") + _create_and_approve(client, correction_value=None) # pristine + + resp = client.get("/api/v1/ground-truth/stats") + assert resp.status_code == 200 + body = resp.json() + assert body["total_jobs"] == 3 + assert body["approved_jobs"] == 3 + assert body["total_corrections"] == 2 + assert body["jobs_with_corrections"] == 2 + assert body["top_corrected_fields"][0]["field_path"] == "header.perihal" + assert body["top_corrected_fields"][0]["count"] == 2 + + +def test_export_streams_jsonl(client: TestClient, fake_pipeline: PipelineOutput) -> None: + _create_and_approve(client, correction_value="Penyelidikan") + _create_and_approve(client, correction_value=None) + + resp = client.get("/api/v1/ground-truth/export") + assert resp.status_code == 200 + assert resp.headers["content-type"].startswith("application/x-ndjson") + lines = [line for line in resp.text.splitlines() if line.strip()] + assert len(lines) == 2 + parsed = [json.loads(line) for line in lines] + for sample in parsed: + assert sample["approved"] is True + assert "initial_result" in sample + assert "final_result" in sample + + +def test_export_approved_only_default(client: TestClient, fake_pipeline: PipelineOutput) -> None: + """Unapproved jobs shouldn't appear in the default export.""" + # One approved, one just completed (no approve call). + _create_and_approve(client, correction_value=None) + client.post( + "/api/v1/documents?sync=true", + files={"file": ("y.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, + ) + resp = client.get("/api/v1/ground-truth/export") + lines = [line for line in resp.text.splitlines() if line.strip()] + assert len(lines) == 1 + + # Toggle approved_only=false to include both. + resp = client.get("/api/v1/ground-truth/export?approved_only=false") + lines = [line for line in resp.text.splitlines() if line.strip()] + assert len(lines) == 2 + + +def test_export_has_corrections_filter(client: TestClient, fake_pipeline: PipelineOutput) -> None: + _create_and_approve(client, correction_value="Penyelidikan") + _create_and_approve(client, correction_value=None) + + resp = client.get("/api/v1/ground-truth/export?has_corrections=true") + lines = [line for line in resp.text.splitlines() if line.strip()] + assert len(lines) == 1 + assert json.loads(lines[0])["corrections"][0]["new_value"] == "Penyelidikan" + + resp = client.get("/api/v1/ground-truth/export?has_corrections=false") + lines = [line for line in resp.text.splitlines() if line.strip()] + assert len(lines) == 1 + assert json.loads(lines[0])["corrections"] == [] + + +def test_export_respects_limit(client: TestClient, fake_pipeline: PipelineOutput) -> None: + for _ in range(5): + _create_and_approve(client, correction_value=None) + resp = client.get("/api/v1/ground-truth/export?limit=2") + lines = [line for line in resp.text.splitlines() if line.strip()] + assert len(lines) == 2 diff --git a/tests/unit/test_cli_export_ground_truth.py b/tests/unit/test_cli_export_ground_truth.py new file mode 100644 index 0000000..1cece54 --- /dev/null +++ b/tests/unit/test_cli_export_ground_truth.py @@ -0,0 +1,80 @@ +"""Smoke tests for the ``ocr_sprint.tools.export_ground_truth`` CLI.""" + +from __future__ import annotations + +import json +from pathlib import Path +from uuid import uuid4 + +import pytest + +from ocr_sprint.db.base import Base, get_engine, session_scope +from ocr_sprint.db.repositories import JobRepository +from ocr_sprint.schemas.document import DocumentStatus, SourceKind +from ocr_sprint.tools.export_ground_truth import main as export_main + + +@pytest.fixture +def db_ready() -> None: + Base.metadata.create_all(bind=get_engine()) + + +def _seed_two_approved_jobs() -> None: + for _ in range(2): + jid = uuid4() + with session_scope() as session: + JobRepository(session).create( + job_id=jid, filename="x.pdf", source_kind=SourceKind.PDF, blob_key="k" + ) + with session_scope() as session: + JobRepository(session).mark_completed( + jid, + status=DocumentStatus.COMPLETED, + confidence=0.9, + result={"header": {"nomor_sprint": "SPR/1/2025"}}, + review_flags=[], + ) + with session_scope() as session: + JobRepository(session).approve(jid, reviewed_by="rev") + + +def test_cli_writes_expected_number_of_lines( + db_ready: None, tmp_path: Path, capsys: pytest.CaptureFixture[str] +) -> None: + _seed_two_approved_jobs() + out = tmp_path / "corpus.jsonl" + + exit_code = export_main(["--out", str(out)]) + assert exit_code == 0 + lines = [line for line in out.read_text().splitlines() if line.strip()] + assert len(lines) == 2 + for line in lines: + parsed = json.loads(line) + assert parsed["approved"] is True + + stderr = capsys.readouterr().err + assert "wrote 2 sample(s)" in stderr + + +def test_cli_respects_limit(db_ready: None, tmp_path: Path) -> None: + _seed_two_approved_jobs() + out = tmp_path / "corpus.jsonl" + exit_code = export_main(["--out", str(out), "--limit", "1"]) + assert exit_code == 0 + lines = [line for line in out.read_text().splitlines() if line.strip()] + assert len(lines) == 1 + + +def test_cli_print_stats_emits_json_to_stderr( + db_ready: None, tmp_path: Path, capsys: pytest.CaptureFixture[str] +) -> None: + _seed_two_approved_jobs() + out = tmp_path / "corpus.jsonl" + exit_code = export_main(["--out", str(out), "--print-stats"]) + assert exit_code == 0 + stderr = capsys.readouterr().err + # Validate the JSON prologue (after the "wrote N" line). + json_start = stderr.index("{") + stats = json.loads(stderr[json_start:]) + assert stats["total_jobs"] == 2 + assert stats["approved_jobs"] == 2 diff --git a/tests/unit/test_ground_truth_service.py b/tests/unit/test_ground_truth_service.py new file mode 100644 index 0000000..694390f --- /dev/null +++ b/tests/unit/test_ground_truth_service.py @@ -0,0 +1,210 @@ +"""Unit tests for the ground-truth export service.""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from uuid import UUID, uuid4 + +import pytest + +from ocr_sprint.db.base import Base, get_engine, session_scope +from ocr_sprint.db.repositories import JobRepository +from ocr_sprint.ground_truth import ( + GroundTruthFilters, + build_initial_result, + ground_truth_stats, + iter_ground_truth_samples, + serialize_sample_to_jsonl, +) +from ocr_sprint.schemas.document import DocumentStatus, SourceKind + + +@pytest.fixture +def db_ready() -> None: + Base.metadata.create_all(bind=get_engine()) + + +def _seed_approved_job_with_corrections( + *, + final_result: dict[str, object] | None = None, + corrections: list[tuple[str, object, object]] | None = None, + approved: bool = True, +) -> UUID: + jid = uuid4() + with session_scope() as session: + repo = JobRepository(session) + repo.create(job_id=jid, filename="x.pdf", source_kind=SourceKind.PDF, blob_key="k") + with session_scope() as session: + JobRepository(session).mark_completed( + jid, + status=DocumentStatus.NEEDS_REVIEW, + confidence=0.8, + result=final_result + or { + "header": {"nomor_sprint": "SPR/1/2025", "satuan_penerbit": "POLRES X"}, + "personel": [{"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"}], + }, + review_flags=[], + ) + if corrections: + # Corrections are tuples of (path, value, old_value_for_audit). + # We translate them into real PATCH calls so the audit rows look + # exactly like they would in production. + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[(p, new, None) for (p, new, _old) in corrections], + corrected_by="reviewer-a", + ) + if approved: + with session_scope() as session: + JobRepository(session).approve(jid, reviewed_by="reviewer-a") + return jid + + +def test_build_initial_result_reverses_corrections() -> None: + """With known old/new pairs the replay must return the pre-HITL dict.""" + final = { + "header": {"nomor_sprint": "SPR/1/2025", "perihal": "Penyelidikan"}, + } + + class _Fake: + def __init__(self, path: str, old: object, new: object) -> None: + self.field_path = path + self.old_value = old + self.new_value = new + + corrections = [ + _Fake("header.perihal", None, "Penyelidikan"), + _Fake("header.nomor_sprint", "SPR/OLD/2025", "SPR/1/2025"), + ] + restored = build_initial_result(final_result=final, corrections=corrections) # type: ignore[arg-type] + assert restored == { + "header": {"nomor_sprint": "SPR/OLD/2025", "perihal": None}, + } + # ``final`` must not have been mutated. + assert final["header"]["nomor_sprint"] == "SPR/1/2025" + + +def test_build_initial_result_ignores_unresolvable_paths() -> None: + """A legacy path that no longer resolves should be skipped, not raise.""" + final = {"header": {"nomor_sprint": "SPR/1"}} + + class _Fake: + def __init__(self, path: str, old: object, new: object) -> None: + self.field_path = path + self.old_value = old + self.new_value = new + + corrections = [ + _Fake("personel[99].nrp", "stale", "new"), # container doesn't exist + _Fake("bogus..path", "x", "y"), + ] + restored = build_initial_result(final_result=final, corrections=corrections) # type: ignore[arg-type] + assert restored == final + + +def test_iter_samples_reconstructs_initial_and_final(db_ready: None) -> None: + jid = _seed_approved_job_with_corrections( + corrections=[("header.perihal", "Penyelidikan", None)], + ) + with session_scope() as session: + samples = list(iter_ground_truth_samples(session, GroundTruthFilters())) + assert len(samples) == 1 + s = samples[0] + assert s.job_id == jid + assert s.approved is True + assert s.reviewed_by == "reviewer-a" + assert s.final_result is not None + assert s.final_result["header"]["perihal"] == "Penyelidikan" + # Initial reconstruction undoes the correction. + assert s.initial_result is not None + assert s.initial_result["header"].get("perihal") is None + assert len(s.corrections) == 1 + assert s.corrections[0].field_path == "header.perihal" + + +def test_iter_samples_respects_approved_only_default(db_ready: None) -> None: + """Unapproved jobs must be excluded unless approved_only=False.""" + _seed_approved_job_with_corrections(approved=False) + with session_scope() as session: + samples = list(iter_ground_truth_samples(session, GroundTruthFilters())) + assert samples == [] + with session_scope() as session: + samples = list(iter_ground_truth_samples(session, GroundTruthFilters(approved_only=False))) + assert len(samples) == 1 + + +def test_iter_samples_respects_has_corrections_filter(db_ready: None) -> None: + _seed_approved_job_with_corrections(corrections=None) # pristine + _seed_approved_job_with_corrections(corrections=[("header.perihal", "fill", None)]) + with session_scope() as session: + with_ = list(iter_ground_truth_samples(session, GroundTruthFilters(has_corrections=True))) + without = list( + iter_ground_truth_samples(session, GroundTruthFilters(has_corrections=False)) + ) + assert len(with_) == 1 + assert len(without) == 1 + assert with_[0].job_id != without[0].job_id + + +def test_iter_samples_respects_since_until_and_limit(db_ready: None) -> None: + _seed_approved_job_with_corrections() + _seed_approved_job_with_corrections() + _seed_approved_job_with_corrections() + + # Since far in the future → empty. + future = datetime(2999, 1, 1, tzinfo=timezone.utc) + with session_scope() as session: + out = list(iter_ground_truth_samples(session, GroundTruthFilters(since=future))) + assert out == [] + + # Limit caps the number emitted. + with session_scope() as session: + out = list(iter_ground_truth_samples(session, GroundTruthFilters(limit=2))) + assert len(out) == 2 + + +def test_stats_counts_rollup_and_top_fields(db_ready: None) -> None: + # 1 approved job with 2 corrections on ``header.perihal``. + _seed_approved_job_with_corrections( + corrections=[ + ("header.perihal", "first", None), + ] + ) + jid = _seed_approved_job_with_corrections( + corrections=[ + ("header.perihal", "second", None), + ("personel[0].nrp", "77060001", None), + ], + ) + assert jid # silence unused + + with session_scope() as session: + stats = ground_truth_stats(session) + + assert stats.total_jobs == 2 + assert stats.approved_jobs == 2 + assert stats.total_corrections == 3 + assert stats.jobs_with_corrections == 2 + # ``header.perihal`` is the most-corrected field (2 > 1). + assert stats.top_corrected_fields[0].field_path == "header.perihal" + assert stats.top_corrected_fields[0].count == 2 + assert {f.field_path for f in stats.top_corrected_fields} == { + "header.perihal", + "personel[0].nrp", + } + + +def test_serialize_is_valid_jsonl(db_ready: None) -> None: + _seed_approved_job_with_corrections(corrections=[("header.perihal", "X", None)]) + with session_scope() as session: + for sample in iter_ground_truth_samples(session, GroundTruthFilters()): + line = serialize_sample_to_jsonl(sample) + assert line.endswith("\n") + # Strip trailing newline and make sure the result is a single + # JSON object (not a list — JSONL must be one object per line). + parsed = json.loads(line) + assert isinstance(parsed, dict) + assert parsed["approved"] is True