Phase 7: ground-truth export (JSONL + stats) + CLI tool
- GET /api/v1/ground-truth/export streaming JSONL (approved_only, since, until, has_corrections, limit) - GET /api/v1/ground-truth/stats total / approved / corrections counts + top-N most-corrected field paths - python -m ocr_sprint.tools.export_ground_truth operator CLI with the same filters + optional --print-stats - Ground-truth sample reconstructs the pipeline's original output by replaying job_corrections in reverse - docs/ground-truth-format.md schema + fine-tuning guidance - 17 new tests (service replay, endpoint filters, CLI) - 201 total tests passing, ruff / mypy --strict clean Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com>
This commit is contained in:
102
src/ocr_sprint/api/routes/ground_truth.py
Normal file
102
src/ocr_sprint/api/routes/ground_truth.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""Ground-truth export + statistics endpoints (Phase 7).
|
||||
|
||||
Two endpoints, both auth'd by the existing ``X-API-Key`` dependency:
|
||||
|
||||
* ``GET /ground-truth/export`` — streams JSONL of approved (or filtered)
|
||||
samples for downstream fine-tuning pipelines.
|
||||
* ``GET /ground-truth/stats`` — returns aggregate counts + top-corrected
|
||||
field paths so operators know when/where fine-tuning will pay off.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ocr_sprint.api.deps.auth import require_api_key
|
||||
from ocr_sprint.api.deps.db import get_session
|
||||
from ocr_sprint.ground_truth import (
|
||||
GroundTruthFilters,
|
||||
ground_truth_stats,
|
||||
iter_ground_truth_samples,
|
||||
serialize_sample_to_jsonl,
|
||||
)
|
||||
from ocr_sprint.schemas.ground_truth import GroundTruthStats
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/ground-truth",
|
||||
tags=["ground-truth"],
|
||||
dependencies=[Depends(require_api_key)],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/export",
|
||||
response_class=StreamingResponse,
|
||||
responses={
|
||||
200: {
|
||||
"content": {"application/x-ndjson": {}},
|
||||
"description": "Newline-delimited JSON stream of training samples.",
|
||||
}
|
||||
},
|
||||
)
|
||||
def export_ground_truth(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
since: Annotated[
|
||||
datetime | None,
|
||||
Query(description="Only include jobs created at or after this ISO timestamp."),
|
||||
] = None,
|
||||
until: Annotated[
|
||||
datetime | None,
|
||||
Query(description="Only include jobs created at or before this ISO timestamp."),
|
||||
] = None,
|
||||
approved_only: Annotated[
|
||||
bool,
|
||||
Query(description="Only export approved jobs (default true)."),
|
||||
] = True,
|
||||
has_corrections: Annotated[
|
||||
bool | None,
|
||||
Query(
|
||||
description=(
|
||||
"Optional: true = only jobs that had at least one correction, "
|
||||
"false = only pristine (no-correction) jobs."
|
||||
)
|
||||
),
|
||||
] = None,
|
||||
limit: Annotated[
|
||||
int | None,
|
||||
Query(ge=1, le=100_000, description="Maximum rows to emit."),
|
||||
] = None,
|
||||
) -> StreamingResponse:
|
||||
filters = GroundTruthFilters(
|
||||
since=since,
|
||||
until=until,
|
||||
approved_only=approved_only,
|
||||
has_corrections=has_corrections,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
def _stream() -> Iterator[bytes]:
|
||||
for sample in iter_ground_truth_samples(session, filters):
|
||||
yield serialize_sample_to_jsonl(sample).encode("utf-8")
|
||||
|
||||
return StreamingResponse(_stream(), media_type="application/x-ndjson")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/stats",
|
||||
response_model=GroundTruthStats,
|
||||
)
|
||||
def get_stats(
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
top_n: Annotated[
|
||||
int,
|
||||
Query(ge=1, le=100, description="How many top-corrected field paths to return."),
|
||||
] = 10,
|
||||
) -> GroundTruthStats:
|
||||
return ground_truth_stats(session, top_n=top_n)
|
||||
22
src/ocr_sprint/ground_truth/__init__.py
Normal file
22
src/ocr_sprint/ground_truth/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Phase 7 — ground-truth export service.
|
||||
|
||||
Consumes the ``jobs`` + ``job_corrections`` tables to produce JSONL
|
||||
training samples + aggregate statistics. No ML dependencies here: the
|
||||
actual fine-tuning runs in a separate project on dedicated hardware.
|
||||
"""
|
||||
|
||||
from ocr_sprint.ground_truth.service import (
|
||||
GroundTruthFilters,
|
||||
build_initial_result,
|
||||
ground_truth_stats,
|
||||
iter_ground_truth_samples,
|
||||
serialize_sample_to_jsonl,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GroundTruthFilters",
|
||||
"build_initial_result",
|
||||
"ground_truth_stats",
|
||||
"iter_ground_truth_samples",
|
||||
"serialize_sample_to_jsonl",
|
||||
]
|
||||
248
src/ocr_sprint/ground_truth/service.py
Normal file
248
src/ocr_sprint/ground_truth/service.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""Core ground-truth export logic.
|
||||
|
||||
The service reads from ``jobs`` and ``job_corrections`` and produces
|
||||
``GroundTruthSample`` rows. The HTTP layer and the CLI both go through
|
||||
this module, so the export logic stays in exactly one place.
|
||||
|
||||
Design notes
|
||||
------------
|
||||
* **Reverse-replay for ``initial_result``.** Audit rows store both
|
||||
``old_value`` and ``new_value``. We start from the current
|
||||
``final_result`` and undo each correction in reverse chronological
|
||||
order to recover the pipeline's original pre-HITL output. This is
|
||||
what fine-tuning needs: the raw mistake next to the reviewer fix.
|
||||
* **No deep copying in the caller.** The service always returns freshly
|
||||
computed dicts so mutating one sample can't leak into the next.
|
||||
* **Generator by design.** Large exports shouldn't load the whole table
|
||||
into memory; the HTTP route streams JSONL line-by-line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import re
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ocr_sprint.db.models import JobCorrectionRow, JobRow
|
||||
from ocr_sprint.schemas.document import DocumentStatus
|
||||
from ocr_sprint.schemas.ground_truth import (
|
||||
FieldCorrectionCount,
|
||||
GroundTruthCorrection,
|
||||
GroundTruthSample,
|
||||
GroundTruthStats,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GroundTruthFilters:
|
||||
"""HTTP / CLI filters for ``iter_ground_truth_samples``.
|
||||
|
||||
All fields are optional; ``approved_only=True`` is the default because
|
||||
only approved samples are actually safe to train on.
|
||||
"""
|
||||
|
||||
since: datetime | None = None
|
||||
until: datetime | None = None
|
||||
approved_only: bool = True
|
||||
has_corrections: bool | None = None
|
||||
limit: int | None = None
|
||||
|
||||
|
||||
def build_initial_result(
|
||||
*,
|
||||
final_result: dict[str, Any] | None,
|
||||
corrections: list[JobCorrectionRow],
|
||||
) -> dict[str, Any] | None:
|
||||
"""Replay ``corrections`` backwards over ``final_result`` to recover
|
||||
the pre-HITL version. Returns ``None`` if there's nothing to
|
||||
reconstruct.
|
||||
"""
|
||||
if final_result is None:
|
||||
return None
|
||||
restored = copy.deepcopy(final_result)
|
||||
for event in reversed(corrections):
|
||||
_assign_path(restored, event.field_path, event.old_value)
|
||||
return restored
|
||||
|
||||
|
||||
def _assign_path(data: dict[str, Any], path: str, value: Any) -> None:
|
||||
"""Set ``data[path] = value`` using the same dotted syntax the HITL
|
||||
repository understands. Silently ignores paths that no longer resolve
|
||||
(the dict shape may have drifted since the event was recorded; in
|
||||
that case we just skip the replay step rather than raising).
|
||||
"""
|
||||
segments = _parse(path)
|
||||
if not segments:
|
||||
return
|
||||
|
||||
cursor: Any = data
|
||||
for name, idx in segments[:-1]:
|
||||
if not isinstance(cursor, dict) or name not in cursor:
|
||||
return
|
||||
cursor = cursor[name]
|
||||
if idx is not None:
|
||||
if not isinstance(cursor, list) or idx >= len(cursor):
|
||||
return
|
||||
cursor = cursor[idx]
|
||||
|
||||
name, idx = segments[-1]
|
||||
if idx is not None:
|
||||
if not isinstance(cursor, dict) or name not in cursor:
|
||||
return
|
||||
container = cursor[name]
|
||||
if not isinstance(container, list) or idx >= len(container):
|
||||
return
|
||||
container[idx] = value
|
||||
return
|
||||
|
||||
if not isinstance(cursor, dict):
|
||||
return
|
||||
cursor[name] = value
|
||||
|
||||
|
||||
_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$")
|
||||
|
||||
|
||||
def _parse(path: str) -> list[tuple[str, int | None]]:
|
||||
"""Forgiving dotted-path parser — returns ``[]`` on malformed input
|
||||
instead of raising, since audit-log replay must not blow up on a
|
||||
stale path. The strict variant used by the HITL PATCH route lives
|
||||
in ``ocr_sprint.db.repositories``.
|
||||
"""
|
||||
if not path or path.startswith(".") or path.endswith("."):
|
||||
return []
|
||||
out: list[tuple[str, int | None]] = []
|
||||
for part in path.split("."):
|
||||
match: re.Match[str] | None = _SEGMENT_RE.match(part)
|
||||
if match is None:
|
||||
return []
|
||||
out.append((match.group(1), int(match.group(2)) if match.group(2) else None))
|
||||
return out
|
||||
|
||||
|
||||
def iter_ground_truth_samples(
|
||||
session: Session,
|
||||
filters: GroundTruthFilters,
|
||||
) -> Iterator[GroundTruthSample]:
|
||||
"""Yield ``GroundTruthSample`` rows matching ``filters``.
|
||||
|
||||
Rows are ordered by ``created_at ASC`` so repeated exports are
|
||||
deterministic and downstream diffing works.
|
||||
"""
|
||||
stmt = select(JobRow).order_by(JobRow.created_at, JobRow.job_id)
|
||||
if filters.approved_only:
|
||||
stmt = stmt.where(JobRow.approved.is_(True))
|
||||
if filters.since is not None:
|
||||
stmt = stmt.where(JobRow.created_at >= filters.since)
|
||||
if filters.until is not None:
|
||||
stmt = stmt.where(JobRow.created_at <= filters.until)
|
||||
|
||||
remaining = filters.limit
|
||||
for job_row in session.scalars(stmt):
|
||||
if remaining is not None and remaining <= 0:
|
||||
return
|
||||
|
||||
correction_rows = list(
|
||||
session.scalars(
|
||||
select(JobCorrectionRow)
|
||||
.where(JobCorrectionRow.job_id == job_row.job_id)
|
||||
.order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id)
|
||||
)
|
||||
)
|
||||
|
||||
if filters.has_corrections is True and not correction_rows:
|
||||
continue
|
||||
if filters.has_corrections is False and correction_rows:
|
||||
continue
|
||||
|
||||
initial = build_initial_result(final_result=job_row.result, corrections=correction_rows)
|
||||
sample = GroundTruthSample(
|
||||
job_id=job_row.job_id,
|
||||
filename=job_row.filename,
|
||||
source_kind=job_row.source_kind,
|
||||
approved=bool(job_row.approved),
|
||||
reviewed_by=job_row.reviewed_by,
|
||||
reviewed_at=job_row.reviewed_at,
|
||||
created_at=job_row.created_at,
|
||||
initial_result=initial,
|
||||
final_result=copy.deepcopy(job_row.result) if job_row.result else None,
|
||||
corrections=[
|
||||
GroundTruthCorrection(
|
||||
field_path=c.field_path,
|
||||
old_value=c.old_value,
|
||||
new_value=c.new_value,
|
||||
corrected_by=c.corrected_by,
|
||||
reason=c.reason,
|
||||
corrected_at=c.corrected_at,
|
||||
)
|
||||
for c in correction_rows
|
||||
],
|
||||
review_flags=list(job_row.review_flags or []),
|
||||
confidence=job_row.confidence,
|
||||
)
|
||||
if remaining is not None:
|
||||
remaining -= 1
|
||||
yield sample
|
||||
|
||||
|
||||
def ground_truth_stats(session: Session, *, top_n: int = 10) -> GroundTruthStats:
|
||||
"""Compute aggregate statistics over the jobs + corrections tables."""
|
||||
by_status: dict[str, int] = {
|
||||
row[0]: int(row[1])
|
||||
for row in session.execute(
|
||||
select(JobRow.status, func.count(JobRow.job_id)).group_by(JobRow.status)
|
||||
).all()
|
||||
}
|
||||
total_jobs = int(sum(by_status.values()))
|
||||
approved_jobs = int(
|
||||
session.execute(
|
||||
select(func.count(JobRow.job_id)).where(JobRow.approved.is_(True))
|
||||
).scalar_one()
|
||||
)
|
||||
|
||||
total_corrections = int(session.execute(select(func.count(JobCorrectionRow.id))).scalar_one())
|
||||
jobs_with_corrections = int(
|
||||
session.execute(select(func.count(func.distinct(JobCorrectionRow.job_id)))).scalar_one()
|
||||
)
|
||||
|
||||
# Field-path histogram. A single SQL ``GROUP BY`` is cheaper than
|
||||
# pulling every row over the wire, but we still cap at ``top_n`` so
|
||||
# a pathological dataset can't drown the response.
|
||||
top_rows = session.execute(
|
||||
select(
|
||||
JobCorrectionRow.field_path,
|
||||
func.count(JobCorrectionRow.id).label("n"),
|
||||
)
|
||||
.group_by(JobCorrectionRow.field_path)
|
||||
.order_by(func.count(JobCorrectionRow.id).desc(), JobCorrectionRow.field_path)
|
||||
.limit(top_n)
|
||||
).all()
|
||||
|
||||
return GroundTruthStats(
|
||||
total_jobs=total_jobs,
|
||||
completed_jobs=by_status.get(DocumentStatus.COMPLETED.value, 0),
|
||||
needs_review_jobs=by_status.get(DocumentStatus.NEEDS_REVIEW.value, 0),
|
||||
failed_jobs=by_status.get(DocumentStatus.FAILED.value, 0),
|
||||
approved_jobs=approved_jobs,
|
||||
total_corrections=total_corrections,
|
||||
jobs_with_corrections=jobs_with_corrections,
|
||||
top_corrected_fields=[
|
||||
FieldCorrectionCount(field_path=row[0], count=int(row[1])) for row in top_rows
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def serialize_sample_to_jsonl(sample: GroundTruthSample) -> str:
|
||||
"""Serialize one sample as a newline-terminated JSON string.
|
||||
|
||||
Kept standalone so the HTTP and CLI layers share the exact same
|
||||
representation — diffs between a curl download and a CLI dump stay
|
||||
byte-identical.
|
||||
"""
|
||||
return sample.model_dump_json() + "\n"
|
||||
@@ -7,7 +7,7 @@ from fastapi import FastAPI
|
||||
from ocr_sprint import __version__
|
||||
from ocr_sprint.api.errors import register_error_handlers
|
||||
from ocr_sprint.api.metrics import MetricsMiddleware, metrics_endpoint
|
||||
from ocr_sprint.api.routes import documents, health
|
||||
from ocr_sprint.api.routes import documents, ground_truth, health
|
||||
from ocr_sprint.config import get_settings
|
||||
from ocr_sprint.db import models as _models # noqa: F401 (register ORM tables)
|
||||
from ocr_sprint.db.base import Base, get_engine
|
||||
@@ -43,6 +43,7 @@ def create_app() -> FastAPI:
|
||||
app.add_middleware(MetricsMiddleware)
|
||||
app.include_router(health.router, prefix="/api/v1")
|
||||
app.include_router(documents.router, prefix="/api/v1")
|
||||
app.include_router(ground_truth.router, prefix="/api/v1")
|
||||
app.add_api_route("/metrics", metrics_endpoint, methods=["GET"], include_in_schema=False)
|
||||
return app
|
||||
|
||||
|
||||
76
src/ocr_sprint/schemas/ground_truth.py
Normal file
76
src/ocr_sprint/schemas/ground_truth.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Schemas for the Phase 7 ground-truth export.
|
||||
|
||||
Each ``GroundTruthSample`` represents one training-ready example:
|
||||
|
||||
* ``initial_*`` snapshots the pipeline's original (pre-HITL) output,
|
||||
reconstructed by replaying the audit trail in reverse.
|
||||
* ``final_*`` is the current ``result`` on the ``jobs`` row — the
|
||||
reviewer-approved answer.
|
||||
* ``corrections`` is the raw audit trail so downstream fine-tuning can
|
||||
see *what* was changed, *why* (free-text reason), and by whom.
|
||||
|
||||
JSONL is emitted — one sample per line — so the file can be mmapped,
|
||||
streamed, or piped straight into an HF ``datasets.load_dataset("json",
|
||||
...)`` call.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class GroundTruthCorrection(BaseModel):
|
||||
"""One row of the ``job_corrections`` audit trail, as exported."""
|
||||
|
||||
field_path: str
|
||||
old_value: Any | None = None
|
||||
new_value: Any | None = None
|
||||
corrected_by: str | None = None
|
||||
reason: str | None = None
|
||||
corrected_at: datetime
|
||||
|
||||
|
||||
class GroundTruthSample(BaseModel):
|
||||
"""One training sample written as a single JSONL line."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
job_id: UUID
|
||||
filename: str
|
||||
source_kind: str
|
||||
approved: bool = False
|
||||
reviewed_by: str | None = None
|
||||
reviewed_at: datetime | None = None
|
||||
created_at: datetime
|
||||
# ``initial_*`` is the pipeline's pre-HITL answer, reconstructed from
|
||||
# the audit trail. ``final_*`` is the reviewer-approved version.
|
||||
initial_result: dict[str, Any] | None = None
|
||||
final_result: dict[str, Any] | None = None
|
||||
corrections: list[GroundTruthCorrection] = Field(default_factory=list)
|
||||
review_flags: list[str] = Field(default_factory=list)
|
||||
confidence: float | None = None
|
||||
|
||||
|
||||
class FieldCorrectionCount(BaseModel):
|
||||
field_path: str
|
||||
count: int
|
||||
|
||||
|
||||
class GroundTruthStats(BaseModel):
|
||||
"""High-level dataset health report surfaced by ``GET /ground-truth/stats``."""
|
||||
|
||||
total_jobs: int
|
||||
completed_jobs: int
|
||||
needs_review_jobs: int
|
||||
failed_jobs: int
|
||||
approved_jobs: int
|
||||
total_corrections: int
|
||||
jobs_with_corrections: int
|
||||
# Most-corrected field paths (descending). Operators use this to
|
||||
# prioritise which fields to target with prompt tweaks or fine-tune
|
||||
# data collection first.
|
||||
top_corrected_fields: list[FieldCorrectionCount] = Field(default_factory=list)
|
||||
1
src/ocr_sprint/tools/__init__.py
Normal file
1
src/ocr_sprint/tools/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Operator CLI utilities. Invoked via ``python -m ocr_sprint.tools.<name>``."""
|
||||
137
src/ocr_sprint/tools/export_ground_truth.py
Normal file
137
src/ocr_sprint/tools/export_ground_truth.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""CLI: ``python -m ocr_sprint.tools.export_ground_truth``.
|
||||
|
||||
Dumps the HITL ground-truth corpus to a local JSONL file. The HTTP
|
||||
endpoint is fine for small queries, but weekly / monthly snapshots are
|
||||
easier to schedule on the host via cron than via curl.
|
||||
|
||||
Usage
|
||||
-----
|
||||
python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl
|
||||
python -m ocr_sprint.tools.export_ground_truth --out corpus.jsonl \
|
||||
--since 2025-01-01 --has-corrections
|
||||
|
||||
The database URL is read from the same ``DATABASE_URL`` env var the API
|
||||
uses, so this script works out of the box on any box that can run the
|
||||
service.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Protocol
|
||||
|
||||
from ocr_sprint.db.base import session_scope
|
||||
from ocr_sprint.ground_truth import (
|
||||
GroundTruthFilters,
|
||||
ground_truth_stats,
|
||||
iter_ground_truth_samples,
|
||||
serialize_sample_to_jsonl,
|
||||
)
|
||||
|
||||
|
||||
def _parse_iso(raw: str | None) -> datetime | None:
|
||||
if raw is None:
|
||||
return None
|
||||
# ``fromisoformat`` supports dates too, so ``--since 2025-01-01``
|
||||
# behaves like ``2025-01-01T00:00:00``.
|
||||
return datetime.fromisoformat(raw)
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
p = argparse.ArgumentParser(
|
||||
prog="ocr_sprint.tools.export_ground_truth",
|
||||
description="Dump HITL ground-truth corpus to JSONL.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--out",
|
||||
required=True,
|
||||
type=Path,
|
||||
help="Output JSONL file. Use '-' to write to stdout.",
|
||||
)
|
||||
p.add_argument("--since", type=str, default=None, help="ISO date/datetime lower bound.")
|
||||
p.add_argument("--until", type=str, default=None, help="ISO date/datetime upper bound.")
|
||||
p.add_argument(
|
||||
"--include-unapproved",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Include jobs that haven't been approved yet. Not recommended "
|
||||
"for training data — use only for debugging exports."
|
||||
),
|
||||
)
|
||||
group = p.add_mutually_exclusive_group()
|
||||
group.add_argument(
|
||||
"--has-corrections",
|
||||
action="store_true",
|
||||
help="Only include jobs that had at least one correction.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--no-corrections",
|
||||
action="store_true",
|
||||
help="Only include pristine jobs (no corrections). Useful for sanity sets.",
|
||||
)
|
||||
p.add_argument("--limit", type=int, default=None, help="Maximum samples to emit.")
|
||||
p.add_argument(
|
||||
"--print-stats",
|
||||
action="store_true",
|
||||
help="Also print dataset summary to stderr after the dump.",
|
||||
)
|
||||
return p
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
args = _build_parser().parse_args(argv)
|
||||
|
||||
has_corrections: bool | None
|
||||
if args.has_corrections:
|
||||
has_corrections = True
|
||||
elif args.no_corrections:
|
||||
has_corrections = False
|
||||
else:
|
||||
has_corrections = None
|
||||
|
||||
filters = GroundTruthFilters(
|
||||
since=_parse_iso(args.since),
|
||||
until=_parse_iso(args.until),
|
||||
approved_only=not args.include_unapproved,
|
||||
has_corrections=has_corrections,
|
||||
limit=args.limit,
|
||||
)
|
||||
|
||||
out: Path = args.out
|
||||
count = 0
|
||||
if str(out) == "-":
|
||||
_write_stream(sys.stdout.buffer, filters)
|
||||
else:
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
with out.open("wb") as fh:
|
||||
count = _write_stream(fh, filters)
|
||||
|
||||
print(f"wrote {count} sample(s) to {out}", file=sys.stderr)
|
||||
|
||||
if args.print_stats:
|
||||
with session_scope() as session:
|
||||
stats = ground_truth_stats(session)
|
||||
print(stats.model_dump_json(indent=2), file=sys.stderr)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class _BinaryWriter(Protocol):
|
||||
def write(self, data: bytes) -> int: ...
|
||||
|
||||
|
||||
def _write_stream(fh: _BinaryWriter, filters: GroundTruthFilters) -> int:
|
||||
"""Write JSONL to ``fh``. Returns the number of samples written."""
|
||||
count = 0
|
||||
with session_scope() as session:
|
||||
for sample in iter_ground_truth_samples(session, filters):
|
||||
fh.write(serialize_sample_to_jsonl(sample).encode("utf-8"))
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Reference in New Issue
Block a user