Phase 6: HITL review endpoints + audit trail

- New job_corrections table (append-only audit log) + migration
- Add approved / reviewed_by / reviewed_at columns to jobs
- PATCH  /documents/{id}         apply field-level corrections
- GET    /documents/{id}/history return chronological audit trail
- POST   /documents/{id}/approve lock final version (idempotent)
- Dotted field-path applier with root allow-list + list-index support
- Auto-clear `missing_field` review flag when required header keys filled
- Atomic batch apply: malformed path in batch rolls back all changes
- 22 new tests (11 repository-level, 11 API-level); 184 total passing

Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com>
This commit is contained in:
Devin AI
2026-04-25 20:12:04 +00:00
parent 45fbfdabb7
commit 66247e39a5
9 changed files with 1058 additions and 22 deletions

View File

@@ -0,0 +1,60 @@
"""phase6 hitl: job_corrections + approval columns
Revision ID: 3b1f2c9a4d56
Revises: ff8c14fbf8a0
Create Date: 2026-04-25 14:30:00.000000
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "3b1f2c9a4d56"
down_revision: str | None = "ff8c14fbf8a0"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
with op.batch_alter_table("jobs") as batch:
batch.add_column(
sa.Column(
"approved",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
)
)
batch.add_column(sa.Column("reviewed_by", sa.String(length=128), nullable=True))
batch.add_column(sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=True))
op.create_table(
"job_corrections",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("job_id", sa.Uuid(), nullable=False),
sa.Column("field_path", sa.String(length=256), nullable=False),
sa.Column("old_value", sa.JSON(), nullable=True),
sa.Column("new_value", sa.JSON(), nullable=True),
sa.Column("corrected_by", sa.String(length=128), nullable=True),
sa.Column("reason", sa.String(length=512), nullable=True),
sa.Column("corrected_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["job_id"], ["jobs.job_id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_job_corrections_job_id"),
"job_corrections",
["job_id"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_job_corrections_job_id"), table_name="job_corrections")
op.drop_table("job_corrections")
with op.batch_alter_table("jobs") as batch:
batch.drop_column("reviewed_at")
batch.drop_column("reviewed_by")
batch.drop_column("approved")

View File

@@ -1,17 +1,17 @@
"""phase4 jobs table
Revision ID: ff8c14fbf8a0
Revises:
Revises:
Create Date: 2026-04-25 15:54:18.579147
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'ff8c14fbf8a0'
revision: str = "ff8c14fbf8a0"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
@@ -19,24 +19,25 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('jobs',
sa.Column('job_id', sa.Uuid(), nullable=False),
sa.Column('status', sa.String(length=32), nullable=False),
sa.Column('source_kind', sa.String(length=16), nullable=False),
sa.Column('filename', sa.String(length=512), nullable=False),
sa.Column('blob_key', sa.String(length=512), nullable=True),
sa.Column('confidence', sa.Float(), nullable=True),
sa.Column('review_flags', sa.JSON(), nullable=False),
sa.Column('result', sa.JSON(), nullable=True),
sa.Column('error', sa.String(length=2048), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('job_id')
op.create_table(
"jobs",
sa.Column("job_id", sa.Uuid(), nullable=False),
sa.Column("status", sa.String(length=32), nullable=False),
sa.Column("source_kind", sa.String(length=16), nullable=False),
sa.Column("filename", sa.String(length=512), nullable=False),
sa.Column("blob_key", sa.String(length=512), nullable=True),
sa.Column("confidence", sa.Float(), nullable=True),
sa.Column("review_flags", sa.JSON(), nullable=False),
sa.Column("result", sa.JSON(), nullable=True),
sa.Column("error", sa.String(length=2048), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("job_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('jobs')
op.drop_table("jobs")
# ### end Alembic commands ###

View File

@@ -22,7 +22,17 @@ from __future__ import annotations
from typing import Annotated
from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, File, HTTPException, Query, Response, UploadFile, status
from fastapi import (
APIRouter,
Depends,
File,
Header,
HTTPException,
Query,
Response,
UploadFile,
status,
)
from sqlalchemy.orm import Session
from ocr_sprint.api.deps.auth import require_api_key
@@ -31,11 +41,22 @@ from ocr_sprint.api.errors import UnsupportedDocumentError
from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS
from ocr_sprint.config import get_settings
from ocr_sprint.db.base import session_scope
from ocr_sprint.db.repositories import JobNotFoundError, JobRepository
from ocr_sprint.db.repositories import (
InvalidFieldPathError,
JobAlreadyApprovedError,
JobNotCompletedError,
JobNotFoundError,
JobRepository,
)
from ocr_sprint.pipeline.ingest import detect_source_kind
from ocr_sprint.pipeline.orchestrator import run_pipeline
from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus
from ocr_sprint.schemas.extraction import ExtractionResult
from ocr_sprint.schemas.review import (
ApprovalResponse,
CorrectionEventResponse,
CorrectionRequest,
)
from ocr_sprint.storage.blob import get_blob_storage
from ocr_sprint.utils.logging import get_logger
@@ -75,6 +96,9 @@ def _row_to_response(row: object) -> DocumentResponse:
data=result_obj,
review_flags=list(row.review_flags or []),
error=row.error,
approved=bool(row.approved),
reviewed_by=row.reviewed_by,
reviewed_at=row.reviewed_at,
)
@@ -192,3 +216,116 @@ async def get_document(
except JobNotFoundError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
return _row_to_response(row)
# ---------- Phase 6 — HITL ----------
def _correction_row_to_response(row: object) -> CorrectionEventResponse:
# Local import to avoid a cyclic import at module load time.
from ocr_sprint.db.models import JobCorrectionRow
assert isinstance(row, JobCorrectionRow)
return CorrectionEventResponse(
id=row.id,
job_id=row.job_id,
field_path=row.field_path,
old_value=row.old_value,
new_value=row.new_value,
corrected_by=row.corrected_by,
reason=row.reason,
corrected_at=row.corrected_at,
)
@router.patch(
"/{job_id}",
response_model=DocumentResponse,
)
async def patch_document(
job_id: UUID,
body: CorrectionRequest,
session: Annotated[Session, Depends(get_session)],
x_user_id: Annotated[
str | None,
Header(description="Free-form reviewer identifier recorded on the audit row."),
] = None,
) -> DocumentResponse:
"""Apply one or more field-level corrections and record an audit trail.
The whole batch is applied atomically — if any path is invalid the
request fails with 400 and no side effects are written. Returns the
updated document so the client doesn't need a follow-up GET.
"""
repo = JobRepository(session)
try:
repo.apply_corrections(
job_id,
corrections=[(c.path, c.value, c.reason) for c in body.corrections],
corrected_by=x_user_id,
)
row = repo.get_or_raise(job_id)
except JobNotFoundError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
except InvalidFieldPathError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
except JobAlreadyApprovedError as exc:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
except JobNotCompletedError as exc:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
_logger.info(
"documents.patched",
job_id=str(job_id),
count=len(body.corrections),
corrected_by=x_user_id or "",
)
return _row_to_response(row)
@router.get(
"/{job_id}/history",
response_model=list[CorrectionEventResponse],
)
async def get_history(
job_id: UUID,
session: Annotated[Session, Depends(get_session)],
) -> list[CorrectionEventResponse]:
repo = JobRepository(session)
try:
rows = repo.list_corrections(job_id)
except JobNotFoundError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
return [_correction_row_to_response(r) for r in rows]
@router.post(
"/{job_id}/approve",
response_model=ApprovalResponse,
)
async def approve_document(
job_id: UUID,
session: Annotated[Session, Depends(get_session)],
x_user_id: Annotated[
str | None,
Header(description="Free-form reviewer identifier recorded on the job."),
] = None,
) -> ApprovalResponse:
"""Lock a job's final version. Idempotent: re-approving returns the
existing row without overwriting ``reviewed_at``.
"""
repo = JobRepository(session)
try:
row = repo.approve(job_id, reviewed_by=x_user_id)
except JobNotFoundError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
except JobNotCompletedError as exc:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
_logger.info("documents.approved", job_id=str(job_id), reviewed_by=row.reviewed_by or "")
return ApprovalResponse(
job_id=row.job_id,
approved=bool(row.approved),
reviewed_by=row.reviewed_by,
reviewed_at=row.reviewed_at,
)

View File

@@ -16,7 +16,7 @@ from datetime import datetime, timezone
from typing import Any
from uuid import UUID, uuid4
from sqlalchemy import JSON, DateTime, Float, String, Uuid
from sqlalchemy import JSON, Boolean, DateTime, Float, ForeignKey, Integer, String, Uuid
from sqlalchemy.orm import Mapped, mapped_column
from ocr_sprint.db.base import Base
@@ -42,6 +42,15 @@ class JobRow(Base):
result: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
error: Mapped[str | None] = mapped_column(String(2048), nullable=True)
# Phase 6 — HITL review state.
# Once ``approved=True`` the row is immutable except to admin users;
# corrections after that point are rejected by the route. ``reviewed_by``
# stores the free-form user identifier the reviewer sent via the
# ``X-User-Id`` header (best-effort attribution — no full RBAC yet).
approved: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
reviewed_by: Mapped[str | None] = mapped_column(String(128), nullable=True)
reviewed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, default=_utcnow
)
@@ -51,3 +60,37 @@ class JobRow(Base):
def __repr__(self) -> str:
return f"JobRow(job_id={self.job_id!s}, status={self.status!r})"
class JobCorrectionRow(Base):
"""One correction event on a job's ``result``.
Each PATCH call writes one row per changed field path so we have a
full audit trail. Rows are append-only — never updated, never
deleted — so the history is reproducible and usable as ground-truth
data for future fine-tuning.
"""
__tablename__ = "job_corrections"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
job_id: Mapped[UUID] = mapped_column(
Uuid, ForeignKey("jobs.job_id", ondelete="CASCADE"), nullable=False, index=True
)
# Dotted JSON path into ExtractionResult, e.g. "header.nomor_sprint" or
# "personel[3].nrp". Kept as a plain string for simplicity — we don't
# parse it server-side beyond the allow-list check in the repository.
field_path: Mapped[str] = mapped_column(String(256), nullable=False)
old_value: Mapped[Any | None] = mapped_column(JSON, nullable=True)
new_value: Mapped[Any | None] = mapped_column(JSON, nullable=True)
corrected_by: Mapped[str | None] = mapped_column(String(128), nullable=True)
reason: Mapped[str | None] = mapped_column(String(512), nullable=True)
corrected_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, default=_utcnow
)
def __repr__(self) -> str:
return (
f"JobCorrectionRow(job_id={self.job_id!s}, "
f"field={self.field_path!r}, by={self.corrected_by!r})"
)

View File

@@ -6,6 +6,9 @@ know about sessions, transactions, or the row → schema mapping.
from __future__ import annotations
import copy
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
@@ -13,7 +16,7 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from ocr_sprint.db.models import JobRow
from ocr_sprint.db.models import JobCorrectionRow, JobRow
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
@@ -25,6 +28,110 @@ class JobNotFoundError(LookupError):
"""Raised by API code when GET /documents/{id} hits a missing row."""
class InvalidFieldPathError(ValueError):
"""Raised when a PATCH request references an unsupported field path."""
class JobAlreadyApprovedError(RuntimeError):
"""Raised when a PATCH is attempted against an already-approved job."""
class JobNotCompletedError(RuntimeError):
"""Raised when a PATCH/approve is attempted against a job that hasn't
produced a ``result`` payload yet (e.g. still pending or failed).
"""
@dataclass(frozen=True)
class AppliedCorrection:
"""Internal record of a correction that successfully applied to the
in-memory ``result`` dict. The repository turns this into a persisted
``JobCorrectionRow`` after the whole batch is validated.
"""
field_path: str
old_value: Any
new_value: Any
reason: str | None
# Allow-list of top-level keys we let reviewers edit. Keeps the attack
# surface small: they can't inject arbitrary fields into the JSON blob.
_ALLOWED_ROOTS: frozenset[str] = frozenset({"header", "ttd", "personel", "untuk"})
# Matches a single path segment like ``personel[3]`` — supports at most one
# index per segment, enough for the list fields we care about.
_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$")
def _split_path(path: str) -> list[tuple[str, int | None]]:
"""Parse ``header.nomor_sprint`` or ``personel[2].nrp`` into segments.
Returns list of ``(name, index_or_none)`` tuples. Raises
``InvalidFieldPathError`` on malformed input so the caller can surface
a 400 to the client.
"""
if not path or path.startswith(".") or path.endswith("."):
raise InvalidFieldPathError(f"Invalid field path: {path!r}")
parts = path.split(".")
out: list[tuple[str, int | None]] = []
for part in parts:
match = _SEGMENT_RE.match(part)
if match is None:
raise InvalidFieldPathError(f"Invalid segment in path: {part!r}")
name = match.group(1)
idx_raw = match.group(2)
idx = int(idx_raw) if idx_raw is not None else None
out.append((name, idx))
if out[0][0] not in _ALLOWED_ROOTS:
raise InvalidFieldPathError(
f"Field path root {out[0][0]!r} not in allowed roots {sorted(_ALLOWED_ROOTS)!r}"
)
return out
def _apply_path(data: dict[str, Any], path: str, new_value: Any) -> Any:
"""Apply a single correction to ``data`` in place. Returns the old
value so the caller can record it in the audit row.
Does NOT validate that the new value matches the field's expected
type — that's the reviewer's responsibility; the whole point of HITL
is to let humans override the model's typing.
"""
segments = _split_path(path)
cursor: Any = data
for name, idx in segments[:-1]:
if not isinstance(cursor, dict) or name not in cursor:
raise InvalidFieldPathError(f"Cannot traverse to {path!r}: missing {name!r}")
cursor = cursor[name]
if idx is not None:
if not isinstance(cursor, list) or idx >= len(cursor):
raise InvalidFieldPathError(
f"Cannot traverse to {path!r}: index [{idx}] out of range"
)
cursor = cursor[idx]
name, idx = segments[-1]
if idx is not None:
# Terminal segment is a list-element, e.g. ``untuk[2]``.
if not isinstance(cursor, dict) or name not in cursor:
raise InvalidFieldPathError(f"Cannot apply to {path!r}: missing container {name!r}")
container = cursor[name]
if not isinstance(container, list) or idx >= len(container):
raise InvalidFieldPathError(f"Cannot apply to {path!r}: index [{idx}] out of range")
old = container[idx]
container[idx] = new_value
return old
if not isinstance(cursor, dict):
raise InvalidFieldPathError(f"Cannot apply to {path!r}: parent is not an object")
old = cursor.get(name)
cursor[name] = new_value
return old
class JobRepository:
"""SQL-backed repository for `jobs` rows."""
@@ -94,3 +201,139 @@ class JobRepository:
if row is None:
raise JobNotFoundError(f"Job not found: {job_id}")
return row
# ---------- Phase 6 — HITL ----------
def apply_corrections(
self,
job_id: UUID,
*,
corrections: list[tuple[str, Any, str | None]],
corrected_by: str | None,
) -> list[JobCorrectionRow]:
"""Apply a batch of field corrections atomically.
``corrections`` is a list of ``(path, new_value, reason)`` tuples.
Returns the persisted audit rows so the caller can surface them in
the response.
Raises
------
JobNotFoundError
If the row doesn't exist.
JobNotCompletedError
If the job hasn't produced a result yet (status pending /
processing / failed).
JobAlreadyApprovedError
If the job has been approved — edits are locked.
InvalidFieldPathError
If any path is malformed or references a disallowed root.
"""
row = self._get_or_raise(job_id)
if row.result is None:
raise JobNotCompletedError(
f"Job {job_id} has no result to correct (status={row.status})"
)
if row.approved:
raise JobAlreadyApprovedError(f"Job {job_id} is already approved; edits are locked")
# Deep-copy so we can roll back in memory if any correction fails.
# The underlying JSON column will only be re-assigned once every
# path applied cleanly.
working = copy.deepcopy(row.result)
applied: list[AppliedCorrection] = []
for path, new_value, reason in corrections:
old_value = _apply_path(working, path, new_value)
applied.append(
AppliedCorrection(
field_path=path, old_value=old_value, new_value=new_value, reason=reason
)
)
# Persist audit rows first; if they fail the session rollback also
# undoes the result-column update we're about to do.
persisted: list[JobCorrectionRow] = []
for event in applied:
row_event = JobCorrectionRow(
job_id=job_id,
field_path=event.field_path,
old_value=event.old_value,
new_value=event.new_value,
corrected_by=corrected_by,
reason=event.reason,
)
self.session.add(row_event)
persisted.append(row_event)
# Clear review flags that the correction has resolved. Right now we
# only auto-clear MISSING_FIELD when any corrected field previously
# held a null/empty value — the reviewer explicitly filled a gap.
row.result = working
row.review_flags = _recompute_flags(
original_flags=list(row.review_flags or []),
applied=applied,
working_result=working,
)
row.updated_at = _utcnow()
self.session.flush()
return persisted
def list_corrections(self, job_id: UUID) -> list[JobCorrectionRow]:
"""Return the full audit trail for ``job_id`` in chronological order."""
# ``get_or_raise`` so callers get a 404 instead of an empty list
# when the job itself doesn't exist.
self._get_or_raise(job_id)
stmt = (
select(JobCorrectionRow)
.where(JobCorrectionRow.job_id == job_id)
.order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id)
)
return list(self.session.scalars(stmt))
def approve(self, job_id: UUID, *, reviewed_by: str | None) -> JobRow:
"""Mark a job as approved. Idempotent — re-approving is a no-op
that keeps the original ``reviewed_at`` (so the audit trail stays
intact).
"""
row = self._get_or_raise(job_id)
if row.result is None:
raise JobNotCompletedError(
f"Job {job_id} has no result to approve (status={row.status})"
)
if row.approved:
return row
row.approved = True
row.reviewed_by = reviewed_by
row.reviewed_at = _utcnow()
row.updated_at = row.reviewed_at
return row
def _recompute_flags(
*,
original_flags: list[str],
applied: list[AppliedCorrection],
working_result: dict[str, Any],
) -> list[str]:
"""Update review flags in light of the corrections just applied.
Keeps the policy simple on purpose:
* ``missing_field`` is removed if after the edit every required
header field is non-empty.
* Other flags stay untouched — the reviewer should either correct the
underlying issue (which this helper can detect) or explicitly
approve the result as-is (which bypasses the flag list).
"""
flags = list(original_flags)
if "missing_field" in flags:
header = working_result.get("header") or {}
filled = all(bool(header.get(key)) for key in ("nomor_sprint", "satuan_penerbit"))
if filled:
flags = [f for f in flags if f != "missing_field"]
# ``applied`` isn't used directly in this MVP rule, but we keep the
# parameter so future policies can inspect exactly what changed
# without re-diffing the blob.
_ = applied
return flags

View File

@@ -55,3 +55,7 @@ class DocumentResponse(BaseModel):
data: ExtractionResult | None = None
review_flags: list[str] = Field(default_factory=list)
error: str | None = None
# Phase 6 — HITL review state.
approved: bool = False
reviewed_by: str | None = None
reviewed_at: datetime | None = None

View File

@@ -0,0 +1,62 @@
"""Request / response schemas for the HITL review endpoints (Phase 6).
The API surface is deliberately small:
* ``CorrectionRequest`` — body of ``PATCH /documents/{id}``. A list of
``FieldCorrection`` entries; each one is applied atomically (all-or-
nothing) and recorded in the audit trail.
* ``CorrectionEventResponse`` — single row in ``GET /documents/{id}/history``.
* ``ApprovalResponse`` — echo back after ``POST /documents/{id}/approve``.
"""
from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field
class FieldCorrection(BaseModel):
"""One field-level correction.
``path`` is a dotted JSON path into ``ExtractionResult``. Supported
roots: ``header``, ``ttd``, ``personel[n]`` (n is a 0-based index),
``untuk``. The path is validated by the repository before being
applied; unknown roots return 400.
"""
path: str = Field(..., description="Dotted JSON path, e.g. 'header.nomor_sprint'.")
value: Any = Field(..., description="New value (any JSON-serialisable payload).")
reason: str | None = Field(
None, max_length=512, description="Optional free-form reason for the correction."
)
class CorrectionRequest(BaseModel):
"""PATCH body — one or more field corrections, applied atomically."""
corrections: list[FieldCorrection] = Field(..., min_length=1)
class CorrectionEventResponse(BaseModel):
"""One row of the audit log surfaced by GET /history."""
id: int
job_id: UUID
field_path: str
old_value: Any | None = None
new_value: Any | None = None
corrected_by: str | None = None
reason: str | None = None
corrected_at: datetime
class ApprovalResponse(BaseModel):
"""Echo returned after a job is approved."""
job_id: UUID
approved: bool
reviewed_by: str | None = None
reviewed_at: datetime | None = None

248
tests/unit/test_api_hitl.py Normal file
View File

@@ -0,0 +1,248 @@
"""End-to-end HTTP tests for the HITL endpoints.
We re-use the ``fake_pipeline`` style from ``test_api.py`` so we don't pay
the PaddleOCR init cost; the orchestrator is monkey-patched to return a
synthetic ``ExtractionResult``.
"""
from __future__ import annotations
from datetime import date
import pytest
from fastapi.testclient import TestClient
from ocr_sprint.main import create_app
from ocr_sprint.pipeline import orchestrator as orch_module
from ocr_sprint.pipeline.orchestrator import PipelineOutput
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
from ocr_sprint.schemas.extraction import (
ExtractionResult,
HeaderFields,
PersonnelEntry,
ReviewFlag,
)
@pytest.fixture
def client() -> TestClient:
return TestClient(create_app())
@pytest.fixture
def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput:
result = ExtractionResult(
header=HeaderFields(
nomor_sprint="Sprin/1/I/2025",
tanggal=date(2025, 1, 1),
satuan_penerbit="POLRES TEST",
perihal=None, # intentional gap so a PATCH can fill it
),
personel=[
PersonnelEntry(pangkat="AIPDA", nrp="77060000", nama="BUDI", jabatan="ANGGOTA"),
],
review_flags=[ReviewFlag.MISSING_FIELD],
confidence=0.7,
)
output = PipelineOutput(
source_kind=SourceKind.PDF,
status=DocumentStatus.NEEDS_REVIEW,
confidence=0.7,
result=result,
)
def _fake_run(_content: bytes) -> PipelineOutput:
return output
monkeypatch.setattr(orch_module, "run_pipeline", _fake_run)
from ocr_sprint.api.routes import documents as docs_module
monkeypatch.setattr(docs_module, "run_pipeline", _fake_run)
from ocr_sprint.worker import tasks as tasks_module
monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run)
return output
def _create_job(client: TestClient) -> str:
post = client.post(
"/api/v1/documents?sync=true",
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
)
assert post.status_code == 200, post.text
body = post.json()
assert body["status"] == "needs_review"
return str(body["job_id"])
def test_patch_applies_correction_and_clears_missing_field(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
patched = client.patch(
f"/api/v1/documents/{job_id}",
json={
"corrections": [
{
"path": "header.perihal",
"value": "Penyelidikan kasus X",
"reason": "LLM missed it",
}
]
},
headers={"X-User-Id": "reviewer-a"},
)
assert patched.status_code == 200, patched.text
body = patched.json()
assert body["data"]["header"]["perihal"] == "Penyelidikan kasus X"
# The fake pipeline has both required header fields filled, so the
# ``missing_field`` flag is auto-cleared as soon as any correction
# lands (the policy re-evaluates required-field coverage on every
# edit).
assert "missing_field" not in body["review_flags"]
def test_patch_returns_400_for_unknown_path(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
resp = client.patch(
f"/api/v1/documents/{job_id}",
json={"corrections": [{"path": "bogus.field", "value": "x"}]},
)
assert resp.status_code == 400
def test_patch_is_atomic_on_partial_failure(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
resp = client.patch(
f"/api/v1/documents/{job_id}",
json={
"corrections": [
{"path": "header.perihal", "value": "OK"},
{"path": "bogus.root", "value": "X"},
]
},
)
assert resp.status_code == 400
# The first correction must not have persisted.
got = client.get(f"/api/v1/documents/{job_id}")
assert got.json()["data"]["header"]["perihal"] is None
def test_history_returns_corrections_in_order(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
client.patch(
f"/api/v1/documents/{job_id}",
json={"corrections": [{"path": "header.perihal", "value": "first"}]},
headers={"X-User-Id": "reviewer-a"},
)
client.patch(
f"/api/v1/documents/{job_id}",
json={"corrections": [{"path": "header.perihal", "value": "second"}]},
headers={"X-User-Id": "reviewer-b"},
)
history = client.get(f"/api/v1/documents/{job_id}/history")
assert history.status_code == 200
events = history.json()
assert [e["new_value"] for e in events] == ["first", "second"]
assert [e["corrected_by"] for e in events] == ["reviewer-a", "reviewer-b"]
# old_value of the second event should reflect the first edit.
assert events[1]["old_value"] == "first"
def test_history_returns_empty_list_for_untouched_job(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
history = client.get(f"/api/v1/documents/{job_id}/history")
assert history.status_code == 200
assert history.json() == []
def test_history_returns_404_for_unknown_job(client: TestClient) -> None:
resp = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000/history")
assert resp.status_code == 404
def test_approve_locks_subsequent_patches(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
approved = client.post(
f"/api/v1/documents/{job_id}/approve",
headers={"X-User-Id": "reviewer-a"},
)
assert approved.status_code == 200, approved.text
body = approved.json()
assert body["approved"] is True
assert body["reviewed_by"] == "reviewer-a"
assert body["reviewed_at"] # non-empty timestamp
# GET reflects the approval state.
got = client.get(f"/api/v1/documents/{job_id}").json()
assert got["approved"] is True
# PATCH after approve must be rejected with 409.
patched = client.patch(
f"/api/v1/documents/{job_id}",
json={"corrections": [{"path": "header.perihal", "value": "X"}]},
)
assert patched.status_code == 409
def test_approve_is_idempotent(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
first = client.post(
f"/api/v1/documents/{job_id}/approve",
headers={"X-User-Id": "reviewer-a"},
)
second = client.post(
f"/api/v1/documents/{job_id}/approve",
headers={"X-User-Id": "reviewer-b"},
)
assert first.status_code == 200
assert second.status_code == 200
# Second approve must NOT change the attribution. (SQLite drops tzinfo
# on roundtrip, which changes Pydantic's serialization between the two
# calls; compare the naive components.)
assert second.json()["reviewed_by"] == "reviewer-a"
assert (
second.json()["reviewed_at"].rstrip("Z").split("+")[0]
== (first.json()["reviewed_at"].rstrip("Z").split("+")[0])
)
def test_patch_requires_at_least_one_correction(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
resp = client.patch(
f"/api/v1/documents/{job_id}",
json={"corrections": []},
)
assert resp.status_code == 422 # Pydantic min_length=1 violation
def test_patch_missing_job_returns_404(client: TestClient) -> None:
resp = client.patch(
"/api/v1/documents/00000000-0000-0000-0000-000000000000",
json={"corrections": [{"path": "header.perihal", "value": "X"}]},
)
assert resp.status_code == 404

238
tests/unit/test_db_hitl.py Normal file
View File

@@ -0,0 +1,238 @@
"""Repository tests for Phase 6 HITL helpers."""
from __future__ import annotations
from uuid import uuid4
import pytest
from ocr_sprint.db.base import Base, get_engine, session_scope
from ocr_sprint.db.repositories import (
InvalidFieldPathError,
JobAlreadyApprovedError,
JobNotCompletedError,
JobNotFoundError,
JobRepository,
)
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
@pytest.fixture
def db_ready() -> None:
Base.metadata.create_all(bind=get_engine())
def _seed_completed_job(
*,
result: dict[str, object] | None = None,
flags: list[str] | None = None,
) -> uuid4: # type: ignore[type-arg]
jid = uuid4()
with session_scope() as session:
repo = JobRepository(session)
repo.create(
job_id=jid,
filename="x.pdf",
source_kind=SourceKind.PDF,
blob_key="k",
)
with session_scope() as session:
JobRepository(session).mark_completed(
jid,
status=DocumentStatus.NEEDS_REVIEW,
confidence=0.7,
result=result
or {
"header": {
"nomor_sprint": "Sprin/1/I/2025",
"satuan_penerbit": "POLRES X",
"perihal": None,
},
"personel": [
{"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"},
],
"untuk": ["Melaksanakan tugas"],
},
review_flags=flags or [],
)
return jid
def test_apply_corrections_updates_nested_header_field(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
repo = JobRepository(session)
repo.apply_corrections(
jid,
corrections=[("header.perihal", "Penyelidikan kasus X", "regex miss")],
corrected_by="reviewer-a",
)
row = repo.get_or_raise(jid)
assert row.result is not None
assert row.result["header"]["perihal"] == "Penyelidikan kasus X"
def test_apply_corrections_writes_audit_row(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[("header.perihal", "Penyelidikan", None)],
corrected_by="reviewer-a",
)
with session_scope() as session:
events = JobRepository(session).list_corrections(jid)
assert len(events) == 1
assert events[0].field_path == "header.perihal"
assert events[0].old_value is None
assert events[0].new_value == "Penyelidikan"
assert events[0].corrected_by == "reviewer-a"
def test_apply_corrections_supports_list_index(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[("personel[0].nrp", "77060001", None)],
corrected_by=None,
)
row = JobRepository(session).get_or_raise(jid)
assert row.result is not None
assert row.result["personel"][0]["nrp"] == "77060001"
def test_apply_corrections_is_atomic_on_invalid_path(db_ready: None) -> None:
"""A second-correction failure must roll back the first one."""
jid = _seed_completed_job()
with session_scope() as session, pytest.raises(InvalidFieldPathError):
JobRepository(session).apply_corrections(
jid,
corrections=[
("header.perihal", "OK", None),
("bogus.root", "X", None),
],
corrected_by=None,
)
# The first correction must not have persisted.
with session_scope() as session:
row = JobRepository(session).get_or_raise(jid)
assert row.result is not None
assert row.result["header"].get("perihal") is None
def test_apply_corrections_rejects_out_of_range_index(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session, pytest.raises(InvalidFieldPathError):
JobRepository(session).apply_corrections(
jid,
corrections=[("personel[99].nrp", "77060001", None)],
corrected_by=None,
)
def test_apply_corrections_rejects_after_approve(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
JobRepository(session).approve(jid, reviewed_by="reviewer-a")
with session_scope() as session, pytest.raises(JobAlreadyApprovedError):
JobRepository(session).apply_corrections(
jid,
corrections=[("header.perihal", "X", None)],
corrected_by="reviewer-a",
)
def test_apply_corrections_rejects_missing_job(db_ready: None) -> None:
with session_scope() as session, pytest.raises(JobNotFoundError):
JobRepository(session).apply_corrections(
uuid4(),
corrections=[("header.perihal", "X", None)],
corrected_by=None,
)
def test_apply_corrections_rejects_pending_job(db_ready: None) -> None:
jid = uuid4()
with session_scope() as session:
JobRepository(session).create(
job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k"
)
with session_scope() as session, pytest.raises(JobNotCompletedError):
JobRepository(session).apply_corrections(
jid,
corrections=[("header.perihal", "X", None)],
corrected_by=None,
)
def test_missing_field_flag_cleared_when_header_gap_filled(db_ready: None) -> None:
jid = _seed_completed_job(
result={
"header": {
"nomor_sprint": None,
"satuan_penerbit": "POLRES X",
}
},
flags=["missing_field", "low_ocr_confidence"],
)
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[("header.nomor_sprint", "Sprin/2/I/2025", None)],
corrected_by="reviewer-a",
)
row = JobRepository(session).get_or_raise(jid)
# ``low_ocr_confidence`` stays (correction doesn't resolve that signal),
# but ``missing_field`` is gone because every required header key is
# now non-empty.
assert list(row.review_flags) == ["low_ocr_confidence"]
def test_approve_sets_timestamps_and_is_idempotent(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
row = JobRepository(session).approve(jid, reviewed_by="reviewer-a")
first_at = row.reviewed_at
assert first_at is not None
with session_scope() as session:
row = JobRepository(session).approve(jid, reviewed_by="reviewer-b")
# Second call must NOT overwrite reviewed_by or reviewed_at.
# SQLite drops tzinfo on roundtrip, so compare the naive components.
assert row.approved is True
assert row.reviewed_by == "reviewer-a"
assert row.reviewed_at is not None
assert row.reviewed_at.replace(tzinfo=None) == first_at.replace(tzinfo=None)
def test_approve_rejects_pending_job(db_ready: None) -> None:
jid = uuid4()
with session_scope() as session:
JobRepository(session).create(
job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k"
)
with session_scope() as session, pytest.raises(JobNotCompletedError):
JobRepository(session).approve(jid, reviewed_by="rev")
def test_history_returns_events_in_order(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[("header.perihal", "one", None)],
corrected_by="r1",
)
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[
("header.perihal", "two", None),
("personel[0].nama", "ANDI", None),
],
corrected_by="r2",
)
with session_scope() as session:
events = JobRepository(session).list_corrections(jid)
assert [e.new_value for e in events] == ["one", "two", "ANDI"]
assert [e.corrected_by for e in events] == ["r1", "r2", "r2"]