Merge pull request #4 from Adriankf59/devin/1777135879-phase-5-llm-hybrid
Phase 5: hybrid LLM extraction (Ollama) for header gaps
This commit is contained in:
60
alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py
Normal file
60
alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""phase6 hitl: job_corrections + approval columns
|
||||
|
||||
Revision ID: 3b1f2c9a4d56
|
||||
Revises: ff8c14fbf8a0
|
||||
Create Date: 2026-04-25 14:30:00.000000
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "3b1f2c9a4d56"
|
||||
down_revision: str | None = "ff8c14fbf8a0"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("jobs") as batch:
|
||||
batch.add_column(
|
||||
sa.Column(
|
||||
"approved",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
)
|
||||
batch.add_column(sa.Column("reviewed_by", sa.String(length=128), nullable=True))
|
||||
batch.add_column(sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=True))
|
||||
|
||||
op.create_table(
|
||||
"job_corrections",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("job_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("field_path", sa.String(length=256), nullable=False),
|
||||
sa.Column("old_value", sa.JSON(), nullable=True),
|
||||
sa.Column("new_value", sa.JSON(), nullable=True),
|
||||
sa.Column("corrected_by", sa.String(length=128), nullable=True),
|
||||
sa.Column("reason", sa.String(length=512), nullable=True),
|
||||
sa.Column("corrected_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(["job_id"], ["jobs.job_id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_job_corrections_job_id"),
|
||||
"job_corrections",
|
||||
["job_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_job_corrections_job_id"), table_name="job_corrections")
|
||||
op.drop_table("job_corrections")
|
||||
with op.batch_alter_table("jobs") as batch:
|
||||
batch.drop_column("reviewed_at")
|
||||
batch.drop_column("reviewed_by")
|
||||
batch.drop_column("approved")
|
||||
@@ -1,17 +1,17 @@
|
||||
"""phase4 jobs table
|
||||
|
||||
Revision ID: ff8c14fbf8a0
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2026-04-25 15:54:18.579147
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'ff8c14fbf8a0'
|
||||
revision: str = "ff8c14fbf8a0"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
@@ -19,24 +19,25 @@ depends_on: str | Sequence[str] | None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('jobs',
|
||||
sa.Column('job_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('status', sa.String(length=32), nullable=False),
|
||||
sa.Column('source_kind', sa.String(length=16), nullable=False),
|
||||
sa.Column('filename', sa.String(length=512), nullable=False),
|
||||
sa.Column('blob_key', sa.String(length=512), nullable=True),
|
||||
sa.Column('confidence', sa.Float(), nullable=True),
|
||||
sa.Column('review_flags', sa.JSON(), nullable=False),
|
||||
sa.Column('result', sa.JSON(), nullable=True),
|
||||
sa.Column('error', sa.String(length=2048), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint('job_id')
|
||||
op.create_table(
|
||||
"jobs",
|
||||
sa.Column("job_id", sa.Uuid(), nullable=False),
|
||||
sa.Column("status", sa.String(length=32), nullable=False),
|
||||
sa.Column("source_kind", sa.String(length=16), nullable=False),
|
||||
sa.Column("filename", sa.String(length=512), nullable=False),
|
||||
sa.Column("blob_key", sa.String(length=512), nullable=True),
|
||||
sa.Column("confidence", sa.Float(), nullable=True),
|
||||
sa.Column("review_flags", sa.JSON(), nullable=False),
|
||||
sa.Column("result", sa.JSON(), nullable=True),
|
||||
sa.Column("error", sa.String(length=2048), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.PrimaryKeyConstraint("job_id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('jobs')
|
||||
op.drop_table("jobs")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -22,7 +22,17 @@ from __future__ import annotations
|
||||
from typing import Annotated
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, Response, UploadFile, status
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
File,
|
||||
Header,
|
||||
HTTPException,
|
||||
Query,
|
||||
Response,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ocr_sprint.api.deps.auth import require_api_key
|
||||
@@ -31,11 +41,22 @@ from ocr_sprint.api.errors import UnsupportedDocumentError
|
||||
from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS
|
||||
from ocr_sprint.config import get_settings
|
||||
from ocr_sprint.db.base import session_scope
|
||||
from ocr_sprint.db.repositories import JobNotFoundError, JobRepository
|
||||
from ocr_sprint.db.repositories import (
|
||||
InvalidFieldPathError,
|
||||
JobAlreadyApprovedError,
|
||||
JobNotCompletedError,
|
||||
JobNotFoundError,
|
||||
JobRepository,
|
||||
)
|
||||
from ocr_sprint.pipeline.ingest import detect_source_kind
|
||||
from ocr_sprint.pipeline.orchestrator import run_pipeline
|
||||
from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus
|
||||
from ocr_sprint.schemas.extraction import ExtractionResult
|
||||
from ocr_sprint.schemas.review import (
|
||||
ApprovalResponse,
|
||||
CorrectionEventResponse,
|
||||
CorrectionRequest,
|
||||
)
|
||||
from ocr_sprint.storage.blob import get_blob_storage
|
||||
from ocr_sprint.utils.logging import get_logger
|
||||
|
||||
@@ -75,6 +96,9 @@ def _row_to_response(row: object) -> DocumentResponse:
|
||||
data=result_obj,
|
||||
review_flags=list(row.review_flags or []),
|
||||
error=row.error,
|
||||
approved=bool(row.approved),
|
||||
reviewed_by=row.reviewed_by,
|
||||
reviewed_at=row.reviewed_at,
|
||||
)
|
||||
|
||||
|
||||
@@ -192,3 +216,116 @@ async def get_document(
|
||||
except JobNotFoundError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
return _row_to_response(row)
|
||||
|
||||
|
||||
# ---------- Phase 6 — HITL ----------
|
||||
|
||||
|
||||
def _correction_row_to_response(row: object) -> CorrectionEventResponse:
|
||||
# Local import to avoid a cyclic import at module load time.
|
||||
from ocr_sprint.db.models import JobCorrectionRow
|
||||
|
||||
assert isinstance(row, JobCorrectionRow)
|
||||
return CorrectionEventResponse(
|
||||
id=row.id,
|
||||
job_id=row.job_id,
|
||||
field_path=row.field_path,
|
||||
old_value=row.old_value,
|
||||
new_value=row.new_value,
|
||||
corrected_by=row.corrected_by,
|
||||
reason=row.reason,
|
||||
corrected_at=row.corrected_at,
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/{job_id}",
|
||||
response_model=DocumentResponse,
|
||||
)
|
||||
async def patch_document(
|
||||
job_id: UUID,
|
||||
body: CorrectionRequest,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
x_user_id: Annotated[
|
||||
str | None,
|
||||
Header(description="Free-form reviewer identifier recorded on the audit row."),
|
||||
] = None,
|
||||
) -> DocumentResponse:
|
||||
"""Apply one or more field-level corrections and record an audit trail.
|
||||
|
||||
The whole batch is applied atomically — if any path is invalid the
|
||||
request fails with 400 and no side effects are written. Returns the
|
||||
updated document so the client doesn't need a follow-up GET.
|
||||
"""
|
||||
repo = JobRepository(session)
|
||||
try:
|
||||
repo.apply_corrections(
|
||||
job_id,
|
||||
corrections=[(c.path, c.value, c.reason) for c in body.corrections],
|
||||
corrected_by=x_user_id,
|
||||
)
|
||||
row = repo.get_or_raise(job_id)
|
||||
except JobNotFoundError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
except InvalidFieldPathError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
except JobAlreadyApprovedError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||
except JobNotCompletedError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||
|
||||
_logger.info(
|
||||
"documents.patched",
|
||||
job_id=str(job_id),
|
||||
count=len(body.corrections),
|
||||
corrected_by=x_user_id or "",
|
||||
)
|
||||
return _row_to_response(row)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{job_id}/history",
|
||||
response_model=list[CorrectionEventResponse],
|
||||
)
|
||||
async def get_history(
|
||||
job_id: UUID,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> list[CorrectionEventResponse]:
|
||||
repo = JobRepository(session)
|
||||
try:
|
||||
rows = repo.list_corrections(job_id)
|
||||
except JobNotFoundError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
return [_correction_row_to_response(r) for r in rows]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{job_id}/approve",
|
||||
response_model=ApprovalResponse,
|
||||
)
|
||||
async def approve_document(
|
||||
job_id: UUID,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
x_user_id: Annotated[
|
||||
str | None,
|
||||
Header(description="Free-form reviewer identifier recorded on the job."),
|
||||
] = None,
|
||||
) -> ApprovalResponse:
|
||||
"""Lock a job's final version. Idempotent: re-approving returns the
|
||||
existing row without overwriting ``reviewed_at``.
|
||||
"""
|
||||
repo = JobRepository(session)
|
||||
try:
|
||||
row = repo.approve(job_id, reviewed_by=x_user_id)
|
||||
except JobNotFoundError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
except JobNotCompletedError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||
|
||||
_logger.info("documents.approved", job_id=str(job_id), reviewed_by=row.reviewed_by or "")
|
||||
return ApprovalResponse(
|
||||
job_id=row.job_id,
|
||||
approved=bool(row.approved),
|
||||
reviewed_by=row.reviewed_by,
|
||||
reviewed_at=row.reviewed_at,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import JSON, DateTime, Float, String, Uuid
|
||||
from sqlalchemy import JSON, Boolean, DateTime, Float, ForeignKey, Integer, String, Uuid
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from ocr_sprint.db.base import Base
|
||||
@@ -42,6 +42,15 @@ class JobRow(Base):
|
||||
result: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
|
||||
error: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
||||
|
||||
# Phase 6 — HITL review state.
|
||||
# Once ``approved=True`` the row is immutable except to admin users;
|
||||
# corrections after that point are rejected by the route. ``reviewed_by``
|
||||
# stores the free-form user identifier the reviewer sent via the
|
||||
# ``X-User-Id`` header (best-effort attribution — no full RBAC yet).
|
||||
approved: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
reviewed_by: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
reviewed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, default=_utcnow
|
||||
)
|
||||
@@ -51,3 +60,37 @@ class JobRow(Base):
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"JobRow(job_id={self.job_id!s}, status={self.status!r})"
|
||||
|
||||
|
||||
class JobCorrectionRow(Base):
|
||||
"""One correction event on a job's ``result``.
|
||||
|
||||
Each PATCH call writes one row per changed field path so we have a
|
||||
full audit trail. Rows are append-only — never updated, never
|
||||
deleted — so the history is reproducible and usable as ground-truth
|
||||
data for future fine-tuning.
|
||||
"""
|
||||
|
||||
__tablename__ = "job_corrections"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
job_id: Mapped[UUID] = mapped_column(
|
||||
Uuid, ForeignKey("jobs.job_id", ondelete="CASCADE"), nullable=False, index=True
|
||||
)
|
||||
# Dotted JSON path into ExtractionResult, e.g. "header.nomor_sprint" or
|
||||
# "personel[3].nrp". Kept as a plain string for simplicity — we don't
|
||||
# parse it server-side beyond the allow-list check in the repository.
|
||||
field_path: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||
old_value: Mapped[Any | None] = mapped_column(JSON, nullable=True)
|
||||
new_value: Mapped[Any | None] = mapped_column(JSON, nullable=True)
|
||||
corrected_by: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
reason: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||
corrected_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, default=_utcnow
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"JobCorrectionRow(job_id={self.job_id!s}, "
|
||||
f"field={self.field_path!r}, by={self.corrected_by!r})"
|
||||
)
|
||||
|
||||
@@ -6,6 +6,9 @@ know about sessions, transactions, or the row → schema mapping.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
@@ -13,7 +16,7 @@ from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ocr_sprint.db.models import JobRow
|
||||
from ocr_sprint.db.models import JobCorrectionRow, JobRow
|
||||
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||
|
||||
|
||||
@@ -25,6 +28,110 @@ class JobNotFoundError(LookupError):
|
||||
"""Raised by API code when GET /documents/{id} hits a missing row."""
|
||||
|
||||
|
||||
class InvalidFieldPathError(ValueError):
|
||||
"""Raised when a PATCH request references an unsupported field path."""
|
||||
|
||||
|
||||
class JobAlreadyApprovedError(RuntimeError):
|
||||
"""Raised when a PATCH is attempted against an already-approved job."""
|
||||
|
||||
|
||||
class JobNotCompletedError(RuntimeError):
|
||||
"""Raised when a PATCH/approve is attempted against a job that hasn't
|
||||
produced a ``result`` payload yet (e.g. still pending or failed).
|
||||
"""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AppliedCorrection:
|
||||
"""Internal record of a correction that successfully applied to the
|
||||
in-memory ``result`` dict. The repository turns this into a persisted
|
||||
``JobCorrectionRow`` after the whole batch is validated.
|
||||
"""
|
||||
|
||||
field_path: str
|
||||
old_value: Any
|
||||
new_value: Any
|
||||
reason: str | None
|
||||
|
||||
|
||||
# Allow-list of top-level keys we let reviewers edit. Keeps the attack
|
||||
# surface small: they can't inject arbitrary fields into the JSON blob.
|
||||
_ALLOWED_ROOTS: frozenset[str] = frozenset({"header", "ttd", "personel", "untuk"})
|
||||
|
||||
# Matches a single path segment like ``personel[3]`` — supports at most one
|
||||
# index per segment, enough for the list fields we care about.
|
||||
_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$")
|
||||
|
||||
|
||||
def _split_path(path: str) -> list[tuple[str, int | None]]:
|
||||
"""Parse ``header.nomor_sprint`` or ``personel[2].nrp`` into segments.
|
||||
|
||||
Returns list of ``(name, index_or_none)`` tuples. Raises
|
||||
``InvalidFieldPathError`` on malformed input so the caller can surface
|
||||
a 400 to the client.
|
||||
"""
|
||||
if not path or path.startswith(".") or path.endswith("."):
|
||||
raise InvalidFieldPathError(f"Invalid field path: {path!r}")
|
||||
|
||||
parts = path.split(".")
|
||||
out: list[tuple[str, int | None]] = []
|
||||
for part in parts:
|
||||
match = _SEGMENT_RE.match(part)
|
||||
if match is None:
|
||||
raise InvalidFieldPathError(f"Invalid segment in path: {part!r}")
|
||||
name = match.group(1)
|
||||
idx_raw = match.group(2)
|
||||
idx = int(idx_raw) if idx_raw is not None else None
|
||||
out.append((name, idx))
|
||||
|
||||
if out[0][0] not in _ALLOWED_ROOTS:
|
||||
raise InvalidFieldPathError(
|
||||
f"Field path root {out[0][0]!r} not in allowed roots {sorted(_ALLOWED_ROOTS)!r}"
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _apply_path(data: dict[str, Any], path: str, new_value: Any) -> Any:
|
||||
"""Apply a single correction to ``data`` in place. Returns the old
|
||||
value so the caller can record it in the audit row.
|
||||
|
||||
Does NOT validate that the new value matches the field's expected
|
||||
type — that's the reviewer's responsibility; the whole point of HITL
|
||||
is to let humans override the model's typing.
|
||||
"""
|
||||
segments = _split_path(path)
|
||||
cursor: Any = data
|
||||
for name, idx in segments[:-1]:
|
||||
if not isinstance(cursor, dict) or name not in cursor:
|
||||
raise InvalidFieldPathError(f"Cannot traverse to {path!r}: missing {name!r}")
|
||||
cursor = cursor[name]
|
||||
if idx is not None:
|
||||
if not isinstance(cursor, list) or idx >= len(cursor):
|
||||
raise InvalidFieldPathError(
|
||||
f"Cannot traverse to {path!r}: index [{idx}] out of range"
|
||||
)
|
||||
cursor = cursor[idx]
|
||||
|
||||
name, idx = segments[-1]
|
||||
if idx is not None:
|
||||
# Terminal segment is a list-element, e.g. ``untuk[2]``.
|
||||
if not isinstance(cursor, dict) or name not in cursor:
|
||||
raise InvalidFieldPathError(f"Cannot apply to {path!r}: missing container {name!r}")
|
||||
container = cursor[name]
|
||||
if not isinstance(container, list) or idx >= len(container):
|
||||
raise InvalidFieldPathError(f"Cannot apply to {path!r}: index [{idx}] out of range")
|
||||
old = container[idx]
|
||||
container[idx] = new_value
|
||||
return old
|
||||
|
||||
if not isinstance(cursor, dict):
|
||||
raise InvalidFieldPathError(f"Cannot apply to {path!r}: parent is not an object")
|
||||
old = cursor.get(name)
|
||||
cursor[name] = new_value
|
||||
return old
|
||||
|
||||
|
||||
class JobRepository:
|
||||
"""SQL-backed repository for `jobs` rows."""
|
||||
|
||||
@@ -94,3 +201,139 @@ class JobRepository:
|
||||
if row is None:
|
||||
raise JobNotFoundError(f"Job not found: {job_id}")
|
||||
return row
|
||||
|
||||
# ---------- Phase 6 — HITL ----------
|
||||
|
||||
def apply_corrections(
|
||||
self,
|
||||
job_id: UUID,
|
||||
*,
|
||||
corrections: list[tuple[str, Any, str | None]],
|
||||
corrected_by: str | None,
|
||||
) -> list[JobCorrectionRow]:
|
||||
"""Apply a batch of field corrections atomically.
|
||||
|
||||
``corrections`` is a list of ``(path, new_value, reason)`` tuples.
|
||||
Returns the persisted audit rows so the caller can surface them in
|
||||
the response.
|
||||
|
||||
Raises
|
||||
------
|
||||
JobNotFoundError
|
||||
If the row doesn't exist.
|
||||
JobNotCompletedError
|
||||
If the job hasn't produced a result yet (status pending /
|
||||
processing / failed).
|
||||
JobAlreadyApprovedError
|
||||
If the job has been approved — edits are locked.
|
||||
InvalidFieldPathError
|
||||
If any path is malformed or references a disallowed root.
|
||||
"""
|
||||
row = self._get_or_raise(job_id)
|
||||
if row.result is None:
|
||||
raise JobNotCompletedError(
|
||||
f"Job {job_id} has no result to correct (status={row.status})"
|
||||
)
|
||||
if row.approved:
|
||||
raise JobAlreadyApprovedError(f"Job {job_id} is already approved; edits are locked")
|
||||
|
||||
# Deep-copy so we can roll back in memory if any correction fails.
|
||||
# The underlying JSON column will only be re-assigned once every
|
||||
# path applied cleanly.
|
||||
working = copy.deepcopy(row.result)
|
||||
applied: list[AppliedCorrection] = []
|
||||
for path, new_value, reason in corrections:
|
||||
old_value = _apply_path(working, path, new_value)
|
||||
applied.append(
|
||||
AppliedCorrection(
|
||||
field_path=path, old_value=old_value, new_value=new_value, reason=reason
|
||||
)
|
||||
)
|
||||
|
||||
# Persist audit rows first; if they fail the session rollback also
|
||||
# undoes the result-column update we're about to do.
|
||||
persisted: list[JobCorrectionRow] = []
|
||||
for event in applied:
|
||||
row_event = JobCorrectionRow(
|
||||
job_id=job_id,
|
||||
field_path=event.field_path,
|
||||
old_value=event.old_value,
|
||||
new_value=event.new_value,
|
||||
corrected_by=corrected_by,
|
||||
reason=event.reason,
|
||||
)
|
||||
self.session.add(row_event)
|
||||
persisted.append(row_event)
|
||||
|
||||
# Clear review flags that the correction has resolved. Right now we
|
||||
# only auto-clear MISSING_FIELD when any corrected field previously
|
||||
# held a null/empty value — the reviewer explicitly filled a gap.
|
||||
row.result = working
|
||||
row.review_flags = _recompute_flags(
|
||||
original_flags=list(row.review_flags or []),
|
||||
applied=applied,
|
||||
working_result=working,
|
||||
)
|
||||
row.updated_at = _utcnow()
|
||||
|
||||
self.session.flush()
|
||||
return persisted
|
||||
|
||||
def list_corrections(self, job_id: UUID) -> list[JobCorrectionRow]:
|
||||
"""Return the full audit trail for ``job_id`` in chronological order."""
|
||||
# ``get_or_raise`` so callers get a 404 instead of an empty list
|
||||
# when the job itself doesn't exist.
|
||||
self._get_or_raise(job_id)
|
||||
stmt = (
|
||||
select(JobCorrectionRow)
|
||||
.where(JobCorrectionRow.job_id == job_id)
|
||||
.order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id)
|
||||
)
|
||||
return list(self.session.scalars(stmt))
|
||||
|
||||
def approve(self, job_id: UUID, *, reviewed_by: str | None) -> JobRow:
|
||||
"""Mark a job as approved. Idempotent — re-approving is a no-op
|
||||
that keeps the original ``reviewed_at`` (so the audit trail stays
|
||||
intact).
|
||||
"""
|
||||
row = self._get_or_raise(job_id)
|
||||
if row.result is None:
|
||||
raise JobNotCompletedError(
|
||||
f"Job {job_id} has no result to approve (status={row.status})"
|
||||
)
|
||||
if row.approved:
|
||||
return row
|
||||
row.approved = True
|
||||
row.reviewed_by = reviewed_by
|
||||
row.reviewed_at = _utcnow()
|
||||
row.updated_at = row.reviewed_at
|
||||
return row
|
||||
|
||||
|
||||
def _recompute_flags(
|
||||
*,
|
||||
original_flags: list[str],
|
||||
applied: list[AppliedCorrection],
|
||||
working_result: dict[str, Any],
|
||||
) -> list[str]:
|
||||
"""Update review flags in light of the corrections just applied.
|
||||
|
||||
Keeps the policy simple on purpose:
|
||||
* ``missing_field`` is removed if after the edit every required
|
||||
header field is non-empty.
|
||||
* Other flags stay untouched — the reviewer should either correct the
|
||||
underlying issue (which this helper can detect) or explicitly
|
||||
approve the result as-is (which bypasses the flag list).
|
||||
"""
|
||||
flags = list(original_flags)
|
||||
if "missing_field" in flags:
|
||||
header = working_result.get("header") or {}
|
||||
filled = all(bool(header.get(key)) for key in ("nomor_sprint", "satuan_penerbit"))
|
||||
if filled:
|
||||
flags = [f for f in flags if f != "missing_field"]
|
||||
|
||||
# ``applied`` isn't used directly in this MVP rule, but we keep the
|
||||
# parameter so future policies can inspect exactly what changed
|
||||
# without re-diffing the blob.
|
||||
_ = applied
|
||||
return flags
|
||||
|
||||
18
src/ocr_sprint/llm/__init__.py
Normal file
18
src/ocr_sprint/llm/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""LLM-based extraction (Phase 5).
|
||||
|
||||
The hybrid extractor first runs the deterministic regex layer and then —
|
||||
only for fields that came back missing or low-confidence — calls a local
|
||||
Ollama model with a Pydantic-typed prompt. Everything is gated by
|
||||
``LLM_ENABLED``; if the flag is off or the Ollama server is unreachable,
|
||||
the pipeline degrades gracefully back to the regex result.
|
||||
"""
|
||||
|
||||
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||
from ocr_sprint.llm.extractor import LLMHeaderResult, llm_fill_header
|
||||
|
||||
__all__ = [
|
||||
"LLMHeaderResult",
|
||||
"LLMUnavailableError",
|
||||
"OllamaClient",
|
||||
"llm_fill_header",
|
||||
]
|
||||
97
src/ocr_sprint/llm/client.py
Normal file
97
src/ocr_sprint/llm/client.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Ollama HTTP client.
|
||||
|
||||
We deliberately avoid the ``ollama`` Python package — the wire format is a
|
||||
single ``POST /api/chat`` with ``format="json"`` and a system + user message,
|
||||
so a small ``httpx`` wrapper is enough. This keeps the runtime dependency
|
||||
footprint smaller and makes the mock-based unit tests trivial.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from ocr_sprint.config import get_settings
|
||||
from ocr_sprint.utils.logging import get_logger
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class LLMUnavailableError(RuntimeError):
|
||||
"""Raised when the Ollama server is unreachable, times out, or returns
|
||||
a malformed payload. The pipeline catches this and falls back to the
|
||||
regex-only result with a ``llm_fallback`` review flag.
|
||||
"""
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Tiny synchronous HTTP wrapper around the Ollama ``/api/chat`` endpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_url:
|
||||
Ollama server URL, e.g. ``http://localhost:11434``.
|
||||
model:
|
||||
Model tag to invoke (default ``qwen2.5:1.5b`` — chosen for CPU
|
||||
latency at acceptable accuracy).
|
||||
timeout_s:
|
||||
Hard wall-clock timeout for a single request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
timeout_s: int | None = None,
|
||||
) -> None:
|
||||
s = get_settings()
|
||||
self.base_url = (base_url or s.llm_base_url).rstrip("/")
|
||||
self.model = model or s.llm_model
|
||||
self.timeout_s = timeout_s if timeout_s is not None else s.llm_timeout_s
|
||||
|
||||
# ---------- public API ----------
|
||||
|
||||
def chat_json(self, system: str, user: str, schema_cls: type[T]) -> T:
|
||||
"""Run a single chat completion in JSON mode and validate the
|
||||
response against ``schema_cls``. Raises ``LLMUnavailableError`` on
|
||||
any transport / parse / validation failure so callers only have one
|
||||
exception to handle.
|
||||
"""
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
# Keep determinism reasonable — we want extraction, not creativity.
|
||||
"options": {"temperature": 0.0, "num_ctx": 4096},
|
||||
}
|
||||
url = f"{self.base_url}/api/chat"
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=self.timeout_s) as client:
|
||||
response = client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except (httpx.HTTPError, ValueError) as exc:
|
||||
_logger.warning("llm.transport_error", url=url, error=str(exc))
|
||||
raise LLMUnavailableError(f"Ollama request failed: {exc}") from exc
|
||||
|
||||
# Ollama returns {"message": {"role": "assistant", "content": "<json>"}}.
|
||||
try:
|
||||
content = data["message"]["content"]
|
||||
except (KeyError, TypeError) as exc:
|
||||
_logger.warning("llm.bad_envelope", payload=data)
|
||||
raise LLMUnavailableError(f"Ollama response missing message.content: {data!r}") from exc
|
||||
|
||||
try:
|
||||
return schema_cls.model_validate_json(content)
|
||||
except ValidationError as exc:
|
||||
_logger.warning("llm.validation_error", error=str(exc), content=content[:400])
|
||||
raise LLMUnavailableError(f"LLM JSON failed schema: {exc}") from exc
|
||||
84
src/ocr_sprint/llm/extractor.py
Normal file
84
src/ocr_sprint/llm/extractor.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""High-level LLM extractor.
|
||||
|
||||
The job is *narrow*: take the raw OCR text plus the partial header that
|
||||
came back from the regex layer, and return an LLM-derived header that the
|
||||
caller can merge in. We never let the LLM populate the personnel table —
|
||||
PP-Structure is more accurate and cheaper for that.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||
from ocr_sprint.llm.prompts import SYSTEM_HEADER, build_user_prompt
|
||||
from ocr_sprint.schemas.extraction import HeaderFields
|
||||
from ocr_sprint.utils.logging import get_logger
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LLMHeaderResult(BaseModel):
|
||||
"""Schema we ask the model to fill. Mirrors ``HeaderFields`` but is
|
||||
intentionally separate so we control exactly what the prompt and
|
||||
validation surface look like — the public ``HeaderFields`` may grow
|
||||
fields later that we don't want the LLM touching.
|
||||
"""
|
||||
|
||||
nomor_sprint: str | None = None
|
||||
tanggal: date | None = None
|
||||
satuan_penerbit: str | None = None
|
||||
perihal: str | None = None
|
||||
dasar: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
def llm_fill_header(
|
||||
raw_text: str,
|
||||
regex_header: HeaderFields,
|
||||
*,
|
||||
client: OllamaClient | None = None,
|
||||
) -> HeaderFields | None:
|
||||
"""Run the LLM extractor and return a *merged* HeaderFields.
|
||||
|
||||
Returns ``None`` if the model is unavailable so the caller can decide
|
||||
what to do (typically: keep the regex result and emit a fallback
|
||||
review flag).
|
||||
"""
|
||||
client = client or OllamaClient()
|
||||
|
||||
user = build_user_prompt(
|
||||
raw_text=raw_text,
|
||||
regex_partial=regex_header.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
try:
|
||||
llm = client.chat_json(SYSTEM_HEADER, user, LLMHeaderResult)
|
||||
except LLMUnavailableError as exc:
|
||||
_logger.warning("llm.unavailable", error=str(exc))
|
||||
return None
|
||||
|
||||
return _merge(regex_header, llm)
|
||||
|
||||
|
||||
def _merge(regex: HeaderFields, llm: LLMHeaderResult) -> HeaderFields:
|
||||
"""Merge LLM output into the regex result.
|
||||
|
||||
Policy: regex wins for any field it already filled. The LLM only fills
|
||||
the *gaps*. This keeps deterministic / verifiable extractions for the
|
||||
fields where regex is reliable and prevents the LLM from "correcting"
|
||||
a value that happens to look unusual but is in fact correct.
|
||||
"""
|
||||
merged = regex.model_copy(deep=True)
|
||||
if merged.nomor_sprint is None and llm.nomor_sprint:
|
||||
merged.nomor_sprint = llm.nomor_sprint
|
||||
if merged.tanggal is None and llm.tanggal is not None:
|
||||
merged.tanggal = llm.tanggal
|
||||
if not merged.satuan_penerbit and llm.satuan_penerbit:
|
||||
merged.satuan_penerbit = llm.satuan_penerbit
|
||||
if not merged.perihal and llm.perihal:
|
||||
merged.perihal = llm.perihal
|
||||
if not merged.dasar and llm.dasar:
|
||||
merged.dasar = list(llm.dasar)
|
||||
return merged
|
||||
48
src/ocr_sprint/llm/prompts.py
Normal file
48
src/ocr_sprint/llm/prompts.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Prompt builders for the LLM extractor.
|
||||
|
||||
Kept in their own module so the prompts can be edited / version-tracked
|
||||
without touching the orchestration logic. We build prompts in Indonesian
|
||||
because the source documents are too — the model performs better when the
|
||||
field labels in the prompt match the OCR text it's being asked about.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
SYSTEM_HEADER = (
|
||||
"Anda adalah asisten ekstraksi data untuk dokumen Surat Perintah (Sprint) "
|
||||
"Kepolisian Republik Indonesia (POLRI). Pengguna akan memberikan teks hasil "
|
||||
"OCR sebuah surat sprint, dan Anda harus mengembalikan JSON yang sesuai "
|
||||
"dengan skema yang diberikan.\n\n"
|
||||
"Aturan keras:\n"
|
||||
"1. Jangan mengarang. Jika sebuah field tidak terlihat di teks, kembalikan null.\n"
|
||||
"2. Jangan menerjemahkan field. Output harus identik ejaannya dengan teks "
|
||||
"sumber (kecuali normalisasi spasi/kapitalisasi yang jelas hasil OCR error).\n"
|
||||
"3. Tanggal: kembalikan format ISO YYYY-MM-DD jika tanggal terlihat, "
|
||||
"selain itu null.\n"
|
||||
"4. Dasar hukum: array string berisi tiap butir, urut sesuai teks.\n"
|
||||
"5. Jangan menambahkan field apa pun di luar skema. Output WAJIB JSON valid."
|
||||
)
|
||||
|
||||
|
||||
def build_user_prompt(raw_text: str, regex_partial: dict[str, object]) -> str:
|
||||
"""Construct the user message: OCR text + a hint about which fields the
|
||||
deterministic regex layer already filled. Telling the LLM what we
|
||||
*already have* keeps it from "creatively" overwriting good values.
|
||||
"""
|
||||
known_fields = "\n".join(f" - {k}: {v!r}" for k, v in sorted(regex_partial.items()) if v)
|
||||
known_block = (
|
||||
f"\nField yang sudah berhasil diekstrak dengan regex:\n{known_fields}\n"
|
||||
if known_fields
|
||||
else ""
|
||||
)
|
||||
|
||||
return (
|
||||
"Teks OCR:\n"
|
||||
"----------\n"
|
||||
f"{raw_text}\n"
|
||||
"----------\n"
|
||||
f"{known_block}"
|
||||
"Tugas: kembalikan JSON dengan field nomor_sprint, tanggal (ISO date | null), "
|
||||
"satuan_penerbit, perihal, dasar (array string). Hanya field yang terlihat — "
|
||||
"yang tidak ada di teks isi null (atau array kosong untuk dasar)."
|
||||
)
|
||||
@@ -15,6 +15,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ocr_sprint.config import get_settings
|
||||
from ocr_sprint.llm.extractor import llm_fill_header
|
||||
from ocr_sprint.pipeline.confidence import compute_confidence, route
|
||||
from ocr_sprint.pipeline.document_detect import DocumentDetectConfig, detect_and_correct
|
||||
from ocr_sprint.pipeline.extract.personnel import extract_personnel
|
||||
@@ -35,6 +36,18 @@ _logger = get_logger(__name__)
|
||||
_OCR_CONFIDENCE_FLAG_THRESHOLD = 0.80
|
||||
|
||||
|
||||
def _header_has_gaps(header: object) -> bool:
|
||||
"""True if any header field worth asking the LLM about is missing.
|
||||
|
||||
Using ``getattr`` so this stays decoupled from the exact attribute
|
||||
names; the schema change cost was too large last time we hard-coded.
|
||||
"""
|
||||
for field in ("nomor_sprint", "tanggal", "satuan_penerbit", "perihal"):
|
||||
if not getattr(header, field, None):
|
||||
return True
|
||||
return not getattr(header, "dasar", None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineOutput:
|
||||
"""Bundle returned by the orchestrator."""
|
||||
@@ -84,6 +97,20 @@ def run_pipeline(content: bytes) -> PipelineOutput:
|
||||
header = extract_header(full_text)
|
||||
ttd = find_signatory(full_text)
|
||||
|
||||
# Phase 5 — hybrid extraction. The regex layer is deterministic but
|
||||
# brittle to layout variants between satuan; if any header field is
|
||||
# still missing we ask the local LLM to fill the gaps. The merger
|
||||
# never lets the LLM overwrite a field that regex already captured.
|
||||
llm_flags: list[ReviewFlag] = []
|
||||
if s.llm_enabled and _header_has_gaps(header):
|
||||
merged = llm_fill_header(full_text, header)
|
||||
if merged is None:
|
||||
llm_flags.append(ReviewFlag.LLM_UNAVAILABLE)
|
||||
else:
|
||||
if merged.model_dump() != header.model_dump():
|
||||
llm_flags.append(ReviewFlag.LLM_FALLBACK)
|
||||
header = merged
|
||||
|
||||
personel: list[PersonnelEntry] = []
|
||||
if s.tables_enabled and cleaned_pages:
|
||||
all_tables: list[DetectedTable] = []
|
||||
@@ -99,7 +126,7 @@ def run_pipeline(content: bytes) -> PipelineOutput:
|
||||
personel_rows=len(personel),
|
||||
)
|
||||
|
||||
initial_flags: list[ReviewFlag] = []
|
||||
initial_flags: list[ReviewFlag] = list(llm_flags)
|
||||
if mean_ocr_conf < _OCR_CONFIDENCE_FLAG_THRESHOLD:
|
||||
initial_flags.append(ReviewFlag.LOW_OCR_CONFIDENCE)
|
||||
|
||||
|
||||
@@ -55,3 +55,7 @@ class DocumentResponse(BaseModel):
|
||||
data: ExtractionResult | None = None
|
||||
review_flags: list[str] = Field(default_factory=list)
|
||||
error: str | None = None
|
||||
# Phase 6 — HITL review state.
|
||||
approved: bool = False
|
||||
reviewed_by: str | None = None
|
||||
reviewed_at: datetime | None = None
|
||||
|
||||
@@ -19,6 +19,8 @@ class ReviewFlag(str, Enum):
|
||||
UNKNOWN_PANGKAT = "unknown_pangkat"
|
||||
PERSONNEL_COUNT_MISMATCH = "personnel_count_mismatch"
|
||||
DATE_PARSE_FAILED = "date_parse_failed"
|
||||
LLM_FALLBACK = "llm_fallback"
|
||||
LLM_UNAVAILABLE = "llm_unavailable"
|
||||
|
||||
|
||||
class Signatory(BaseModel):
|
||||
|
||||
62
src/ocr_sprint/schemas/review.py
Normal file
62
src/ocr_sprint/schemas/review.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Request / response schemas for the HITL review endpoints (Phase 6).
|
||||
|
||||
The API surface is deliberately small:
|
||||
|
||||
* ``CorrectionRequest`` — body of ``PATCH /documents/{id}``. A list of
|
||||
``FieldCorrection`` entries; each one is applied atomically (all-or-
|
||||
nothing) and recorded in the audit trail.
|
||||
* ``CorrectionEventResponse`` — single row in ``GET /documents/{id}/history``.
|
||||
* ``ApprovalResponse`` — echo back after ``POST /documents/{id}/approve``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class FieldCorrection(BaseModel):
|
||||
"""One field-level correction.
|
||||
|
||||
``path`` is a dotted JSON path into ``ExtractionResult``. Supported
|
||||
roots: ``header``, ``ttd``, ``personel[n]`` (n is a 0-based index),
|
||||
``untuk``. The path is validated by the repository before being
|
||||
applied; unknown roots return 400.
|
||||
"""
|
||||
|
||||
path: str = Field(..., description="Dotted JSON path, e.g. 'header.nomor_sprint'.")
|
||||
value: Any = Field(..., description="New value (any JSON-serialisable payload).")
|
||||
reason: str | None = Field(
|
||||
None, max_length=512, description="Optional free-form reason for the correction."
|
||||
)
|
||||
|
||||
|
||||
class CorrectionRequest(BaseModel):
|
||||
"""PATCH body — one or more field corrections, applied atomically."""
|
||||
|
||||
corrections: list[FieldCorrection] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class CorrectionEventResponse(BaseModel):
|
||||
"""One row of the audit log surfaced by GET /history."""
|
||||
|
||||
id: int
|
||||
job_id: UUID
|
||||
field_path: str
|
||||
old_value: Any | None = None
|
||||
new_value: Any | None = None
|
||||
corrected_by: str | None = None
|
||||
reason: str | None = None
|
||||
corrected_at: datetime
|
||||
|
||||
|
||||
class ApprovalResponse(BaseModel):
|
||||
"""Echo returned after a job is approved."""
|
||||
|
||||
job_id: UUID
|
||||
approved: bool
|
||||
reviewed_by: str | None = None
|
||||
reviewed_at: datetime | None = None
|
||||
248
tests/unit/test_api_hitl.py
Normal file
248
tests/unit/test_api_hitl.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""End-to-end HTTP tests for the HITL endpoints.
|
||||
|
||||
We re-use the ``fake_pipeline`` style from ``test_api.py`` so we don't pay
|
||||
the PaddleOCR init cost; the orchestrator is monkey-patched to return a
|
||||
synthetic ``ExtractionResult``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from ocr_sprint.main import create_app
|
||||
from ocr_sprint.pipeline import orchestrator as orch_module
|
||||
from ocr_sprint.pipeline.orchestrator import PipelineOutput
|
||||
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||
from ocr_sprint.schemas.extraction import (
|
||||
ExtractionResult,
|
||||
HeaderFields,
|
||||
PersonnelEntry,
|
||||
ReviewFlag,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client() -> TestClient:
|
||||
return TestClient(create_app())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput:
|
||||
result = ExtractionResult(
|
||||
header=HeaderFields(
|
||||
nomor_sprint="Sprin/1/I/2025",
|
||||
tanggal=date(2025, 1, 1),
|
||||
satuan_penerbit="POLRES TEST",
|
||||
perihal=None, # intentional gap so a PATCH can fill it
|
||||
),
|
||||
personel=[
|
||||
PersonnelEntry(pangkat="AIPDA", nrp="77060000", nama="BUDI", jabatan="ANGGOTA"),
|
||||
],
|
||||
review_flags=[ReviewFlag.MISSING_FIELD],
|
||||
confidence=0.7,
|
||||
)
|
||||
output = PipelineOutput(
|
||||
source_kind=SourceKind.PDF,
|
||||
status=DocumentStatus.NEEDS_REVIEW,
|
||||
confidence=0.7,
|
||||
result=result,
|
||||
)
|
||||
|
||||
def _fake_run(_content: bytes) -> PipelineOutput:
|
||||
return output
|
||||
|
||||
monkeypatch.setattr(orch_module, "run_pipeline", _fake_run)
|
||||
from ocr_sprint.api.routes import documents as docs_module
|
||||
|
||||
monkeypatch.setattr(docs_module, "run_pipeline", _fake_run)
|
||||
from ocr_sprint.worker import tasks as tasks_module
|
||||
|
||||
monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run)
|
||||
return output
|
||||
|
||||
|
||||
def _create_job(client: TestClient) -> str:
|
||||
post = client.post(
|
||||
"/api/v1/documents?sync=true",
|
||||
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||
)
|
||||
assert post.status_code == 200, post.text
|
||||
body = post.json()
|
||||
assert body["status"] == "needs_review"
|
||||
return str(body["job_id"])
|
||||
|
||||
|
||||
def test_patch_applies_correction_and_clears_missing_field(
|
||||
client: TestClient,
|
||||
fake_pipeline: PipelineOutput,
|
||||
) -> None:
|
||||
job_id = _create_job(client)
|
||||
patched = client.patch(
|
||||
f"/api/v1/documents/{job_id}",
|
||||
json={
|
||||
"corrections": [
|
||||
{
|
||||
"path": "header.perihal",
|
||||
"value": "Penyelidikan kasus X",
|
||||
"reason": "LLM missed it",
|
||||
}
|
||||
]
|
||||
},
|
||||
headers={"X-User-Id": "reviewer-a"},
|
||||
)
|
||||
assert patched.status_code == 200, patched.text
|
||||
body = patched.json()
|
||||
assert body["data"]["header"]["perihal"] == "Penyelidikan kasus X"
|
||||
# The fake pipeline has both required header fields filled, so the
|
||||
# ``missing_field`` flag is auto-cleared as soon as any correction
|
||||
# lands (the policy re-evaluates required-field coverage on every
|
||||
# edit).
|
||||
assert "missing_field" not in body["review_flags"]
|
||||
|
||||
|
||||
def test_patch_returns_400_for_unknown_path(
|
||||
client: TestClient,
|
||||
fake_pipeline: PipelineOutput,
|
||||
) -> None:
|
||||
job_id = _create_job(client)
|
||||
resp = client.patch(
|
||||
f"/api/v1/documents/{job_id}",
|
||||
json={"corrections": [{"path": "bogus.field", "value": "x"}]},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_patch_is_atomic_on_partial_failure(
|
||||
client: TestClient,
|
||||
fake_pipeline: PipelineOutput,
|
||||
) -> None:
|
||||
job_id = _create_job(client)
|
||||
resp = client.patch(
|
||||
f"/api/v1/documents/{job_id}",
|
||||
json={
|
||||
"corrections": [
|
||||
{"path": "header.perihal", "value": "OK"},
|
||||
{"path": "bogus.root", "value": "X"},
|
||||
]
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
# The first correction must not have persisted.
|
||||
got = client.get(f"/api/v1/documents/{job_id}")
|
||||
assert got.json()["data"]["header"]["perihal"] is None
|
||||
|
||||
|
||||
def test_history_returns_corrections_in_order(
|
||||
client: TestClient,
|
||||
fake_pipeline: PipelineOutput,
|
||||
) -> None:
|
||||
job_id = _create_job(client)
|
||||
client.patch(
|
||||
f"/api/v1/documents/{job_id}",
|
||||
json={"corrections": [{"path": "header.perihal", "value": "first"}]},
|
||||
headers={"X-User-Id": "reviewer-a"},
|
||||
)
|
||||
client.patch(
|
||||
f"/api/v1/documents/{job_id}",
|
||||
json={"corrections": [{"path": "header.perihal", "value": "second"}]},
|
||||
headers={"X-User-Id": "reviewer-b"},
|
||||
)
|
||||
|
||||
history = client.get(f"/api/v1/documents/{job_id}/history")
|
||||
assert history.status_code == 200
|
||||
events = history.json()
|
||||
assert [e["new_value"] for e in events] == ["first", "second"]
|
||||
assert [e["corrected_by"] for e in events] == ["reviewer-a", "reviewer-b"]
|
||||
# old_value of the second event should reflect the first edit.
|
||||
assert events[1]["old_value"] == "first"
|
||||
|
||||
|
||||
def test_history_returns_empty_list_for_untouched_job(
|
||||
client: TestClient,
|
||||
fake_pipeline: PipelineOutput,
|
||||
) -> None:
|
||||
job_id = _create_job(client)
|
||||
history = client.get(f"/api/v1/documents/{job_id}/history")
|
||||
assert history.status_code == 200
|
||||
assert history.json() == []
|
||||
|
||||
|
||||
def test_history_returns_404_for_unknown_job(client: TestClient) -> None:
|
||||
resp = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000/history")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_approve_locks_subsequent_patches(
|
||||
client: TestClient,
|
||||
fake_pipeline: PipelineOutput,
|
||||
) -> None:
|
||||
job_id = _create_job(client)
|
||||
approved = client.post(
|
||||
f"/api/v1/documents/{job_id}/approve",
|
||||
headers={"X-User-Id": "reviewer-a"},
|
||||
)
|
||||
assert approved.status_code == 200, approved.text
|
||||
body = approved.json()
|
||||
assert body["approved"] is True
|
||||
assert body["reviewed_by"] == "reviewer-a"
|
||||
assert body["reviewed_at"] # non-empty timestamp
|
||||
|
||||
# GET reflects the approval state.
|
||||
got = client.get(f"/api/v1/documents/{job_id}").json()
|
||||
assert got["approved"] is True
|
||||
|
||||
# PATCH after approve must be rejected with 409.
|
||||
patched = client.patch(
|
||||
f"/api/v1/documents/{job_id}",
|
||||
json={"corrections": [{"path": "header.perihal", "value": "X"}]},
|
||||
)
|
||||
assert patched.status_code == 409
|
||||
|
||||
|
||||
def test_approve_is_idempotent(
|
||||
client: TestClient,
|
||||
fake_pipeline: PipelineOutput,
|
||||
) -> None:
|
||||
job_id = _create_job(client)
|
||||
first = client.post(
|
||||
f"/api/v1/documents/{job_id}/approve",
|
||||
headers={"X-User-Id": "reviewer-a"},
|
||||
)
|
||||
second = client.post(
|
||||
f"/api/v1/documents/{job_id}/approve",
|
||||
headers={"X-User-Id": "reviewer-b"},
|
||||
)
|
||||
assert first.status_code == 200
|
||||
assert second.status_code == 200
|
||||
# Second approve must NOT change the attribution. (SQLite drops tzinfo
|
||||
# on roundtrip, which changes Pydantic's serialization between the two
|
||||
# calls; compare the naive components.)
|
||||
assert second.json()["reviewed_by"] == "reviewer-a"
|
||||
assert (
|
||||
second.json()["reviewed_at"].rstrip("Z").split("+")[0]
|
||||
== (first.json()["reviewed_at"].rstrip("Z").split("+")[0])
|
||||
)
|
||||
|
||||
|
||||
def test_patch_requires_at_least_one_correction(
|
||||
client: TestClient,
|
||||
fake_pipeline: PipelineOutput,
|
||||
) -> None:
|
||||
job_id = _create_job(client)
|
||||
resp = client.patch(
|
||||
f"/api/v1/documents/{job_id}",
|
||||
json={"corrections": []},
|
||||
)
|
||||
assert resp.status_code == 422 # Pydantic min_length=1 violation
|
||||
|
||||
|
||||
def test_patch_missing_job_returns_404(client: TestClient) -> None:
|
||||
resp = client.patch(
|
||||
"/api/v1/documents/00000000-0000-0000-0000-000000000000",
|
||||
json={"corrections": [{"path": "header.perihal", "value": "X"}]},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
238
tests/unit/test_db_hitl.py
Normal file
238
tests/unit/test_db_hitl.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Repository tests for Phase 6 HITL helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from ocr_sprint.db.base import Base, get_engine, session_scope
|
||||
from ocr_sprint.db.repositories import (
|
||||
InvalidFieldPathError,
|
||||
JobAlreadyApprovedError,
|
||||
JobNotCompletedError,
|
||||
JobNotFoundError,
|
||||
JobRepository,
|
||||
)
|
||||
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_ready() -> None:
|
||||
Base.metadata.create_all(bind=get_engine())
|
||||
|
||||
|
||||
def _seed_completed_job(
|
||||
*,
|
||||
result: dict[str, object] | None = None,
|
||||
flags: list[str] | None = None,
|
||||
) -> uuid4: # type: ignore[type-arg]
|
||||
jid = uuid4()
|
||||
with session_scope() as session:
|
||||
repo = JobRepository(session)
|
||||
repo.create(
|
||||
job_id=jid,
|
||||
filename="x.pdf",
|
||||
source_kind=SourceKind.PDF,
|
||||
blob_key="k",
|
||||
)
|
||||
with session_scope() as session:
|
||||
JobRepository(session).mark_completed(
|
||||
jid,
|
||||
status=DocumentStatus.NEEDS_REVIEW,
|
||||
confidence=0.7,
|
||||
result=result
|
||||
or {
|
||||
"header": {
|
||||
"nomor_sprint": "Sprin/1/I/2025",
|
||||
"satuan_penerbit": "POLRES X",
|
||||
"perihal": None,
|
||||
},
|
||||
"personel": [
|
||||
{"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"},
|
||||
],
|
||||
"untuk": ["Melaksanakan tugas"],
|
||||
},
|
||||
review_flags=flags or [],
|
||||
)
|
||||
return jid
|
||||
|
||||
|
||||
def test_apply_corrections_updates_nested_header_field(db_ready: None) -> None:
|
||||
jid = _seed_completed_job()
|
||||
with session_scope() as session:
|
||||
repo = JobRepository(session)
|
||||
repo.apply_corrections(
|
||||
jid,
|
||||
corrections=[("header.perihal", "Penyelidikan kasus X", "regex miss")],
|
||||
corrected_by="reviewer-a",
|
||||
)
|
||||
row = repo.get_or_raise(jid)
|
||||
assert row.result is not None
|
||||
assert row.result["header"]["perihal"] == "Penyelidikan kasus X"
|
||||
|
||||
|
||||
def test_apply_corrections_writes_audit_row(db_ready: None) -> None:
|
||||
jid = _seed_completed_job()
|
||||
with session_scope() as session:
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[("header.perihal", "Penyelidikan", None)],
|
||||
corrected_by="reviewer-a",
|
||||
)
|
||||
with session_scope() as session:
|
||||
events = JobRepository(session).list_corrections(jid)
|
||||
assert len(events) == 1
|
||||
assert events[0].field_path == "header.perihal"
|
||||
assert events[0].old_value is None
|
||||
assert events[0].new_value == "Penyelidikan"
|
||||
assert events[0].corrected_by == "reviewer-a"
|
||||
|
||||
|
||||
def test_apply_corrections_supports_list_index(db_ready: None) -> None:
|
||||
jid = _seed_completed_job()
|
||||
with session_scope() as session:
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[("personel[0].nrp", "77060001", None)],
|
||||
corrected_by=None,
|
||||
)
|
||||
row = JobRepository(session).get_or_raise(jid)
|
||||
assert row.result is not None
|
||||
assert row.result["personel"][0]["nrp"] == "77060001"
|
||||
|
||||
|
||||
def test_apply_corrections_is_atomic_on_invalid_path(db_ready: None) -> None:
|
||||
"""A second-correction failure must roll back the first one."""
|
||||
jid = _seed_completed_job()
|
||||
with session_scope() as session, pytest.raises(InvalidFieldPathError):
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[
|
||||
("header.perihal", "OK", None),
|
||||
("bogus.root", "X", None),
|
||||
],
|
||||
corrected_by=None,
|
||||
)
|
||||
# The first correction must not have persisted.
|
||||
with session_scope() as session:
|
||||
row = JobRepository(session).get_or_raise(jid)
|
||||
assert row.result is not None
|
||||
assert row.result["header"].get("perihal") is None
|
||||
|
||||
|
||||
def test_apply_corrections_rejects_out_of_range_index(db_ready: None) -> None:
|
||||
jid = _seed_completed_job()
|
||||
with session_scope() as session, pytest.raises(InvalidFieldPathError):
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[("personel[99].nrp", "77060001", None)],
|
||||
corrected_by=None,
|
||||
)
|
||||
|
||||
|
||||
def test_apply_corrections_rejects_after_approve(db_ready: None) -> None:
|
||||
jid = _seed_completed_job()
|
||||
with session_scope() as session:
|
||||
JobRepository(session).approve(jid, reviewed_by="reviewer-a")
|
||||
with session_scope() as session, pytest.raises(JobAlreadyApprovedError):
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[("header.perihal", "X", None)],
|
||||
corrected_by="reviewer-a",
|
||||
)
|
||||
|
||||
|
||||
def test_apply_corrections_rejects_missing_job(db_ready: None) -> None:
|
||||
with session_scope() as session, pytest.raises(JobNotFoundError):
|
||||
JobRepository(session).apply_corrections(
|
||||
uuid4(),
|
||||
corrections=[("header.perihal", "X", None)],
|
||||
corrected_by=None,
|
||||
)
|
||||
|
||||
|
||||
def test_apply_corrections_rejects_pending_job(db_ready: None) -> None:
|
||||
jid = uuid4()
|
||||
with session_scope() as session:
|
||||
JobRepository(session).create(
|
||||
job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k"
|
||||
)
|
||||
with session_scope() as session, pytest.raises(JobNotCompletedError):
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[("header.perihal", "X", None)],
|
||||
corrected_by=None,
|
||||
)
|
||||
|
||||
|
||||
def test_missing_field_flag_cleared_when_header_gap_filled(db_ready: None) -> None:
|
||||
jid = _seed_completed_job(
|
||||
result={
|
||||
"header": {
|
||||
"nomor_sprint": None,
|
||||
"satuan_penerbit": "POLRES X",
|
||||
}
|
||||
},
|
||||
flags=["missing_field", "low_ocr_confidence"],
|
||||
)
|
||||
with session_scope() as session:
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[("header.nomor_sprint", "Sprin/2/I/2025", None)],
|
||||
corrected_by="reviewer-a",
|
||||
)
|
||||
row = JobRepository(session).get_or_raise(jid)
|
||||
# ``low_ocr_confidence`` stays (correction doesn't resolve that signal),
|
||||
# but ``missing_field`` is gone because every required header key is
|
||||
# now non-empty.
|
||||
assert list(row.review_flags) == ["low_ocr_confidence"]
|
||||
|
||||
|
||||
def test_approve_sets_timestamps_and_is_idempotent(db_ready: None) -> None:
|
||||
jid = _seed_completed_job()
|
||||
with session_scope() as session:
|
||||
row = JobRepository(session).approve(jid, reviewed_by="reviewer-a")
|
||||
first_at = row.reviewed_at
|
||||
assert first_at is not None
|
||||
with session_scope() as session:
|
||||
row = JobRepository(session).approve(jid, reviewed_by="reviewer-b")
|
||||
# Second call must NOT overwrite reviewed_by or reviewed_at.
|
||||
# SQLite drops tzinfo on roundtrip, so compare the naive components.
|
||||
assert row.approved is True
|
||||
assert row.reviewed_by == "reviewer-a"
|
||||
assert row.reviewed_at is not None
|
||||
assert row.reviewed_at.replace(tzinfo=None) == first_at.replace(tzinfo=None)
|
||||
|
||||
|
||||
def test_approve_rejects_pending_job(db_ready: None) -> None:
|
||||
jid = uuid4()
|
||||
with session_scope() as session:
|
||||
JobRepository(session).create(
|
||||
job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k"
|
||||
)
|
||||
with session_scope() as session, pytest.raises(JobNotCompletedError):
|
||||
JobRepository(session).approve(jid, reviewed_by="rev")
|
||||
|
||||
|
||||
def test_history_returns_events_in_order(db_ready: None) -> None:
|
||||
jid = _seed_completed_job()
|
||||
with session_scope() as session:
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[("header.perihal", "one", None)],
|
||||
corrected_by="r1",
|
||||
)
|
||||
with session_scope() as session:
|
||||
JobRepository(session).apply_corrections(
|
||||
jid,
|
||||
corrections=[
|
||||
("header.perihal", "two", None),
|
||||
("personel[0].nama", "ANDI", None),
|
||||
],
|
||||
corrected_by="r2",
|
||||
)
|
||||
with session_scope() as session:
|
||||
events = JobRepository(session).list_corrections(jid)
|
||||
assert [e.new_value for e in events] == ["one", "two", "ANDI"]
|
||||
assert [e.corrected_by for e in events] == ["r1", "r2", "r2"]
|
||||
108
tests/unit/test_llm_client.py
Normal file
108
tests/unit/test_llm_client.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Unit tests for the Ollama HTTP client wrapper.
|
||||
|
||||
We swap ``httpx.Client`` inside ``ocr_sprint.llm.client`` for a builder that
|
||||
returns a real ``httpx.Client`` wrapping a ``MockTransport``. Capturing the
|
||||
original constructor *before* patching avoids infinite recursion in the
|
||||
patched callable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import ocr_sprint.llm.client as llm_client_module
|
||||
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||
|
||||
|
||||
class _Schema(BaseModel):
|
||||
foo: str
|
||||
bar: int
|
||||
|
||||
|
||||
def _ollama_envelope(content: str) -> dict[str, object]:
|
||||
"""Mimic the shape Ollama's /api/chat returns."""
|
||||
return {"message": {"role": "assistant", "content": content}, "done": True}
|
||||
|
||||
|
||||
def _patch_transport(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
handler: Any,
|
||||
) -> None:
|
||||
transport = httpx.MockTransport(handler)
|
||||
real_client = httpx.Client # capture before patching
|
||||
|
||||
def _factory(*_args: object, **kwargs: object) -> httpx.Client:
|
||||
# Strip any caller-provided transport kwarg; we always inject ours.
|
||||
kwargs.pop("transport", None)
|
||||
return real_client(transport=transport, **kwargs)
|
||||
|
||||
monkeypatch.setattr(llm_client_module.httpx, "Client", _factory)
|
||||
|
||||
|
||||
def test_chat_json_returns_validated_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _handler(request: httpx.Request) -> httpx.Response:
|
||||
captured["url"] = str(request.url)
|
||||
captured["body"] = request.read()
|
||||
return httpx.Response(200, json=_ollama_envelope('{"foo": "x", "bar": 7}'))
|
||||
|
||||
_patch_transport(monkeypatch, _handler)
|
||||
|
||||
client = OllamaClient(base_url="http://ollama:11434", model="m", timeout_s=5)
|
||||
out = client.chat_json("system msg", "user msg", _Schema)
|
||||
|
||||
assert out == _Schema(foo="x", bar=7)
|
||||
assert captured["url"] == "http://ollama:11434/api/chat"
|
||||
body = captured["body"]
|
||||
assert isinstance(body, bytes)
|
||||
assert b'"format":"json"' in body
|
||||
assert b'"system msg"' in body
|
||||
|
||||
|
||||
def test_chat_json_raises_on_http_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _handler(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(500, text="boom")
|
||||
|
||||
_patch_transport(monkeypatch, _handler)
|
||||
|
||||
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
|
||||
with pytest.raises(LLMUnavailableError, match="Ollama request failed"):
|
||||
client.chat_json("s", "u", _Schema)
|
||||
|
||||
|
||||
def test_chat_json_raises_on_invalid_json(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _handler(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json=_ollama_envelope("this is not json"))
|
||||
|
||||
_patch_transport(monkeypatch, _handler)
|
||||
|
||||
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
|
||||
with pytest.raises(LLMUnavailableError, match="schema"):
|
||||
client.chat_json("s", "u", _Schema)
|
||||
|
||||
|
||||
def test_chat_json_raises_on_missing_envelope(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _handler(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"oops": True})
|
||||
|
||||
_patch_transport(monkeypatch, _handler)
|
||||
|
||||
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
|
||||
with pytest.raises(LLMUnavailableError, match=r"message\.content"):
|
||||
client.chat_json("s", "u", _Schema)
|
||||
|
||||
|
||||
def test_chat_json_raises_on_connection_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _handler(request: httpx.Request) -> httpx.Response:
|
||||
raise httpx.ConnectError("nobody home", request=request)
|
||||
|
||||
_patch_transport(monkeypatch, _handler)
|
||||
|
||||
client = OllamaClient(base_url="http://x", model="m", timeout_s=1)
|
||||
with pytest.raises(LLMUnavailableError):
|
||||
client.chat_json("s", "u", _Schema)
|
||||
90
tests/unit/test_llm_extractor.py
Normal file
90
tests/unit/test_llm_extractor.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Unit tests for the hybrid LLM header extractor / merger."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||
from ocr_sprint.llm.extractor import LLMHeaderResult, _merge, llm_fill_header
|
||||
from ocr_sprint.schemas.extraction import HeaderFields
|
||||
|
||||
|
||||
class _StubClient(OllamaClient):
|
||||
"""Test double that bypasses HTTP entirely."""
|
||||
|
||||
def __init__(self, payload: LLMHeaderResult | Exception) -> None:
|
||||
# Skip the real __init__ — we don't need any real config.
|
||||
self._payload = payload
|
||||
|
||||
def chat_json( # type: ignore[override]
|
||||
self, system: str, user: str, schema_cls: type[BaseModel]
|
||||
) -> BaseModel:
|
||||
if isinstance(self._payload, Exception):
|
||||
raise self._payload
|
||||
return self._payload
|
||||
|
||||
|
||||
def test_merge_keeps_regex_when_present() -> None:
|
||||
regex = HeaderFields(nomor_sprint="Sprin/123/IV/2025/Reskrim", tanggal=date(2025, 4, 21))
|
||||
llm = LLMHeaderResult(nomor_sprint="HALLUCINATED", tanggal=date(1999, 1, 1), perihal="ok")
|
||||
out = _merge(regex, llm)
|
||||
assert out.nomor_sprint == "Sprin/123/IV/2025/Reskrim"
|
||||
assert out.tanggal == date(2025, 4, 21)
|
||||
# Gaps get filled.
|
||||
assert out.perihal == "ok"
|
||||
|
||||
|
||||
def test_merge_fills_gaps() -> None:
|
||||
regex = HeaderFields() # all None
|
||||
llm = LLMHeaderResult(
|
||||
nomor_sprint="Sprin/9/IX/2024",
|
||||
tanggal=date(2024, 9, 1),
|
||||
satuan_penerbit="Polres Bandung",
|
||||
perihal="Penyelidikan",
|
||||
dasar=["UU 2/2002", "Perkap 6/2017"],
|
||||
)
|
||||
out = _merge(regex, llm)
|
||||
assert out.nomor_sprint == "Sprin/9/IX/2024"
|
||||
assert out.tanggal == date(2024, 9, 1)
|
||||
assert out.satuan_penerbit == "Polres Bandung"
|
||||
assert out.perihal == "Penyelidikan"
|
||||
assert out.dasar == ["UU 2/2002", "Perkap 6/2017"]
|
||||
|
||||
|
||||
def test_llm_fill_header_returns_merged_when_client_succeeds() -> None:
|
||||
regex = HeaderFields(nomor_sprint="Sprin/1/I/2025") # has nomor, missing rest
|
||||
stub = _StubClient(
|
||||
LLMHeaderResult(
|
||||
satuan_penerbit="Polres Bandung",
|
||||
perihal="Penyelidikan",
|
||||
dasar=["UU 2/2002"],
|
||||
)
|
||||
)
|
||||
out = llm_fill_header(raw_text="...", regex_header=regex, client=stub)
|
||||
assert out is not None
|
||||
assert out.nomor_sprint == "Sprin/1/I/2025"
|
||||
assert out.satuan_penerbit == "Polres Bandung"
|
||||
assert out.perihal == "Penyelidikan"
|
||||
assert out.dasar == ["UU 2/2002"]
|
||||
|
||||
|
||||
def test_llm_fill_header_returns_none_when_unavailable() -> None:
|
||||
stub = _StubClient(LLMUnavailableError("server down"))
|
||||
out = llm_fill_header(raw_text="...", regex_header=HeaderFields(), client=stub)
|
||||
assert out is None
|
||||
|
||||
|
||||
def test_merge_does_not_overwrite_dasar_when_regex_has_it() -> None:
|
||||
regex = HeaderFields(dasar=["UU 2/2002"])
|
||||
llm = LLMHeaderResult(dasar=["something else", "more"])
|
||||
out = _merge(regex, llm)
|
||||
assert out.dasar == ["UU 2/2002"]
|
||||
|
||||
|
||||
def test_llm_extractor_unused_argument_kept_silent() -> None:
|
||||
# A trivial sanity check that the public function signature accepts
|
||||
# keyword-only `client` — this matches how the orchestrator calls it.
|
||||
pytest.importorskip("ocr_sprint.llm.extractor")
|
||||
171
tests/unit/test_orchestrator_llm.py
Normal file
171
tests/unit/test_orchestrator_llm.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Orchestrator-level tests for the Phase 5 hybrid LLM wiring.
|
||||
|
||||
These tests stub out the heavy stages (ingest / preprocess / OCR / table)
|
||||
so we can verify the *branching* behaviour around the LLM step without
|
||||
booting Paddle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from ocr_sprint.pipeline import orchestrator as orch_module
|
||||
from ocr_sprint.pipeline.orchestrator import _header_has_gaps, run_pipeline
|
||||
from ocr_sprint.schemas.document import SourceKind
|
||||
from ocr_sprint.schemas.extraction import HeaderFields, ReviewFlag, Signatory
|
||||
|
||||
|
||||
def test_header_has_gaps_detects_missing_fields() -> None:
|
||||
full = HeaderFields(
|
||||
nomor_sprint="Sprin/1/I/2025",
|
||||
tanggal=date(2025, 1, 1),
|
||||
satuan_penerbit="Polres X",
|
||||
perihal="ok",
|
||||
dasar=["UU 2/2002"],
|
||||
)
|
||||
assert _header_has_gaps(full) is False
|
||||
|
||||
assert _header_has_gaps(HeaderFields()) is True
|
||||
assert _header_has_gaps(full.model_copy(update={"perihal": None})) is True
|
||||
assert _header_has_gaps(full.model_copy(update={"dasar": []})) is True
|
||||
|
||||
|
||||
def _stub_pipeline_stages(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
raw_text: str,
|
||||
regex_header: HeaderFields,
|
||||
) -> None:
|
||||
"""Replace ingest -> ocr -> tables with cheap fakes so the orchestrator
|
||||
runs without Paddle / PyMuPDF.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from ocr_sprint.pipeline import ingest as ingest_module
|
||||
from ocr_sprint.pipeline import ocr as ocr_module
|
||||
from ocr_sprint.pipeline.ingest import IngestedPage
|
||||
|
||||
img = np.full((100, 100, 3), 255, dtype=np.uint8)
|
||||
fake_page = IngestedPage(image=img, page_index=0)
|
||||
fake_ocr_page = ocr_module.OCRPage(
|
||||
lines=[
|
||||
ocr_module.OCRLine(text=raw_text, confidence=0.95, box=((0, 0), (1, 0), (1, 1), (0, 1)))
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(orch_module, "detect_source_kind", lambda _: SourceKind.PDF)
|
||||
monkeypatch.setattr(orch_module, "ingest", lambda *a, **k: [fake_page])
|
||||
monkeypatch.setattr(orch_module, "detect_and_correct", lambda image, _cfg: image)
|
||||
monkeypatch.setattr(orch_module, "preprocess", lambda image, _cfg: image)
|
||||
monkeypatch.setattr(orch_module, "run_ocr", lambda _image: fake_ocr_page)
|
||||
# No tables in these tests.
|
||||
monkeypatch.setattr(orch_module, "run_table_extraction", lambda _img: [])
|
||||
monkeypatch.setattr(orch_module, "extract_personnel", lambda _tables: [])
|
||||
# Header / signatory / validators come from the real implementation
|
||||
# for `extract_header`, but we override to control gap state.
|
||||
monkeypatch.setattr(orch_module, "extract_header", lambda _text: regex_header)
|
||||
monkeypatch.setattr(orch_module, "find_signatory", lambda _text: Signatory())
|
||||
monkeypatch.setattr(orch_module, "validate_extraction", lambda _result: [])
|
||||
# Keep ingest_module referenced so import isn't dropped.
|
||||
assert ingest_module is not None
|
||||
|
||||
|
||||
def test_orchestrator_skips_llm_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("LLM_ENABLED", "false")
|
||||
from ocr_sprint.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
|
||||
_stub_pipeline_stages(
|
||||
monkeypatch,
|
||||
raw_text="dummy",
|
||||
regex_header=HeaderFields(), # all gaps
|
||||
)
|
||||
|
||||
called = {"n": 0}
|
||||
|
||||
def _trip(*_args: object, **_kwargs: object) -> None:
|
||||
called["n"] += 1
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(orch_module, "llm_fill_header", _trip)
|
||||
|
||||
result = run_pipeline(b"%PDF-1.4\n%fake")
|
||||
assert called["n"] == 0
|
||||
assert ReviewFlag.LLM_FALLBACK not in result.result.review_flags
|
||||
assert ReviewFlag.LLM_UNAVAILABLE not in result.result.review_flags
|
||||
|
||||
|
||||
def test_orchestrator_skips_llm_when_header_complete(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("LLM_ENABLED", "true")
|
||||
from ocr_sprint.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
|
||||
_stub_pipeline_stages(
|
||||
monkeypatch,
|
||||
raw_text="dummy",
|
||||
regex_header=HeaderFields(
|
||||
nomor_sprint="Sprin/1/I/2025",
|
||||
tanggal=date(2025, 1, 1),
|
||||
satuan_penerbit="Polres X",
|
||||
perihal="ok",
|
||||
dasar=["UU 2/2002"],
|
||||
),
|
||||
)
|
||||
|
||||
called = {"n": 0}
|
||||
|
||||
def _trip(*_args: object, **_kwargs: object) -> None:
|
||||
called["n"] += 1
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(orch_module, "llm_fill_header", _trip)
|
||||
|
||||
run_pipeline(b"%PDF-1.4\n%fake")
|
||||
assert called["n"] == 0
|
||||
|
||||
|
||||
def test_orchestrator_calls_llm_and_marks_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("LLM_ENABLED", "true")
|
||||
from ocr_sprint.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
|
||||
regex_partial = HeaderFields(nomor_sprint="Sprin/1/I/2025") # rest missing
|
||||
_stub_pipeline_stages(monkeypatch, raw_text="dummy text", regex_header=regex_partial)
|
||||
|
||||
def _llm(_raw: str, header: HeaderFields, **_: object) -> HeaderFields:
|
||||
return header.model_copy(
|
||||
update={
|
||||
"satuan_penerbit": "Polres Bandung",
|
||||
"perihal": "Penyelidikan",
|
||||
"dasar": ["UU 2/2002"],
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(orch_module, "llm_fill_header", _llm)
|
||||
|
||||
out = run_pipeline(b"%PDF-1.4\n%fake")
|
||||
assert out.result.header.satuan_penerbit == "Polres Bandung"
|
||||
assert out.result.header.perihal == "Penyelidikan"
|
||||
assert ReviewFlag.LLM_FALLBACK in out.result.review_flags
|
||||
assert ReviewFlag.LLM_UNAVAILABLE not in out.result.review_flags
|
||||
|
||||
|
||||
def test_orchestrator_marks_unavailable_when_llm_returns_none(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("LLM_ENABLED", "true")
|
||||
from ocr_sprint.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
|
||||
_stub_pipeline_stages(monkeypatch, raw_text="dummy", regex_header=HeaderFields())
|
||||
monkeypatch.setattr(orch_module, "llm_fill_header", lambda *_a, **_k: None)
|
||||
|
||||
out = run_pipeline(b"%PDF-1.4\n%fake")
|
||||
assert ReviewFlag.LLM_UNAVAILABLE in out.result.review_flags
|
||||
assert ReviewFlag.LLM_FALLBACK not in out.result.review_flags
|
||||
Reference in New Issue
Block a user