Merge pull request #6 from Adriankf59/devin/1777148426-phase-7-ground-truth
Phase 7: ground-truth export (JSONL + stats) + CLI tool
This commit is contained in:
112
docs/ground-truth-format.md
Normal file
112
docs/ground-truth-format.md
Normal file
@@ -0,0 +1,112 @@
|
||||
# Ground-truth export format (Phase 7)
|
||||
|
||||
The service exposes the HITL corpus as [JSONL](https://jsonlines.org/) —
|
||||
one training sample per line — via the HTTP endpoint
|
||||
`GET /api/v1/ground-truth/export` and the equivalent CLI:
|
||||
|
||||
```bash
|
||||
python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl
|
||||
```
|
||||
|
||||
Both paths read the same database and emit byte-identical JSONL, so a
|
||||
cron-scheduled dump and an ad-hoc curl download are interchangeable.
|
||||
|
||||
## Sample schema
|
||||
|
||||
Each line is a single JSON object with the following shape:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"job_id": "c5da6747-...",
|
||||
"filename": "sprint-042.pdf",
|
||||
"source_kind": "pdf",
|
||||
"approved": true,
|
||||
"reviewed_by": "reviewer-a", // free-form; comes from X-User-Id
|
||||
"reviewed_at": "2025-06-01T10:15:00Z",
|
||||
"created_at": "2025-05-28T08:02:17Z",
|
||||
|
||||
// The pipeline's original pre-HITL output, reconstructed by replaying
|
||||
// the audit trail backwards. `null` for jobs that never produced a
|
||||
// result (e.g. hard-failed on OCR).
|
||||
"initial_result": {
|
||||
"header": { "nomor_sprint": "Sprin/1/I/2025", "perihal": null, ... },
|
||||
"personel": [ { "pangkat": "AIPDA", "nrp": "77060000", ... } ],
|
||||
...
|
||||
},
|
||||
|
||||
// The reviewer-approved answer (current value of jobs.result).
|
||||
"final_result": { ...same shape as initial_result... },
|
||||
|
||||
// Every correction event, in chronological order.
|
||||
"corrections": [
|
||||
{
|
||||
"field_path": "header.perihal",
|
||||
"old_value": null,
|
||||
"new_value": "Penyelidikan kasus pencurian",
|
||||
"corrected_by": "reviewer-a",
|
||||
"reason": "LLM missed it",
|
||||
"corrected_at": "2025-05-30T14:00:00Z"
|
||||
}
|
||||
],
|
||||
|
||||
"review_flags": ["llm_fallback"],
|
||||
"confidence": 0.78
|
||||
}
|
||||
```
|
||||
|
||||
## Recommended filters
|
||||
|
||||
* `approved_only=true` (default) — **do not** train on unreviewed
|
||||
samples; they can still contain OCR mistakes.
|
||||
* `has_corrections=true` — for a "hard examples" set where the pipeline
|
||||
was originally wrong.
|
||||
* `has_corrections=false` — for a "sanity" set where the pipeline was
|
||||
already right. Good for regression tests after fine-tuning.
|
||||
* `since` / `until` — build incremental snapshots without re-processing
|
||||
the full history.
|
||||
|
||||
## When is the dataset big enough to fine-tune?
|
||||
|
||||
Rough operational checklist (rules of thumb — adjust based on your own
|
||||
error analysis):
|
||||
|
||||
| Bucket | Minimum rows | Notes |
|
||||
|---------------------------------|--------------|---------------------------------------------------------------|
|
||||
| LoRA on header extraction (LLM) | ~200–500 | Per-field error signal must be > random noise. |
|
||||
| Per-satuan prompt tuning | ~50 / satuan | Helps when formats differ sharply between Polda/Polres units. |
|
||||
| PP-Structure table fine-tune | ~1 000+ | Layout models are data-hungry; hold off until HITL is steady. |
|
||||
|
||||
Use `GET /api/v1/ground-truth/stats` to check coverage:
|
||||
|
||||
```json
|
||||
{
|
||||
"total_jobs": 842,
|
||||
"approved_jobs": 613,
|
||||
"total_corrections": 1 204,
|
||||
"jobs_with_corrections": 431,
|
||||
"top_corrected_fields": [
|
||||
{ "field_path": "header.perihal", "count": 289 },
|
||||
{ "field_path": "personel[0].nrp", "count": 51 },
|
||||
...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Fields at the top of `top_corrected_fields` are the highest-leverage
|
||||
targets for prompt tweaks, regex upgrades, or (eventually) fine-tuning.
|
||||
|
||||
## Fine-tuning outside this repo
|
||||
|
||||
The export is deliberately framework-agnostic. Suggested follow-ups on
|
||||
dedicated GPU hardware:
|
||||
|
||||
* [**Unsloth**](https://github.com/unslothai/unsloth) — LoRA on
|
||||
Qwen2.5 / Llama 3.1 with 2–4 × speedups on a single GPU.
|
||||
* [**Axolotl**](https://github.com/axolotl-ai-cloud/axolotl) — more
|
||||
batteries-included; good for multi-GPU runs.
|
||||
|
||||
Typical prompt-completion conversion: feed `initial_result` (or the raw
|
||||
OCR text, if your pipeline keeps it) as the "input" and `final_result`
|
||||
as the "output". The `corrections` list is only needed if you want to
|
||||
build an error-class analysis — the model itself trains on the final
|
||||
answer.
|
||||
102
src/ocr_sprint/api/routes/ground_truth.py
Normal file
102
src/ocr_sprint/api/routes/ground_truth.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Ground-truth export + statistics endpoints (Phase 7).
|
||||
|
||||
Two endpoints, both auth'd by the existing ``X-API-Key`` dependency:
|
||||
|
||||
* ``GET /ground-truth/export`` — streams JSONL of approved (or filtered)
|
||||
samples for downstream fine-tuning pipelines.
|
||||
* ``GET /ground-truth/stats`` — returns aggregate counts + top-corrected
|
||||
field paths so operators know when/where fine-tuning will pay off.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ocr_sprint.api.deps.auth import require_api_key
|
||||
from ocr_sprint.api.deps.db import get_session
|
||||
from ocr_sprint.ground_truth import (
|
||||
GroundTruthFilters,
|
||||
ground_truth_stats,
|
||||
iter_ground_truth_samples,
|
||||
serialize_sample_to_jsonl,
|
||||
)
|
||||
from ocr_sprint.schemas.ground_truth import GroundTruthStats
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/ground-truth",
|
||||
tags=["ground-truth"],
|
||||
dependencies=[Depends(require_api_key)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/export",
|
||||
response_class=StreamingResponse,
|
||||
responses={
|
||||
200: {
|
||||
"content": {"application/x-ndjson": {}},
|
||||
"description": "Newline-delimited JSON stream of training samples.",
|
||||
}
|
||||
},
|
||||
)
|
||||
def export_ground_truth(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
since: Annotated[
|
||||
datetime | None,
|
||||
Query(description="Only include jobs created at or after this ISO timestamp."),
|
||||
] = None,
|
||||
until: Annotated[
|
||||
datetime | None,
|
||||
Query(description="Only include jobs created at or before this ISO timestamp."),
|
||||
] = None,
|
||||
approved_only: Annotated[
|
||||
bool,
|
||||
Query(description="Only export approved jobs (default true)."),
|
||||
] = True,
|
||||
has_corrections: Annotated[
|
||||
bool | None,
|
||||
Query(
|
||||
description=(
|
||||
"Optional: true = only jobs that had at least one correction, "
|
||||
"false = only pristine (no-correction) jobs."
|
||||
)
|
||||
),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, le=100_000, description="Maximum rows to emit."),
|
||||
] = None,
|
||||
) -> StreamingResponse:
|
||||
filters = GroundTruthFilters(
|
||||
since=since,
|
||||
until=until,
|
||||
approved_only=approved_only,
|
||||
has_corrections=has_corrections,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def _stream() -> Iterator[bytes]:
|
||||
for sample in iter_ground_truth_samples(session, filters):
|
||||
yield serialize_sample_to_jsonl(sample).encode("utf-8")
|
||||
|
||||
return StreamingResponse(_stream(), media_type="application/x-ndjson")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=GroundTruthStats,
|
||||
)
|
||||
def get_stats(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
top_n: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="How many top-corrected field paths to return."),
|
||||
] = 10,
|
||||
) -> GroundTruthStats:
|
||||
return ground_truth_stats(session, top_n=top_n)
|
||||
22
src/ocr_sprint/ground_truth/__init__.py
Normal file
22
src/ocr_sprint/ground_truth/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Phase 7 — ground-truth export service.
|
||||
|
||||
Consumes the ``jobs`` + ``job_corrections`` tables to produce JSONL
|
||||
training samples + aggregate statistics. No ML dependencies here: the
|
||||
actual fine-tuning runs in a separate project on dedicated hardware.
|
||||
"""
|
||||
|
||||
from ocr_sprint.ground_truth.service import (
|
||||
GroundTruthFilters,
|
||||
build_initial_result,
|
||||
ground_truth_stats,
|
||||
iter_ground_truth_samples,
|
||||
serialize_sample_to_jsonl,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GroundTruthFilters",
|
||||
"build_initial_result",
|
||||
"ground_truth_stats",
|
||||
"iter_ground_truth_samples",
|
||||
"serialize_sample_to_jsonl",
|
||||
]
|
||||
248
src/ocr_sprint/ground_truth/service.py
Normal file
248
src/ocr_sprint/ground_truth/service.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Core ground-truth export logic.
|
||||
|
||||
The service reads from ``jobs`` and ``job_corrections`` and produces
|
||||
``GroundTruthSample`` rows. The HTTP layer and the CLI both go through
|
||||
this module, so the export logic stays in exactly one place.
|
||||
|
||||
Design notes
|
||||
------------
|
||||
* **Reverse-replay for ``initial_result``.** Audit rows store both
|
||||
``old_value`` and ``new_value``. We start from the current
|
||||
``final_result`` and undo each correction in reverse chronological
|
||||
order to recover the pipeline's original pre-HITL output. This is
|
||||
what fine-tuning needs: the raw mistake next to the reviewer fix.
|
||||
* **No deep copying in the caller.** The service always returns freshly
|
||||
computed dicts so mutating one sample can't leak into the next.
|
||||
* **Generator by design.** Large exports shouldn't load the whole table
|
||||
into memory; the HTTP route streams JSONL line-by-line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ocr_sprint.db.models import JobCorrectionRow, JobRow
|
||||
from ocr_sprint.schemas.document import DocumentStatus
|
||||
from ocr_sprint.schemas.ground_truth import (
|
||||
FieldCorrectionCount,
|
||||
GroundTruthCorrection,
|
||||
GroundTruthSample,
|
||||
GroundTruthStats,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GroundTruthFilters:
|
||||
"""HTTP / CLI filters for ``iter_ground_truth_samples``.
|
||||
|
||||
All fields are optional; ``approved_only=True`` is the default because
|
||||
only approved samples are actually safe to train on.
|
||||
"""
|
||||
|
||||
since: datetime | None = None
|
||||
until: datetime | None = None
|
||||
approved_only: bool = True
|
||||
has_corrections: bool | None = None
|
||||
limit: int | None = None
|
||||
|
||||
|
||||
def build_initial_result(
|
||||
*,
|
||||
final_result: dict[str, Any] | None,
|
||||
corrections: list[JobCorrectionRow],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Replay ``corrections`` backwards over ``final_result`` to recover
|
||||
the pre-HITL version. Returns ``None`` if there's nothing to
|
||||
reconstruct.
|
||||
"""
|
||||
if final_result is None:
|
||||
return None
|
||||
restored = copy.deepcopy(final_result)
|
||||
for event in reversed(corrections):
|
||||
_assign_path(restored, event.field_path, event.old_value)
|
||||
return restored
|
||||
|
||||
|
||||
def _assign_path(data: dict[str, Any], path: str, value: Any) -> None:
|
||||
"""Set ``data[path] = value`` using the same dotted syntax the HITL
|
||||
repository understands. Silently ignores paths that no longer resolve
|
||||
(the dict shape may have drifted since the event was recorded; in
|
||||
that case we just skip the replay step rather than raising).
|
||||
"""
|
||||
segments = _parse(path)
|
||||
if not segments:
|
||||
return
|
||||
|
||||
cursor: Any = data
|
||||
for name, idx in segments[:-1]:
|
||||
if not isinstance(cursor, dict) or name not in cursor:
|
||||
return
|
||||
cursor = cursor[name]
|
||||
if idx is not None:
|
||||
if not isinstance(cursor, list) or idx >= len(cursor):
|
||||
return
|
||||
cursor = cursor[idx]
|
||||
|
||||
name, idx = segments[-1]
|
||||
if idx is not None:
|
||||
if not isinstance(cursor, dict) or name not in cursor:
|
||||
return
|
||||
container = cursor[name]
|
||||
if not isinstance(container, list) or idx >= len(container):
|
||||
return
|
||||
container[idx] = value
|
||||
return
|
||||
|
||||
if not isinstance(cursor, dict):
|
||||
return
|
||||
cursor[name] = value
|
||||
|
||||
|
||||
_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$")
|
||||
|
||||
|
||||
def _parse(path: str) -> list[tuple[str, int | None]]:
|
||||
"""Forgiving dotted-path parser — returns ``[]`` on malformed input
|
||||
instead of raising, since audit-log replay must not blow up on a
|
||||
stale path. The strict variant used by the HITL PATCH route lives
|
||||
in ``ocr_sprint.db.repositories``.
|
||||
"""
|
||||
if not path or path.startswith(".") or path.endswith("."):
|
||||
return []
|
||||
out: list[tuple[str, int | None]] = []
|
||||
for part in path.split("."):
|
||||
match: re.Match[str] | None = _SEGMENT_RE.match(part)
|
||||
if match is None:
|
||||
return []
|
||||
out.append((match.group(1), int(match.group(2)) if match.group(2) else None))
|
||||
return out
|
||||
|
||||
|
||||
def iter_ground_truth_samples(
|
||||
session: Session,
|
||||
filters: GroundTruthFilters,
|
||||
) -> Iterator[GroundTruthSample]:
|
||||
"""Yield ``GroundTruthSample`` rows matching ``filters``.
|
||||
|
||||
Rows are ordered by ``created_at ASC`` so repeated exports are
|
||||
deterministic and downstream diffing works.
|
||||
"""
|
||||
stmt = select(JobRow).order_by(JobRow.created_at, JobRow.job_id)
|
||||
if filters.approved_only:
|
||||
stmt = stmt.where(JobRow.approved.is_(True))
|
||||
if filters.since is not None:
|
||||
stmt = stmt.where(JobRow.created_at >= filters.since)
|
||||
if filters.until is not None:
|
||||
stmt = stmt.where(JobRow.created_at <= filters.until)
|
||||
|
||||
remaining = filters.limit
|
||||
for job_row in session.scalars(stmt):
|
||||
if remaining is not None and remaining <= 0:
|
||||
return
|
||||
|
||||
correction_rows = list(
|
||||
session.scalars(
|
||||
select(JobCorrectionRow)
|
||||
.where(JobCorrectionRow.job_id == job_row.job_id)
|
||||
.order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id)
|
||||
)
|
||||
)
|
||||
|
||||
if filters.has_corrections is True and not correction_rows:
|
||||
continue
|
||||
if filters.has_corrections is False and correction_rows:
|
||||
continue
|
||||
|
||||
initial = build_initial_result(final_result=job_row.result, corrections=correction_rows)
|
||||
sample = GroundTruthSample(
|
||||
job_id=job_row.job_id,
|
||||
filename=job_row.filename,
|
||||
source_kind=job_row.source_kind,
|
||||
approved=bool(job_row.approved),
|
||||
reviewed_by=job_row.reviewed_by,
|
||||
reviewed_at=job_row.reviewed_at,
|
||||
created_at=job_row.created_at,
|
||||
initial_result=initial,
|
||||
final_result=copy.deepcopy(job_row.result) if job_row.result else None,
|
||||
corrections=[
|
||||
GroundTruthCorrection(
|
||||
field_path=c.field_path,
|
||||
old_value=c.old_value,
|
||||
new_value=c.new_value,
|
||||
corrected_by=c.corrected_by,
|
||||
reason=c.reason,
|
||||
corrected_at=c.corrected_at,
|
||||
)
|
||||
for c in correction_rows
|
||||
],
|
||||
review_flags=list(job_row.review_flags or []),
|
||||
confidence=job_row.confidence,
|
||||
)
|
||||
if remaining is not None:
|
||||
remaining -= 1
|
||||
yield sample
|
||||
|
||||
|
||||
def ground_truth_stats(session: Session, *, top_n: int = 10) -> GroundTruthStats:
|
||||
"""Compute aggregate statistics over the jobs + corrections tables."""
|
||||
by_status: dict[str, int] = {
|
||||
row[0]: int(row[1])
|
||||
for row in session.execute(
|
||||
select(JobRow.status, func.count(JobRow.job_id)).group_by(JobRow.status)
|
||||
).all()
|
||||
}
|
||||
total_jobs = int(sum(by_status.values()))
|
||||
approved_jobs = int(
|
||||
session.execute(
|
||||
select(func.count(JobRow.job_id)).where(JobRow.approved.is_(True))
|
||||
).scalar_one()
|
||||
)
|
||||
|
||||
total_corrections = int(session.execute(select(func.count(JobCorrectionRow.id))).scalar_one())
|
||||
jobs_with_corrections = int(
|
||||
session.execute(select(func.count(func.distinct(JobCorrectionRow.job_id)))).scalar_one()
|
||||
)
|
||||
|
||||
# Field-path histogram. A single SQL ``GROUP BY`` is cheaper than
|
||||
# pulling every row over the wire, but we still cap at ``top_n`` so
|
||||
# a pathological dataset can't drown the response.
|
||||
top_rows = session.execute(
|
||||
select(
|
||||
JobCorrectionRow.field_path,
|
||||
func.count(JobCorrectionRow.id).label("n"),
|
||||
)
|
||||
.group_by(JobCorrectionRow.field_path)
|
||||
.order_by(func.count(JobCorrectionRow.id).desc(), JobCorrectionRow.field_path)
|
||||
.limit(top_n)
|
||||
).all()
|
||||
|
||||
return GroundTruthStats(
|
||||
total_jobs=total_jobs,
|
||||
completed_jobs=by_status.get(DocumentStatus.COMPLETED.value, 0),
|
||||
needs_review_jobs=by_status.get(DocumentStatus.NEEDS_REVIEW.value, 0),
|
||||
failed_jobs=by_status.get(DocumentStatus.FAILED.value, 0),
|
||||
approved_jobs=approved_jobs,
|
||||
total_corrections=total_corrections,
|
||||
jobs_with_corrections=jobs_with_corrections,
|
||||
top_corrected_fields=[
|
||||
FieldCorrectionCount(field_path=row[0], count=int(row[1])) for row in top_rows
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def serialize_sample_to_jsonl(sample: GroundTruthSample) -> str:
|
||||
"""Serialize one sample as a newline-terminated JSON string.
|
||||
|
||||
Kept standalone so the HTTP and CLI layers share the exact same
|
||||
representation — diffs between a curl download and a CLI dump stay
|
||||
byte-identical.
|
||||
"""
|
||||
return sample.model_dump_json() + "\n"
|
||||
@@ -7,7 +7,7 @@ from fastapi import FastAPI
|
||||
from ocr_sprint import __version__
|
||||
from ocr_sprint.api.errors import register_error_handlers
|
||||
from ocr_sprint.api.metrics import MetricsMiddleware, metrics_endpoint
|
||||
from ocr_sprint.api.routes import documents, health
|
||||
from ocr_sprint.api.routes import documents, ground_truth, health
|
||||
from ocr_sprint.config import get_settings
|
||||
from ocr_sprint.db import models as _models # noqa: F401 (register ORM tables)
|
||||
from ocr_sprint.db.base import Base, get_engine
|
||||
@@ -43,6 +43,7 @@ def create_app() -> FastAPI:
|
||||
app.add_middleware(MetricsMiddleware)
|
||||
app.include_router(health.router, prefix="/api/v1")
|
||||
app.include_router(documents.router, prefix="/api/v1")
|
||||
app.include_router(ground_truth.router, prefix="/api/v1")
|
||||
app.add_api_route("/metrics", metrics_endpoint, methods=["GET"], include_in_schema=False)
|
||||
return app
|
||||
|
||||
|
||||
76
src/ocr_sprint/schemas/ground_truth.py
Normal file
76
src/ocr_sprint/schemas/ground_truth.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Schemas for the Phase 7 ground-truth export.
|
||||
|
||||
Each ``GroundTruthSample`` represents one training-ready example:
|
||||
|
||||
* ``initial_*`` snapshots the pipeline's original (pre-HITL) output,
|
||||
reconstructed by replaying the audit trail in reverse.
|
||||
* ``final_*`` is the current ``result`` on the ``jobs`` row — the
|
||||
reviewer-approved answer.
|
||||
* ``corrections`` is the raw audit trail so downstream fine-tuning can
|
||||
see *what* was changed, *why* (free-text reason), and by whom.
|
||||
|
||||
JSONL is emitted — one sample per line — so the file can be mmapped,
|
||||
streamed, or piped straight into an HF ``datasets.load_dataset("json",
|
||||
...)`` call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class GroundTruthCorrection(BaseModel):
|
||||
"""One row of the ``job_corrections`` audit trail, as exported."""
|
||||
|
||||
field_path: str
|
||||
old_value: Any | None = None
|
||||
new_value: Any | None = None
|
||||
corrected_by: str | None = None
|
||||
reason: str | None = None
|
||||
corrected_at: datetime
|
||||
|
||||
|
||||
class GroundTruthSample(BaseModel):
|
||||
"""One training sample written as a single JSONL line."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
job_id: UUID
|
||||
filename: str
|
||||
source_kind: str
|
||||
approved: bool = False
|
||||
reviewed_by: str | None = None
|
||||
reviewed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
# ``initial_*`` is the pipeline's pre-HITL answer, reconstructed from
|
||||
# the audit trail. ``final_*`` is the reviewer-approved version.
|
||||
initial_result: dict[str, Any] | None = None
|
||||
final_result: dict[str, Any] | None = None
|
||||
corrections: list[GroundTruthCorrection] = Field(default_factory=list)
|
||||
review_flags: list[str] = Field(default_factory=list)
|
||||
confidence: float | None = None
|
||||
|
||||
|
||||
class FieldCorrectionCount(BaseModel):
|
||||
field_path: str
|
||||
count: int
|
||||
|
||||
|
||||
class GroundTruthStats(BaseModel):
|
||||
"""High-level dataset health report surfaced by ``GET /ground-truth/stats``."""
|
||||
|
||||
total_jobs: int
|
||||
completed_jobs: int
|
||||
needs_review_jobs: int
|
||||
failed_jobs: int
|
||||
approved_jobs: int
|
||||
total_corrections: int
|
||||
jobs_with_corrections: int
|
||||
# Most-corrected field paths (descending). Operators use this to
|
||||
# prioritise which fields to target with prompt tweaks or fine-tune
|
||||
# data collection first.
|
||||
top_corrected_fields: list[FieldCorrectionCount] = Field(default_factory=list)
|
||||
1
src/ocr_sprint/tools/__init__.py
Normal file
1
src/ocr_sprint/tools/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Operator CLI utilities. Invoked via ``python -m ocr_sprint.tools.<name>``."""
|
||||
136
src/ocr_sprint/tools/export_ground_truth.py
Normal file
136
src/ocr_sprint/tools/export_ground_truth.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""CLI: ``python -m ocr_sprint.tools.export_ground_truth``.
|
||||
|
||||
Dumps the HITL ground-truth corpus to a local JSONL file. The HTTP
|
||||
endpoint is fine for small queries, but weekly / monthly snapshots are
|
||||
easier to schedule on the host via cron than via curl.
|
||||
|
||||
Usage
|
||||
-----
|
||||
python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl
|
||||
python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl \
|
||||
--since 2025-01-01 --has-corrections
|
||||
|
||||
The database URL is read from the same ``DATABASE_URL`` env var the API
|
||||
uses, so this script works out of the box on any box that can run the
|
||||
service.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from ocr_sprint.db.base import session_scope
|
||||
from ocr_sprint.ground_truth import (
|
||||
GroundTruthFilters,
|
||||
ground_truth_stats,
|
||||
iter_ground_truth_samples,
|
||||
serialize_sample_to_jsonl,
|
||||
)
|
||||
|
||||
|
||||
def _parse_iso(raw: str | None) -> datetime | None:
|
||||
if raw is None:
|
||||
return None
|
||||
# ``fromisoformat`` supports dates too, so ``--since 2025-01-01``
|
||||
# behaves like ``2025-01-01T00:00:00``.
|
||||
return datetime.fromisoformat(raw)
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(
|
||||
prog="ocr_sprint.tools.export_ground_truth",
|
||||
description="Dump HITL ground-truth corpus to JSONL.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--out",
|
||||
required=True,
|
||||
type=Path,
|
||||
help="Output JSONL file. Use '-' to write to stdout.",
|
||||
)
|
||||
p.add_argument("--since", type=str, default=None, help="ISO date/datetime lower bound.")
|
||||
p.add_argument("--until", type=str, default=None, help="ISO date/datetime upper bound.")
|
||||
p.add_argument(
|
||||
"--include-unapproved",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Include jobs that haven't been approved yet. Not recommended "
|
||||
"for training data — use only for debugging exports."
|
||||
),
|
||||
)
|
||||
group = p.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
"--has-corrections",
|
||||
action="store_true",
|
||||
help="Only include jobs that had at least one correction.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--no-corrections",
|
||||
action="store_true",
|
||||
help="Only include pristine jobs (no corrections). Useful for sanity sets.",
|
||||
)
|
||||
p.add_argument("--limit", type=int, default=None, help="Maximum samples to emit.")
|
||||
p.add_argument(
|
||||
"--print-stats",
|
||||
action="store_true",
|
||||
help="Also print dataset summary to stderr after the dump.",
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
args = _build_parser().parse_args(argv)
|
||||
|
||||
has_corrections: bool | None
|
||||
if args.has_corrections:
|
||||
has_corrections = True
|
||||
elif args.no_corrections:
|
||||
has_corrections = False
|
||||
else:
|
||||
has_corrections = None
|
||||
|
||||
filters = GroundTruthFilters(
|
||||
since=_parse_iso(args.since),
|
||||
until=_parse_iso(args.until),
|
||||
approved_only=not args.include_unapproved,
|
||||
has_corrections=has_corrections,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
out: Path = args.out
|
||||
if str(out) == "-":
|
||||
count = _write_stream(sys.stdout.buffer, filters)
|
||||
else:
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
with out.open("wb") as fh:
|
||||
count = _write_stream(fh, filters)
|
||||
|
||||
print(f"wrote {count} sample(s) to {out}", file=sys.stderr)
|
||||
|
||||
if args.print_stats:
|
||||
with session_scope() as session:
|
||||
stats = ground_truth_stats(session)
|
||||
print(stats.model_dump_json(indent=2), file=sys.stderr)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class _BinaryWriter(Protocol):
|
||||
def write(self, data: bytes) -> int: ...
|
||||
|
||||
|
||||
def _write_stream(fh: _BinaryWriter, filters: GroundTruthFilters) -> int:
|
||||
"""Write JSONL to ``fh``. Returns the number of samples written."""
|
||||
count = 0
|
||||
with session_scope() as session:
|
||||
for sample in iter_ground_truth_samples(session, filters):
|
||||
fh.write(serialize_sample_to_jsonl(sample).encode("utf-8"))
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
158
tests/unit/test_api_ground_truth.py
Normal file
158
tests/unit/test_api_ground_truth.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""HTTP tests for the ground-truth export endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ocr_sprint.main import create_app
|
||||
from ocr_sprint.pipeline import orchestrator as orch_module
|
||||
from ocr_sprint.pipeline.orchestrator import PipelineOutput
|
||||
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||
from ocr_sprint.schemas.extraction import (
|
||||
ExtractionResult,
|
||||
HeaderFields,
|
||||
PersonnelEntry,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> TestClient:
|
||||
return TestClient(create_app())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput:
|
||||
result = ExtractionResult(
|
||||
header=HeaderFields(
|
||||
nomor_sprint="Sprin/1/I/2025",
|
||||
tanggal=date(2025, 1, 1),
|
||||
satuan_penerbit="POLRES TEST",
|
||||
),
|
||||
personel=[
|
||||
PersonnelEntry(pangkat="AIPDA", nrp="77060000", nama="BUDI", jabatan="ANGGOTA"),
|
||||
],
|
||||
confidence=0.9,
|
||||
)
|
||||
output = PipelineOutput(
|
||||
source_kind=SourceKind.PDF,
|
||||
status=DocumentStatus.COMPLETED,
|
||||
confidence=0.9,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def _fake_run(_content: bytes) -> PipelineOutput:
|
||||
return output
|
||||
|
||||
monkeypatch.setattr(orch_module, "run_pipeline", _fake_run)
|
||||
from ocr_sprint.api.routes import documents as docs_module
|
||||
|
||||
monkeypatch.setattr(docs_module, "run_pipeline", _fake_run)
|
||||
from ocr_sprint.worker import tasks as tasks_module
|
||||
|
||||
monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run)
|
||||
return output
|
||||
|
||||
|
||||
def _create_and_approve(client: TestClient, *, correction_value: str | None = None) -> str:
|
||||
post = client.post(
|
||||
"/api/v1/documents?sync=true",
|
||||
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||
)
|
||||
assert post.status_code == 200, post.text
|
||||
jid = str(post.json()["job_id"])
|
||||
if correction_value is not None:
|
||||
patched = client.patch(
|
||||
f"/api/v1/documents/{jid}",
|
||||
json={"corrections": [{"path": "header.perihal", "value": correction_value}]},
|
||||
)
|
||||
assert patched.status_code == 200
|
||||
approved = client.post(f"/api/v1/documents/{jid}/approve")
|
||||
assert approved.status_code == 200
|
||||
return jid
|
||||
|
||||
|
||||
def test_stats_empty_dataset(client: TestClient) -> None:
|
||||
resp = client.get("/api/v1/ground-truth/stats")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["total_jobs"] == 0
|
||||
assert body["approved_jobs"] == 0
|
||||
assert body["total_corrections"] == 0
|
||||
assert body["top_corrected_fields"] == []
|
||||
|
||||
|
||||
def test_stats_rolls_up_counts(client: TestClient, fake_pipeline: PipelineOutput) -> None:
|
||||
_create_and_approve(client, correction_value="Penyelidikan-1")
|
||||
_create_and_approve(client, correction_value="Penyelidikan-2")
|
||||
_create_and_approve(client, correction_value=None) # pristine
|
||||
|
||||
resp = client.get("/api/v1/ground-truth/stats")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["total_jobs"] == 3
|
||||
assert body["approved_jobs"] == 3
|
||||
assert body["total_corrections"] == 2
|
||||
assert body["jobs_with_corrections"] == 2
|
||||
assert body["top_corrected_fields"][0]["field_path"] == "header.perihal"
|
||||
assert body["top_corrected_fields"][0]["count"] == 2
|
||||
|
||||
|
||||
def test_export_streams_jsonl(client: TestClient, fake_pipeline: PipelineOutput) -> None:
|
||||
_create_and_approve(client, correction_value="Penyelidikan")
|
||||
_create_and_approve(client, correction_value=None)
|
||||
|
||||
resp = client.get("/api/v1/ground-truth/export")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["content-type"].startswith("application/x-ndjson")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 2
|
||||
parsed = [json.loads(line) for line in lines]
|
||||
for sample in parsed:
|
||||
assert sample["approved"] is True
|
||||
assert "initial_result" in sample
|
||||
assert "final_result" in sample
|
||||
|
||||
|
||||
def test_export_approved_only_default(client: TestClient, fake_pipeline: PipelineOutput) -> None:
|
||||
"""Unapproved jobs shouldn't appear in the default export."""
|
||||
# One approved, one just completed (no approve call).
|
||||
_create_and_approve(client, correction_value=None)
|
||||
client.post(
|
||||
"/api/v1/documents?sync=true",
|
||||
files={"file": ("y.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||
)
|
||||
resp = client.get("/api/v1/ground-truth/export")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 1
|
||||
|
||||
# Toggle approved_only=false to include both.
|
||||
resp = client.get("/api/v1/ground-truth/export?approved_only=false")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 2
|
||||
|
||||
|
||||
def test_export_has_corrections_filter(client: TestClient, fake_pipeline: PipelineOutput) -> None:
|
||||
_create_and_approve(client, correction_value="Penyelidikan")
|
||||
_create_and_approve(client, correction_value=None)
|
||||
|
||||
resp = client.get("/api/v1/ground-truth/export?has_corrections=true")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 1
|
||||
assert json.loads(lines[0])["corrections"][0]["new_value"] == "Penyelidikan"
|
||||
|
||||
resp = client.get("/api/v1/ground-truth/export?has_corrections=false")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 1
|
||||
assert json.loads(lines[0])["corrections"] == []
|
||||
|
||||
|
||||
def test_export_respects_limit(client: TestClient, fake_pipeline: PipelineOutput) -> None:
|
||||
for _ in range(5):
|
||||
_create_and_approve(client, correction_value=None)
|
||||
resp = client.get("/api/v1/ground-truth/export?limit=2")
|
||||
lines = [line for line in resp.text.splitlines() if line.strip()]
|
||||
assert len(lines) == 2
|
||||
96
tests/unit/test_cli_export_ground_truth.py
Normal file
96
tests/unit/test_cli_export_ground_truth.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Smoke tests for the ``ocr_sprint.tools.export_ground_truth`` CLI."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from ocr_sprint.db.base import Base, get_engine, session_scope
|
||||
from ocr_sprint.db.repositories import JobRepository
|
||||
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||
from ocr_sprint.tools.export_ground_truth import main as export_main
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_ready() -> None:
|
||||
Base.metadata.create_all(bind=get_engine())
|
||||
|
||||
|
||||
def _seed_two_approved_jobs() -> None:
|
||||
for _ in range(2):
|
||||
jid = uuid4()
|
||||
with session_scope() as session:
|
||||
JobRepository(session).create(
|
||||
job_id=jid, filename="x.pdf", source_kind=SourceKind.PDF, blob_key="k"
|
||||
)
|
||||
with session_scope() as session:
|
||||
JobRepository(session).mark_completed(
|
||||
jid,
|
||||
status=DocumentStatus.COMPLETED,
|
||||
confidence=0.9,
|
||||
result={"header": {"nomor_sprint": "SPR/1/2025"}},
|
||||
review_flags=[],
|
||||
)
|
||||
with session_scope() as session:
|
||||
JobRepository(session).approve(jid, reviewed_by="rev")
|
||||
|
||||
|
||||
def test_cli_writes_expected_number_of_lines(
|
||||
db_ready: None, tmp_path: Path, capsys: pytest.CaptureFixture[str]
|
||||
) -> None:
|
||||
_seed_two_approved_jobs()
|
||||
out = tmp_path / "corpus.jsonl"
|
||||
|
||||
exit_code = export_main(["--out", str(out)])
|
||||
assert exit_code == 0
|
||||
lines = [line for line in out.read_text().splitlines() if line.strip()]
|
||||
assert len(lines) == 2
|
||||
for line in lines:
|
||||
parsed = json.loads(line)
|
||||
assert parsed["approved"] is True
|
||||
|
||||
stderr = capsys.readouterr().err
|
||||
assert "wrote 2 sample(s)" in stderr
|
||||
|
||||
|
||||
def test_cli_respects_limit(db_ready: None, tmp_path: Path) -> None:
|
||||
_seed_two_approved_jobs()
|
||||
out = tmp_path / "corpus.jsonl"
|
||||
exit_code = export_main(["--out", str(out), "--limit", "1"])
|
||||
assert exit_code == 0
|
||||
lines = [line for line in out.read_text().splitlines() if line.strip()]
|
||||
assert len(lines) == 1
|
||||
|
||||
|
||||
def test_cli_stdout_reports_correct_count(
|
||||
db_ready: None, capsys: pytest.CaptureFixture[str]
|
||||
) -> None:
|
||||
"""``--out -`` writes JSONL to stdout; the "wrote N" message must
|
||||
reflect what actually streamed, not 0."""
|
||||
_seed_two_approved_jobs()
|
||||
exit_code = export_main(["--out", "-"])
|
||||
assert exit_code == 0
|
||||
captured = capsys.readouterr()
|
||||
stdout_lines = [line for line in captured.out.splitlines() if line.strip()]
|
||||
assert len(stdout_lines) == 2
|
||||
for line in stdout_lines:
|
||||
assert json.loads(line)["approved"] is True
|
||||
assert "wrote 2 sample(s)" in captured.err
|
||||
|
||||
|
||||
def test_cli_print_stats_emits_json_to_stderr(
|
||||
db_ready: None, tmp_path: Path, capsys: pytest.CaptureFixture[str]
|
||||
) -> None:
|
||||
_seed_two_approved_jobs()
|
||||
out = tmp_path / "corpus.jsonl"
|
||||
exit_code = export_main(["--out", str(out), "--print-stats"])
|
||||
assert exit_code == 0
|
||||
stderr = capsys.readouterr().err
|
||||
# Validate the JSON prologue (after the "wrote N" line).
|
||||
json_start = stderr.index("{")
|
||||
stats = json.loads(stderr[json_start:])
|
||||
assert stats["total_jobs"] == 2
|
||||
assert stats["approved_jobs"] == 2
|
||||
210
tests/unit/test_ground_truth_service.py
Normal file
210
tests/unit/test_ground_truth_service.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Unit tests for the ground-truth export service."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from ocr_sprint.db.base import Base, get_engine, session_scope
|
||||
from ocr_sprint.db.repositories import JobRepository
|
||||
from ocr_sprint.ground_truth import (
|
||||
GroundTruthFilters,
|
||||
build_initial_result,
|
||||
ground_truth_stats,
|
||||
iter_ground_truth_samples,
|
||||
serialize_sample_to_jsonl,
|
||||
)
|
||||
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_ready() -> None:
|
||||
Base.metadata.create_all(bind=get_engine())
|
||||
|
||||
|
||||
def _seed_approved_job_with_corrections(
|
||||
*,
|
||||
final_result: dict[str, object] | None = None,
|
||||
corrections: list[tuple[str, object, object]] | None = None,
|
||||
approved: bool = True,
|
||||
) -> UUID:
|
||||
jid = uuid4()
|
||||
with session_scope() as session:
|
||||
repo = JobRepository(session)
|
||||
repo.create(job_id=jid, filename="x.pdf", source_kind=SourceKind.PDF, blob_key="k")
|
||||
with session_scope() as session:
|
||||
JobRepository(session).mark_completed(
|
||||
jid,
|
||||
status=DocumentStatus.NEEDS_REVIEW,
|
||||
confidence=0.8,
|
||||
result=final_result
|
||||
or {
|
||||
"header": {"nomor_sprint": "SPR/1/2025", "satuan_penerbit": "POLRES X"},
|
||||
"personel": [{"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"}],
|
||||
},
|
||||
review_flags=[],
|
||||
)
|
||||
if corrections:
|
||||
# Corrections are tuples of (path, value, old_value_for_audit).
|
||||
# We translate them into real PATCH calls so the audit rows look
|
||||
# exactly like they would in production.
|
||||
with session_scope() as session:
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[(p, new, None) for (p, new, _old) in corrections],
|
||||
corrected_by="reviewer-a",
|
||||
)
|
||||
if approved:
|
||||
with session_scope() as session:
|
||||
JobRepository(session).approve(jid, reviewed_by="reviewer-a")
|
||||
return jid
|
||||
|
||||
|
||||
def test_build_initial_result_reverses_corrections() -> None:
|
||||
"""With known old/new pairs the replay must return the pre-HITL dict."""
|
||||
final = {
|
||||
"header": {"nomor_sprint": "SPR/1/2025", "perihal": "Penyelidikan"},
|
||||
}
|
||||
|
||||
class _Fake:
|
||||
def __init__(self, path: str, old: object, new: object) -> None:
|
||||
self.field_path = path
|
||||
self.old_value = old
|
||||
self.new_value = new
|
||||
|
||||
corrections = [
|
||||
_Fake("header.perihal", None, "Penyelidikan"),
|
||||
_Fake("header.nomor_sprint", "SPR/OLD/2025", "SPR/1/2025"),
|
||||
]
|
||||
restored = build_initial_result(final_result=final, corrections=corrections) # type: ignore[arg-type]
|
||||
assert restored == {
|
||||
"header": {"nomor_sprint": "SPR/OLD/2025", "perihal": None},
|
||||
}
|
||||
# ``final`` must not have been mutated.
|
||||
assert final["header"]["nomor_sprint"] == "SPR/1/2025"
|
||||
|
||||
|
||||
def test_build_initial_result_ignores_unresolvable_paths() -> None:
|
||||
"""A legacy path that no longer resolves should be skipped, not raise."""
|
||||
final = {"header": {"nomor_sprint": "SPR/1"}}
|
||||
|
||||
class _Fake:
|
||||
def __init__(self, path: str, old: object, new: object) -> None:
|
||||
self.field_path = path
|
||||
self.old_value = old
|
||||
self.new_value = new
|
||||
|
||||
corrections = [
|
||||
_Fake("personel[99].nrp", "stale", "new"), # container doesn't exist
|
||||
_Fake("bogus..path", "x", "y"),
|
||||
]
|
||||
restored = build_initial_result(final_result=final, corrections=corrections) # type: ignore[arg-type]
|
||||
assert restored == final
|
||||
|
||||
|
||||
def test_iter_samples_reconstructs_initial_and_final(db_ready: None) -> None:
|
||||
jid = _seed_approved_job_with_corrections(
|
||||
corrections=[("header.perihal", "Penyelidikan", None)],
|
||||
)
|
||||
with session_scope() as session:
|
||||
samples = list(iter_ground_truth_samples(session, GroundTruthFilters()))
|
||||
assert len(samples) == 1
|
||||
s = samples[0]
|
||||
assert s.job_id == jid
|
||||
assert s.approved is True
|
||||
assert s.reviewed_by == "reviewer-a"
|
||||
assert s.final_result is not None
|
||||
assert s.final_result["header"]["perihal"] == "Penyelidikan"
|
||||
# Initial reconstruction undoes the correction.
|
||||
assert s.initial_result is not None
|
||||
assert s.initial_result["header"].get("perihal") is None
|
||||
assert len(s.corrections) == 1
|
||||
assert s.corrections[0].field_path == "header.perihal"
|
||||
|
||||
|
||||
def test_iter_samples_respects_approved_only_default(db_ready: None) -> None:
|
||||
"""Unapproved jobs must be excluded unless approved_only=False."""
|
||||
_seed_approved_job_with_corrections(approved=False)
|
||||
with session_scope() as session:
|
||||
samples = list(iter_ground_truth_samples(session, GroundTruthFilters()))
|
||||
assert samples == []
|
||||
with session_scope() as session:
|
||||
samples = list(iter_ground_truth_samples(session, GroundTruthFilters(approved_only=False)))
|
||||
assert len(samples) == 1
|
||||
|
||||
|
||||
def test_iter_samples_respects_has_corrections_filter(db_ready: None) -> None:
|
||||
_seed_approved_job_with_corrections(corrections=None) # pristine
|
||||
_seed_approved_job_with_corrections(corrections=[("header.perihal", "fill", None)])
|
||||
with session_scope() as session:
|
||||
with_ = list(iter_ground_truth_samples(session, GroundTruthFilters(has_corrections=True)))
|
||||
without = list(
|
||||
iter_ground_truth_samples(session, GroundTruthFilters(has_corrections=False))
|
||||
)
|
||||
assert len(with_) == 1
|
||||
assert len(without) == 1
|
||||
assert with_[0].job_id != without[0].job_id
|
||||
|
||||
|
||||
def test_iter_samples_respects_since_until_and_limit(db_ready: None) -> None:
|
||||
_seed_approved_job_with_corrections()
|
||||
_seed_approved_job_with_corrections()
|
||||
_seed_approved_job_with_corrections()
|
||||
|
||||
# Since far in the future → empty.
|
||||
future = datetime(2999, 1, 1, tzinfo=timezone.utc)
|
||||
with session_scope() as session:
|
||||
out = list(iter_ground_truth_samples(session, GroundTruthFilters(since=future)))
|
||||
assert out == []
|
||||
|
||||
# Limit caps the number emitted.
|
||||
with session_scope() as session:
|
||||
out = list(iter_ground_truth_samples(session, GroundTruthFilters(limit=2)))
|
||||
assert len(out) == 2
|
||||
|
||||
|
||||
def test_stats_counts_rollup_and_top_fields(db_ready: None) -> None:
|
||||
# 1 approved job with 2 corrections on ``header.perihal``.
|
||||
_seed_approved_job_with_corrections(
|
||||
corrections=[
|
||||
("header.perihal", "first", None),
|
||||
]
|
||||
)
|
||||
jid = _seed_approved_job_with_corrections(
|
||||
corrections=[
|
||||
("header.perihal", "second", None),
|
||||
("personel[0].nrp", "77060001", None),
|
||||
],
|
||||
)
|
||||
assert jid # silence unused
|
||||
|
||||
with session_scope() as session:
|
||||
stats = ground_truth_stats(session)
|
||||
|
||||
assert stats.total_jobs == 2
|
||||
assert stats.approved_jobs == 2
|
||||
assert stats.total_corrections == 3
|
||||
assert stats.jobs_with_corrections == 2
|
||||
# ``header.perihal`` is the most-corrected field (2 > 1).
|
||||
assert stats.top_corrected_fields[0].field_path == "header.perihal"
|
||||
assert stats.top_corrected_fields[0].count == 2
|
||||
assert {f.field_path for f in stats.top_corrected_fields} == {
|
||||
"header.perihal",
|
||||
"personel[0].nrp",
|
||||
}
|
||||
|
||||
|
||||
def test_serialize_is_valid_jsonl(db_ready: None) -> None:
|
||||
_seed_approved_job_with_corrections(corrections=[("header.perihal", "X", None)])
|
||||
with session_scope() as session:
|
||||
for sample in iter_ground_truth_samples(session, GroundTruthFilters()):
|
||||
line = serialize_sample_to_jsonl(sample)
|
||||
assert line.endswith("\n")
|
||||
# Strip trailing newline and make sure the result is a single
|
||||
# JSON object (not a list — JSONL must be one object per line).
|
||||
parsed = json.loads(line)
|
||||
assert isinstance(parsed, dict)
|
||||
assert parsed["approved"] is True
|
||||
Reference in New Issue
Block a user