Merge pull request #4 from Adriankf59/devin/1777135879-phase-5-llm-hybrid

Phase 5: hybrid LLM extraction (Ollama) for header gaps
This commit is contained in:
Adrian Kuman Firmansah
2026-04-26 03:20:12 +07:00
committed by GitHub
18 changed files with 1704 additions and 23 deletions

View File

@@ -0,0 +1,60 @@
"""phase6 hitl: job_corrections + approval columns
Revision ID: 3b1f2c9a4d56
Revises: ff8c14fbf8a0
Create Date: 2026-04-25 14:30:00.000000
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "3b1f2c9a4d56"
down_revision: str | None = "ff8c14fbf8a0"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
with op.batch_alter_table("jobs") as batch:
batch.add_column(
sa.Column(
"approved",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
)
)
batch.add_column(sa.Column("reviewed_by", sa.String(length=128), nullable=True))
batch.add_column(sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=True))
op.create_table(
"job_corrections",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("job_id", sa.Uuid(), nullable=False),
sa.Column("field_path", sa.String(length=256), nullable=False),
sa.Column("old_value", sa.JSON(), nullable=True),
sa.Column("new_value", sa.JSON(), nullable=True),
sa.Column("corrected_by", sa.String(length=128), nullable=True),
sa.Column("reason", sa.String(length=512), nullable=True),
sa.Column("corrected_at", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(["job_id"], ["jobs.job_id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
op.f("ix_job_corrections_job_id"),
"job_corrections",
["job_id"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_job_corrections_job_id"), table_name="job_corrections")
op.drop_table("job_corrections")
with op.batch_alter_table("jobs") as batch:
batch.drop_column("reviewed_at")
batch.drop_column("reviewed_by")
batch.drop_column("approved")

View File

@@ -1,17 +1,17 @@
"""phase4 jobs table
Revision ID: ff8c14fbf8a0
Revises:
Revises:
Create Date: 2026-04-25 15:54:18.579147
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'ff8c14fbf8a0'
revision: str = "ff8c14fbf8a0"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
@@ -19,24 +19,25 @@ depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('jobs',
sa.Column('job_id', sa.Uuid(), nullable=False),
sa.Column('status', sa.String(length=32), nullable=False),
sa.Column('source_kind', sa.String(length=16), nullable=False),
sa.Column('filename', sa.String(length=512), nullable=False),
sa.Column('blob_key', sa.String(length=512), nullable=True),
sa.Column('confidence', sa.Float(), nullable=True),
sa.Column('review_flags', sa.JSON(), nullable=False),
sa.Column('result', sa.JSON(), nullable=True),
sa.Column('error', sa.String(length=2048), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint('job_id')
op.create_table(
"jobs",
sa.Column("job_id", sa.Uuid(), nullable=False),
sa.Column("status", sa.String(length=32), nullable=False),
sa.Column("source_kind", sa.String(length=16), nullable=False),
sa.Column("filename", sa.String(length=512), nullable=False),
sa.Column("blob_key", sa.String(length=512), nullable=True),
sa.Column("confidence", sa.Float(), nullable=True),
sa.Column("review_flags", sa.JSON(), nullable=False),
sa.Column("result", sa.JSON(), nullable=True),
sa.Column("error", sa.String(length=2048), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("job_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('jobs')
op.drop_table("jobs")
# ### end Alembic commands ###

View File

@@ -22,7 +22,17 @@ from __future__ import annotations
from typing import Annotated
from uuid import UUID, uuid4
from fastapi import APIRouter, Depends, File, HTTPException, Query, Response, UploadFile, status
from fastapi import (
APIRouter,
Depends,
File,
Header,
HTTPException,
Query,
Response,
UploadFile,
status,
)
from sqlalchemy.orm import Session
from ocr_sprint.api.deps.auth import require_api_key
@@ -31,11 +41,22 @@ from ocr_sprint.api.errors import UnsupportedDocumentError
from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS
from ocr_sprint.config import get_settings
from ocr_sprint.db.base import session_scope
from ocr_sprint.db.repositories import JobNotFoundError, JobRepository
from ocr_sprint.db.repositories import (
InvalidFieldPathError,
JobAlreadyApprovedError,
JobNotCompletedError,
JobNotFoundError,
JobRepository,
)
from ocr_sprint.pipeline.ingest import detect_source_kind
from ocr_sprint.pipeline.orchestrator import run_pipeline
from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus
from ocr_sprint.schemas.extraction import ExtractionResult
from ocr_sprint.schemas.review import (
ApprovalResponse,
CorrectionEventResponse,
CorrectionRequest,
)
from ocr_sprint.storage.blob import get_blob_storage
from ocr_sprint.utils.logging import get_logger
@@ -75,6 +96,9 @@ def _row_to_response(row: object) -> DocumentResponse:
data=result_obj,
review_flags=list(row.review_flags or []),
error=row.error,
approved=bool(row.approved),
reviewed_by=row.reviewed_by,
reviewed_at=row.reviewed_at,
)
@@ -192,3 +216,116 @@ async def get_document(
except JobNotFoundError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
return _row_to_response(row)
# ---------- Phase 6 — HITL ----------
def _correction_row_to_response(row: object) -> CorrectionEventResponse:
# Local import to avoid a cyclic import at module load time.
from ocr_sprint.db.models import JobCorrectionRow
assert isinstance(row, JobCorrectionRow)
return CorrectionEventResponse(
id=row.id,
job_id=row.job_id,
field_path=row.field_path,
old_value=row.old_value,
new_value=row.new_value,
corrected_by=row.corrected_by,
reason=row.reason,
corrected_at=row.corrected_at,
)
@router.patch(
"/{job_id}",
response_model=DocumentResponse,
)
async def patch_document(
job_id: UUID,
body: CorrectionRequest,
session: Annotated[Session, Depends(get_session)],
x_user_id: Annotated[
str | None,
Header(description="Free-form reviewer identifier recorded on the audit row."),
] = None,
) -> DocumentResponse:
"""Apply one or more field-level corrections and record an audit trail.
The whole batch is applied atomically — if any path is invalid the
request fails with 400 and no side effects are written. Returns the
updated document so the client doesn't need a follow-up GET.
"""
repo = JobRepository(session)
try:
repo.apply_corrections(
job_id,
corrections=[(c.path, c.value, c.reason) for c in body.corrections],
corrected_by=x_user_id,
)
row = repo.get_or_raise(job_id)
except JobNotFoundError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
except InvalidFieldPathError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
except JobAlreadyApprovedError as exc:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
except JobNotCompletedError as exc:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
_logger.info(
"documents.patched",
job_id=str(job_id),
count=len(body.corrections),
corrected_by=x_user_id or "",
)
return _row_to_response(row)
@router.get(
"/{job_id}/history",
response_model=list[CorrectionEventResponse],
)
async def get_history(
job_id: UUID,
session: Annotated[Session, Depends(get_session)],
) -> list[CorrectionEventResponse]:
repo = JobRepository(session)
try:
rows = repo.list_corrections(job_id)
except JobNotFoundError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
return [_correction_row_to_response(r) for r in rows]
@router.post(
"/{job_id}/approve",
response_model=ApprovalResponse,
)
async def approve_document(
job_id: UUID,
session: Annotated[Session, Depends(get_session)],
x_user_id: Annotated[
str | None,
Header(description="Free-form reviewer identifier recorded on the job."),
] = None,
) -> ApprovalResponse:
"""Lock a job's final version. Idempotent: re-approving returns the
existing row without overwriting ``reviewed_at``.
"""
repo = JobRepository(session)
try:
row = repo.approve(job_id, reviewed_by=x_user_id)
except JobNotFoundError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
except JobNotCompletedError as exc:
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
_logger.info("documents.approved", job_id=str(job_id), reviewed_by=row.reviewed_by or "")
return ApprovalResponse(
job_id=row.job_id,
approved=bool(row.approved),
reviewed_by=row.reviewed_by,
reviewed_at=row.reviewed_at,
)

View File

@@ -16,7 +16,7 @@ from datetime import datetime, timezone
from typing import Any
from uuid import UUID, uuid4
from sqlalchemy import JSON, DateTime, Float, String, Uuid
from sqlalchemy import JSON, Boolean, DateTime, Float, ForeignKey, Integer, String, Uuid
from sqlalchemy.orm import Mapped, mapped_column
from ocr_sprint.db.base import Base
@@ -42,6 +42,15 @@ class JobRow(Base):
result: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
error: Mapped[str | None] = mapped_column(String(2048), nullable=True)
# Phase 6 — HITL review state.
# Once ``approved=True`` the row is immutable except to admin users;
# corrections after that point are rejected by the route. ``reviewed_by``
# stores the free-form user identifier the reviewer sent via the
# ``X-User-Id`` header (best-effort attribution — no full RBAC yet).
approved: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
reviewed_by: Mapped[str | None] = mapped_column(String(128), nullable=True)
reviewed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, default=_utcnow
)
@@ -51,3 +60,37 @@ class JobRow(Base):
def __repr__(self) -> str:
return f"JobRow(job_id={self.job_id!s}, status={self.status!r})"
class JobCorrectionRow(Base):
"""One correction event on a job's ``result``.
Each PATCH call writes one row per changed field path so we have a
full audit trail. Rows are append-only — never updated, never
deleted — so the history is reproducible and usable as ground-truth
data for future fine-tuning.
"""
__tablename__ = "job_corrections"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
job_id: Mapped[UUID] = mapped_column(
Uuid, ForeignKey("jobs.job_id", ondelete="CASCADE"), nullable=False, index=True
)
# Dotted JSON path into ExtractionResult, e.g. "header.nomor_sprint" or
# "personel[3].nrp". Kept as a plain string for simplicity — we don't
# parse it server-side beyond the allow-list check in the repository.
field_path: Mapped[str] = mapped_column(String(256), nullable=False)
old_value: Mapped[Any | None] = mapped_column(JSON, nullable=True)
new_value: Mapped[Any | None] = mapped_column(JSON, nullable=True)
corrected_by: Mapped[str | None] = mapped_column(String(128), nullable=True)
reason: Mapped[str | None] = mapped_column(String(512), nullable=True)
corrected_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, default=_utcnow
)
def __repr__(self) -> str:
return (
f"JobCorrectionRow(job_id={self.job_id!s}, "
f"field={self.field_path!r}, by={self.corrected_by!r})"
)

View File

@@ -6,6 +6,9 @@ know about sessions, transactions, or the row → schema mapping.
from __future__ import annotations
import copy
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
@@ -13,7 +16,7 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from ocr_sprint.db.models import JobRow
from ocr_sprint.db.models import JobCorrectionRow, JobRow
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
@@ -25,6 +28,110 @@ class JobNotFoundError(LookupError):
"""Raised by API code when GET /documents/{id} hits a missing row."""
class InvalidFieldPathError(ValueError):
"""Raised when a PATCH request references an unsupported field path."""
class JobAlreadyApprovedError(RuntimeError):
"""Raised when a PATCH is attempted against an already-approved job."""
class JobNotCompletedError(RuntimeError):
"""Raised when a PATCH/approve is attempted against a job that hasn't
produced a ``result`` payload yet (e.g. still pending or failed).
"""
@dataclass(frozen=True)
class AppliedCorrection:
"""Internal record of a correction that successfully applied to the
in-memory ``result`` dict. The repository turns this into a persisted
``JobCorrectionRow`` after the whole batch is validated.
"""
field_path: str
old_value: Any
new_value: Any
reason: str | None
# Allow-list of top-level keys we let reviewers edit. Keeps the attack
# surface small: they can't inject arbitrary fields into the JSON blob.
_ALLOWED_ROOTS: frozenset[str] = frozenset({"header", "ttd", "personel", "untuk"})
# Matches a single path segment like ``personel[3]`` — supports at most one
# index per segment, enough for the list fields we care about.
_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$")
def _split_path(path: str) -> list[tuple[str, int | None]]:
"""Parse ``header.nomor_sprint`` or ``personel[2].nrp`` into segments.
Returns list of ``(name, index_or_none)`` tuples. Raises
``InvalidFieldPathError`` on malformed input so the caller can surface
a 400 to the client.
"""
if not path or path.startswith(".") or path.endswith("."):
raise InvalidFieldPathError(f"Invalid field path: {path!r}")
parts = path.split(".")
out: list[tuple[str, int | None]] = []
for part in parts:
match = _SEGMENT_RE.match(part)
if match is None:
raise InvalidFieldPathError(f"Invalid segment in path: {part!r}")
name = match.group(1)
idx_raw = match.group(2)
idx = int(idx_raw) if idx_raw is not None else None
out.append((name, idx))
if out[0][0] not in _ALLOWED_ROOTS:
raise InvalidFieldPathError(
f"Field path root {out[0][0]!r} not in allowed roots {sorted(_ALLOWED_ROOTS)!r}"
)
return out
def _apply_path(data: dict[str, Any], path: str, new_value: Any) -> Any:
"""Apply a single correction to ``data`` in place. Returns the old
value so the caller can record it in the audit row.
Does NOT validate that the new value matches the field's expected
type — that's the reviewer's responsibility; the whole point of HITL
is to let humans override the model's typing.
"""
segments = _split_path(path)
cursor: Any = data
for name, idx in segments[:-1]:
if not isinstance(cursor, dict) or name not in cursor:
raise InvalidFieldPathError(f"Cannot traverse to {path!r}: missing {name!r}")
cursor = cursor[name]
if idx is not None:
if not isinstance(cursor, list) or idx >= len(cursor):
raise InvalidFieldPathError(
f"Cannot traverse to {path!r}: index [{idx}] out of range"
)
cursor = cursor[idx]
name, idx = segments[-1]
if idx is not None:
# Terminal segment is a list-element, e.g. ``untuk[2]``.
if not isinstance(cursor, dict) or name not in cursor:
raise InvalidFieldPathError(f"Cannot apply to {path!r}: missing container {name!r}")
container = cursor[name]
if not isinstance(container, list) or idx >= len(container):
raise InvalidFieldPathError(f"Cannot apply to {path!r}: index [{idx}] out of range")
old = container[idx]
container[idx] = new_value
return old
if not isinstance(cursor, dict):
raise InvalidFieldPathError(f"Cannot apply to {path!r}: parent is not an object")
old = cursor.get(name)
cursor[name] = new_value
return old
class JobRepository:
"""SQL-backed repository for `jobs` rows."""
@@ -94,3 +201,139 @@ class JobRepository:
if row is None:
raise JobNotFoundError(f"Job not found: {job_id}")
return row
# ---------- Phase 6 — HITL ----------
def apply_corrections(
self,
job_id: UUID,
*,
corrections: list[tuple[str, Any, str | None]],
corrected_by: str | None,
) -> list[JobCorrectionRow]:
"""Apply a batch of field corrections atomically.
``corrections`` is a list of ``(path, new_value, reason)`` tuples.
Returns the persisted audit rows so the caller can surface them in
the response.
Raises
------
JobNotFoundError
If the row doesn't exist.
JobNotCompletedError
If the job hasn't produced a result yet (status pending /
processing / failed).
JobAlreadyApprovedError
If the job has been approved — edits are locked.
InvalidFieldPathError
If any path is malformed or references a disallowed root.
"""
row = self._get_or_raise(job_id)
if row.result is None:
raise JobNotCompletedError(
f"Job {job_id} has no result to correct (status={row.status})"
)
if row.approved:
raise JobAlreadyApprovedError(f"Job {job_id} is already approved; edits are locked")
# Deep-copy so we can roll back in memory if any correction fails.
# The underlying JSON column will only be re-assigned once every
# path applied cleanly.
working = copy.deepcopy(row.result)
applied: list[AppliedCorrection] = []
for path, new_value, reason in corrections:
old_value = _apply_path(working, path, new_value)
applied.append(
AppliedCorrection(
field_path=path, old_value=old_value, new_value=new_value, reason=reason
)
)
# Persist audit rows first; if they fail the session rollback also
# undoes the result-column update we're about to do.
persisted: list[JobCorrectionRow] = []
for event in applied:
row_event = JobCorrectionRow(
job_id=job_id,
field_path=event.field_path,
old_value=event.old_value,
new_value=event.new_value,
corrected_by=corrected_by,
reason=event.reason,
)
self.session.add(row_event)
persisted.append(row_event)
# Clear review flags that the correction has resolved. Right now we
# only auto-clear MISSING_FIELD when any corrected field previously
# held a null/empty value — the reviewer explicitly filled a gap.
row.result = working
row.review_flags = _recompute_flags(
original_flags=list(row.review_flags or []),
applied=applied,
working_result=working,
)
row.updated_at = _utcnow()
self.session.flush()
return persisted
def list_corrections(self, job_id: UUID) -> list[JobCorrectionRow]:
"""Return the full audit trail for ``job_id`` in chronological order."""
# ``get_or_raise`` so callers get a 404 instead of an empty list
# when the job itself doesn't exist.
self._get_or_raise(job_id)
stmt = (
select(JobCorrectionRow)
.where(JobCorrectionRow.job_id == job_id)
.order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id)
)
return list(self.session.scalars(stmt))
def approve(self, job_id: UUID, *, reviewed_by: str | None) -> JobRow:
"""Mark a job as approved. Idempotent — re-approving is a no-op
that keeps the original ``reviewed_at`` (so the audit trail stays
intact).
"""
row = self._get_or_raise(job_id)
if row.result is None:
raise JobNotCompletedError(
f"Job {job_id} has no result to approve (status={row.status})"
)
if row.approved:
return row
row.approved = True
row.reviewed_by = reviewed_by
row.reviewed_at = _utcnow()
row.updated_at = row.reviewed_at
return row
def _recompute_flags(
*,
original_flags: list[str],
applied: list[AppliedCorrection],
working_result: dict[str, Any],
) -> list[str]:
"""Update review flags in light of the corrections just applied.
Keeps the policy simple on purpose:
* ``missing_field`` is removed if after the edit every required
header field is non-empty.
* Other flags stay untouched — the reviewer should either correct the
underlying issue (which this helper can detect) or explicitly
approve the result as-is (which bypasses the flag list).
"""
flags = list(original_flags)
if "missing_field" in flags:
header = working_result.get("header") or {}
filled = all(bool(header.get(key)) for key in ("nomor_sprint", "satuan_penerbit"))
if filled:
flags = [f for f in flags if f != "missing_field"]
# ``applied`` isn't used directly in this MVP rule, but we keep the
# parameter so future policies can inspect exactly what changed
# without re-diffing the blob.
_ = applied
return flags

View File

@@ -0,0 +1,18 @@
"""LLM-based extraction (Phase 5).
The hybrid extractor first runs the deterministic regex layer and then —
only for fields that came back missing or low-confidence — calls a local
Ollama model with a Pydantic-typed prompt. Everything is gated by
``LLM_ENABLED``; if the flag is off or the Ollama server is unreachable,
the pipeline degrades gracefully back to the regex result.
"""
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
from ocr_sprint.llm.extractor import LLMHeaderResult, llm_fill_header
__all__ = [
"LLMHeaderResult",
"LLMUnavailableError",
"OllamaClient",
"llm_fill_header",
]

View File

@@ -0,0 +1,97 @@
"""Ollama HTTP client.
We deliberately avoid the ``ollama`` Python package — the wire format is a
single ``POST /api/chat`` with ``format="json"`` and a system + user message,
so a small ``httpx`` wrapper is enough. This keeps the runtime dependency
footprint smaller and makes the mock-based unit tests trivial.
"""
from __future__ import annotations
from typing import TypeVar
import httpx
from pydantic import BaseModel, ValidationError
from ocr_sprint.config import get_settings
from ocr_sprint.utils.logging import get_logger
_logger = get_logger(__name__)
T = TypeVar("T", bound=BaseModel)
class LLMUnavailableError(RuntimeError):
"""Raised when the Ollama server is unreachable, times out, or returns
a malformed payload. The pipeline catches this and falls back to the
regex-only result with a ``llm_fallback`` review flag.
"""
class OllamaClient:
"""Tiny synchronous HTTP wrapper around the Ollama ``/api/chat`` endpoint.
Parameters
----------
base_url:
Ollama server URL, e.g. ``http://localhost:11434``.
model:
Model tag to invoke (default ``qwen2.5:1.5b`` — chosen for CPU
latency at acceptable accuracy).
timeout_s:
Hard wall-clock timeout for a single request.
"""
def __init__(
self,
base_url: str | None = None,
model: str | None = None,
timeout_s: int | None = None,
) -> None:
s = get_settings()
self.base_url = (base_url or s.llm_base_url).rstrip("/")
self.model = model or s.llm_model
self.timeout_s = timeout_s if timeout_s is not None else s.llm_timeout_s
# ---------- public API ----------
def chat_json(self, system: str, user: str, schema_cls: type[T]) -> T:
"""Run a single chat completion in JSON mode and validate the
response against ``schema_cls``. Raises ``LLMUnavailableError`` on
any transport / parse / validation failure so callers only have one
exception to handle.
"""
payload = {
"model": self.model,
"stream": False,
"format": "json",
"messages": [
{"role": "system", "content": system},
{"role": "user", "content": user},
],
# Keep determinism reasonable — we want extraction, not creativity.
"options": {"temperature": 0.0, "num_ctx": 4096},
}
url = f"{self.base_url}/api/chat"
try:
with httpx.Client(timeout=self.timeout_s) as client:
response = client.post(url, json=payload)
response.raise_for_status()
data = response.json()
except (httpx.HTTPError, ValueError) as exc:
_logger.warning("llm.transport_error", url=url, error=str(exc))
raise LLMUnavailableError(f"Ollama request failed: {exc}") from exc
# Ollama returns {"message": {"role": "assistant", "content": "<json>"}}.
try:
content = data["message"]["content"]
except (KeyError, TypeError) as exc:
_logger.warning("llm.bad_envelope", payload=data)
raise LLMUnavailableError(f"Ollama response missing message.content: {data!r}") from exc
try:
return schema_cls.model_validate_json(content)
except ValidationError as exc:
_logger.warning("llm.validation_error", error=str(exc), content=content[:400])
raise LLMUnavailableError(f"LLM JSON failed schema: {exc}") from exc

View File

@@ -0,0 +1,84 @@
"""High-level LLM extractor.
The job is *narrow*: take the raw OCR text plus the partial header that
came back from the regex layer, and return an LLM-derived header that the
caller can merge in. We never let the LLM populate the personnel table —
PP-Structure is more accurate and cheaper for that.
"""
from __future__ import annotations
from datetime import date
from pydantic import BaseModel, Field
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
from ocr_sprint.llm.prompts import SYSTEM_HEADER, build_user_prompt
from ocr_sprint.schemas.extraction import HeaderFields
from ocr_sprint.utils.logging import get_logger
_logger = get_logger(__name__)
class LLMHeaderResult(BaseModel):
"""Schema we ask the model to fill. Mirrors ``HeaderFields`` but is
intentionally separate so we control exactly what the prompt and
validation surface look like — the public ``HeaderFields`` may grow
fields later that we don't want the LLM touching.
"""
nomor_sprint: str | None = None
tanggal: date | None = None
satuan_penerbit: str | None = None
perihal: str | None = None
dasar: list[str] = Field(default_factory=list)
def llm_fill_header(
raw_text: str,
regex_header: HeaderFields,
*,
client: OllamaClient | None = None,
) -> HeaderFields | None:
"""Run the LLM extractor and return a *merged* HeaderFields.
Returns ``None`` if the model is unavailable so the caller can decide
what to do (typically: keep the regex result and emit a fallback
review flag).
"""
client = client or OllamaClient()
user = build_user_prompt(
raw_text=raw_text,
regex_partial=regex_header.model_dump(mode="json"),
)
try:
llm = client.chat_json(SYSTEM_HEADER, user, LLMHeaderResult)
except LLMUnavailableError as exc:
_logger.warning("llm.unavailable", error=str(exc))
return None
return _merge(regex_header, llm)
def _merge(regex: HeaderFields, llm: LLMHeaderResult) -> HeaderFields:
"""Merge LLM output into the regex result.
Policy: regex wins for any field it already filled. The LLM only fills
the *gaps*. This keeps deterministic / verifiable extractions for the
fields where regex is reliable and prevents the LLM from "correcting"
a value that happens to look unusual but is in fact correct.
"""
merged = regex.model_copy(deep=True)
if merged.nomor_sprint is None and llm.nomor_sprint:
merged.nomor_sprint = llm.nomor_sprint
if merged.tanggal is None and llm.tanggal is not None:
merged.tanggal = llm.tanggal
if not merged.satuan_penerbit and llm.satuan_penerbit:
merged.satuan_penerbit = llm.satuan_penerbit
if not merged.perihal and llm.perihal:
merged.perihal = llm.perihal
if not merged.dasar and llm.dasar:
merged.dasar = list(llm.dasar)
return merged

View File

@@ -0,0 +1,48 @@
"""Prompt builders for the LLM extractor.
Kept in their own module so the prompts can be edited / version-tracked
without touching the orchestration logic. We build prompts in Indonesian
because the source documents are too — the model performs better when the
field labels in the prompt match the OCR text it's being asked about.
"""
from __future__ import annotations
SYSTEM_HEADER = (
"Anda adalah asisten ekstraksi data untuk dokumen Surat Perintah (Sprint) "
"Kepolisian Republik Indonesia (POLRI). Pengguna akan memberikan teks hasil "
"OCR sebuah surat sprint, dan Anda harus mengembalikan JSON yang sesuai "
"dengan skema yang diberikan.\n\n"
"Aturan keras:\n"
"1. Jangan mengarang. Jika sebuah field tidak terlihat di teks, kembalikan null.\n"
"2. Jangan menerjemahkan field. Output harus identik ejaannya dengan teks "
"sumber (kecuali normalisasi spasi/kapitalisasi yang jelas hasil OCR error).\n"
"3. Tanggal: kembalikan format ISO YYYY-MM-DD jika tanggal terlihat, "
"selain itu null.\n"
"4. Dasar hukum: array string berisi tiap butir, urut sesuai teks.\n"
"5. Jangan menambahkan field apa pun di luar skema. Output WAJIB JSON valid."
)
def build_user_prompt(raw_text: str, regex_partial: dict[str, object]) -> str:
"""Construct the user message: OCR text + a hint about which fields the
deterministic regex layer already filled. Telling the LLM what we
*already have* keeps it from "creatively" overwriting good values.
"""
known_fields = "\n".join(f" - {k}: {v!r}" for k, v in sorted(regex_partial.items()) if v)
known_block = (
f"\nField yang sudah berhasil diekstrak dengan regex:\n{known_fields}\n"
if known_fields
else ""
)
return (
"Teks OCR:\n"
"----------\n"
f"{raw_text}\n"
"----------\n"
f"{known_block}"
"Tugas: kembalikan JSON dengan field nomor_sprint, tanggal (ISO date | null), "
"satuan_penerbit, perihal, dasar (array string). Hanya field yang terlihat — "
"yang tidak ada di teks isi null (atau array kosong untuk dasar)."
)

View File

@@ -15,6 +15,7 @@ from __future__ import annotations
from dataclasses import dataclass
from ocr_sprint.config import get_settings
from ocr_sprint.llm.extractor import llm_fill_header
from ocr_sprint.pipeline.confidence import compute_confidence, route
from ocr_sprint.pipeline.document_detect import DocumentDetectConfig, detect_and_correct
from ocr_sprint.pipeline.extract.personnel import extract_personnel
@@ -35,6 +36,18 @@ _logger = get_logger(__name__)
_OCR_CONFIDENCE_FLAG_THRESHOLD = 0.80
def _header_has_gaps(header: object) -> bool:
"""True if any header field worth asking the LLM about is missing.
Using ``getattr`` so this stays decoupled from the exact attribute
names; the schema change cost was too large last time we hard-coded.
"""
for field in ("nomor_sprint", "tanggal", "satuan_penerbit", "perihal"):
if not getattr(header, field, None):
return True
return not getattr(header, "dasar", None)
@dataclass
class PipelineOutput:
"""Bundle returned by the orchestrator."""
@@ -84,6 +97,20 @@ def run_pipeline(content: bytes) -> PipelineOutput:
header = extract_header(full_text)
ttd = find_signatory(full_text)
# Phase 5 — hybrid extraction. The regex layer is deterministic but
# brittle to layout variants between satuan; if any header field is
# still missing we ask the local LLM to fill the gaps. The merger
# never lets the LLM overwrite a field that regex already captured.
llm_flags: list[ReviewFlag] = []
if s.llm_enabled and _header_has_gaps(header):
merged = llm_fill_header(full_text, header)
if merged is None:
llm_flags.append(ReviewFlag.LLM_UNAVAILABLE)
else:
if merged.model_dump() != header.model_dump():
llm_flags.append(ReviewFlag.LLM_FALLBACK)
header = merged
personel: list[PersonnelEntry] = []
if s.tables_enabled and cleaned_pages:
all_tables: list[DetectedTable] = []
@@ -99,7 +126,7 @@ def run_pipeline(content: bytes) -> PipelineOutput:
personel_rows=len(personel),
)
initial_flags: list[ReviewFlag] = []
initial_flags: list[ReviewFlag] = list(llm_flags)
if mean_ocr_conf < _OCR_CONFIDENCE_FLAG_THRESHOLD:
initial_flags.append(ReviewFlag.LOW_OCR_CONFIDENCE)

View File

@@ -55,3 +55,7 @@ class DocumentResponse(BaseModel):
data: ExtractionResult | None = None
review_flags: list[str] = Field(default_factory=list)
error: str | None = None
# Phase 6 — HITL review state.
approved: bool = False
reviewed_by: str | None = None
reviewed_at: datetime | None = None

View File

@@ -19,6 +19,8 @@ class ReviewFlag(str, Enum):
UNKNOWN_PANGKAT = "unknown_pangkat"
PERSONNEL_COUNT_MISMATCH = "personnel_count_mismatch"
DATE_PARSE_FAILED = "date_parse_failed"
LLM_FALLBACK = "llm_fallback"
LLM_UNAVAILABLE = "llm_unavailable"
class Signatory(BaseModel):

View File

@@ -0,0 +1,62 @@
"""Request / response schemas for the HITL review endpoints (Phase 6).
The API surface is deliberately small:
* ``CorrectionRequest`` — body of ``PATCH /documents/{id}``. A list of
``FieldCorrection`` entries; each one is applied atomically (all-or-
nothing) and recorded in the audit trail.
* ``CorrectionEventResponse`` — single row in ``GET /documents/{id}/history``.
* ``ApprovalResponse`` — echo back after ``POST /documents/{id}/approve``.
"""
from __future__ import annotations
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel, Field
class FieldCorrection(BaseModel):
"""One field-level correction.
``path`` is a dotted JSON path into ``ExtractionResult``. Supported
roots: ``header``, ``ttd``, ``personel[n]`` (n is a 0-based index),
``untuk``. The path is validated by the repository before being
applied; unknown roots return 400.
"""
path: str = Field(..., description="Dotted JSON path, e.g. 'header.nomor_sprint'.")
value: Any = Field(..., description="New value (any JSON-serialisable payload).")
reason: str | None = Field(
None, max_length=512, description="Optional free-form reason for the correction."
)
class CorrectionRequest(BaseModel):
"""PATCH body — one or more field corrections, applied atomically."""
corrections: list[FieldCorrection] = Field(..., min_length=1)
class CorrectionEventResponse(BaseModel):
"""One row of the audit log surfaced by GET /history."""
id: int
job_id: UUID
field_path: str
old_value: Any | None = None
new_value: Any | None = None
corrected_by: str | None = None
reason: str | None = None
corrected_at: datetime
class ApprovalResponse(BaseModel):
"""Echo returned after a job is approved."""
job_id: UUID
approved: bool
reviewed_by: str | None = None
reviewed_at: datetime | None = None

248
tests/unit/test_api_hitl.py Normal file
View File

@@ -0,0 +1,248 @@
"""End-to-end HTTP tests for the HITL endpoints.
We re-use the ``fake_pipeline`` style from ``test_api.py`` so we don't pay
the PaddleOCR init cost; the orchestrator is monkey-patched to return a
synthetic ``ExtractionResult``.
"""
from __future__ import annotations
from datetime import date
import pytest
from fastapi.testclient import TestClient
from ocr_sprint.main import create_app
from ocr_sprint.pipeline import orchestrator as orch_module
from ocr_sprint.pipeline.orchestrator import PipelineOutput
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
from ocr_sprint.schemas.extraction import (
ExtractionResult,
HeaderFields,
PersonnelEntry,
ReviewFlag,
)
@pytest.fixture
def client() -> TestClient:
return TestClient(create_app())
@pytest.fixture
def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput:
result = ExtractionResult(
header=HeaderFields(
nomor_sprint="Sprin/1/I/2025",
tanggal=date(2025, 1, 1),
satuan_penerbit="POLRES TEST",
perihal=None, # intentional gap so a PATCH can fill it
),
personel=[
PersonnelEntry(pangkat="AIPDA", nrp="77060000", nama="BUDI", jabatan="ANGGOTA"),
],
review_flags=[ReviewFlag.MISSING_FIELD],
confidence=0.7,
)
output = PipelineOutput(
source_kind=SourceKind.PDF,
status=DocumentStatus.NEEDS_REVIEW,
confidence=0.7,
result=result,
)
def _fake_run(_content: bytes) -> PipelineOutput:
return output
monkeypatch.setattr(orch_module, "run_pipeline", _fake_run)
from ocr_sprint.api.routes import documents as docs_module
monkeypatch.setattr(docs_module, "run_pipeline", _fake_run)
from ocr_sprint.worker import tasks as tasks_module
monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run)
return output
def _create_job(client: TestClient) -> str:
post = client.post(
"/api/v1/documents?sync=true",
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
)
assert post.status_code == 200, post.text
body = post.json()
assert body["status"] == "needs_review"
return str(body["job_id"])
def test_patch_applies_correction_and_clears_missing_field(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
patched = client.patch(
f"/api/v1/documents/{job_id}",
json={
"corrections": [
{
"path": "header.perihal",
"value": "Penyelidikan kasus X",
"reason": "LLM missed it",
}
]
},
headers={"X-User-Id": "reviewer-a"},
)
assert patched.status_code == 200, patched.text
body = patched.json()
assert body["data"]["header"]["perihal"] == "Penyelidikan kasus X"
# The fake pipeline has both required header fields filled, so the
# ``missing_field`` flag is auto-cleared as soon as any correction
# lands (the policy re-evaluates required-field coverage on every
# edit).
assert "missing_field" not in body["review_flags"]
def test_patch_returns_400_for_unknown_path(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
resp = client.patch(
f"/api/v1/documents/{job_id}",
json={"corrections": [{"path": "bogus.field", "value": "x"}]},
)
assert resp.status_code == 400
def test_patch_is_atomic_on_partial_failure(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
resp = client.patch(
f"/api/v1/documents/{job_id}",
json={
"corrections": [
{"path": "header.perihal", "value": "OK"},
{"path": "bogus.root", "value": "X"},
]
},
)
assert resp.status_code == 400
# The first correction must not have persisted.
got = client.get(f"/api/v1/documents/{job_id}")
assert got.json()["data"]["header"]["perihal"] is None
def test_history_returns_corrections_in_order(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
client.patch(
f"/api/v1/documents/{job_id}",
json={"corrections": [{"path": "header.perihal", "value": "first"}]},
headers={"X-User-Id": "reviewer-a"},
)
client.patch(
f"/api/v1/documents/{job_id}",
json={"corrections": [{"path": "header.perihal", "value": "second"}]},
headers={"X-User-Id": "reviewer-b"},
)
history = client.get(f"/api/v1/documents/{job_id}/history")
assert history.status_code == 200
events = history.json()
assert [e["new_value"] for e in events] == ["first", "second"]
assert [e["corrected_by"] for e in events] == ["reviewer-a", "reviewer-b"]
# old_value of the second event should reflect the first edit.
assert events[1]["old_value"] == "first"
def test_history_returns_empty_list_for_untouched_job(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
history = client.get(f"/api/v1/documents/{job_id}/history")
assert history.status_code == 200
assert history.json() == []
def test_history_returns_404_for_unknown_job(client: TestClient) -> None:
resp = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000/history")
assert resp.status_code == 404
def test_approve_locks_subsequent_patches(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
approved = client.post(
f"/api/v1/documents/{job_id}/approve",
headers={"X-User-Id": "reviewer-a"},
)
assert approved.status_code == 200, approved.text
body = approved.json()
assert body["approved"] is True
assert body["reviewed_by"] == "reviewer-a"
assert body["reviewed_at"] # non-empty timestamp
# GET reflects the approval state.
got = client.get(f"/api/v1/documents/{job_id}").json()
assert got["approved"] is True
# PATCH after approve must be rejected with 409.
patched = client.patch(
f"/api/v1/documents/{job_id}",
json={"corrections": [{"path": "header.perihal", "value": "X"}]},
)
assert patched.status_code == 409
def test_approve_is_idempotent(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
first = client.post(
f"/api/v1/documents/{job_id}/approve",
headers={"X-User-Id": "reviewer-a"},
)
second = client.post(
f"/api/v1/documents/{job_id}/approve",
headers={"X-User-Id": "reviewer-b"},
)
assert first.status_code == 200
assert second.status_code == 200
# Second approve must NOT change the attribution. (SQLite drops tzinfo
# on roundtrip, which changes Pydantic's serialization between the two
# calls; compare the naive components.)
assert second.json()["reviewed_by"] == "reviewer-a"
assert (
second.json()["reviewed_at"].rstrip("Z").split("+")[0]
== (first.json()["reviewed_at"].rstrip("Z").split("+")[0])
)
def test_patch_requires_at_least_one_correction(
client: TestClient,
fake_pipeline: PipelineOutput,
) -> None:
job_id = _create_job(client)
resp = client.patch(
f"/api/v1/documents/{job_id}",
json={"corrections": []},
)
assert resp.status_code == 422 # Pydantic min_length=1 violation
def test_patch_missing_job_returns_404(client: TestClient) -> None:
resp = client.patch(
"/api/v1/documents/00000000-0000-0000-0000-000000000000",
json={"corrections": [{"path": "header.perihal", "value": "X"}]},
)
assert resp.status_code == 404

238
tests/unit/test_db_hitl.py Normal file
View File

@@ -0,0 +1,238 @@
"""Repository tests for Phase 6 HITL helpers."""
from __future__ import annotations
from uuid import uuid4
import pytest
from ocr_sprint.db.base import Base, get_engine, session_scope
from ocr_sprint.db.repositories import (
InvalidFieldPathError,
JobAlreadyApprovedError,
JobNotCompletedError,
JobNotFoundError,
JobRepository,
)
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
@pytest.fixture
def db_ready() -> None:
Base.metadata.create_all(bind=get_engine())
def _seed_completed_job(
*,
result: dict[str, object] | None = None,
flags: list[str] | None = None,
) -> uuid4: # type: ignore[type-arg]
jid = uuid4()
with session_scope() as session:
repo = JobRepository(session)
repo.create(
job_id=jid,
filename="x.pdf",
source_kind=SourceKind.PDF,
blob_key="k",
)
with session_scope() as session:
JobRepository(session).mark_completed(
jid,
status=DocumentStatus.NEEDS_REVIEW,
confidence=0.7,
result=result
or {
"header": {
"nomor_sprint": "Sprin/1/I/2025",
"satuan_penerbit": "POLRES X",
"perihal": None,
},
"personel": [
{"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"},
],
"untuk": ["Melaksanakan tugas"],
},
review_flags=flags or [],
)
return jid
def test_apply_corrections_updates_nested_header_field(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
repo = JobRepository(session)
repo.apply_corrections(
jid,
corrections=[("header.perihal", "Penyelidikan kasus X", "regex miss")],
corrected_by="reviewer-a",
)
row = repo.get_or_raise(jid)
assert row.result is not None
assert row.result["header"]["perihal"] == "Penyelidikan kasus X"
def test_apply_corrections_writes_audit_row(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[("header.perihal", "Penyelidikan", None)],
corrected_by="reviewer-a",
)
with session_scope() as session:
events = JobRepository(session).list_corrections(jid)
assert len(events) == 1
assert events[0].field_path == "header.perihal"
assert events[0].old_value is None
assert events[0].new_value == "Penyelidikan"
assert events[0].corrected_by == "reviewer-a"
def test_apply_corrections_supports_list_index(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[("personel[0].nrp", "77060001", None)],
corrected_by=None,
)
row = JobRepository(session).get_or_raise(jid)
assert row.result is not None
assert row.result["personel"][0]["nrp"] == "77060001"
def test_apply_corrections_is_atomic_on_invalid_path(db_ready: None) -> None:
"""A second-correction failure must roll back the first one."""
jid = _seed_completed_job()
with session_scope() as session, pytest.raises(InvalidFieldPathError):
JobRepository(session).apply_corrections(
jid,
corrections=[
("header.perihal", "OK", None),
("bogus.root", "X", None),
],
corrected_by=None,
)
# The first correction must not have persisted.
with session_scope() as session:
row = JobRepository(session).get_or_raise(jid)
assert row.result is not None
assert row.result["header"].get("perihal") is None
def test_apply_corrections_rejects_out_of_range_index(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session, pytest.raises(InvalidFieldPathError):
JobRepository(session).apply_corrections(
jid,
corrections=[("personel[99].nrp", "77060001", None)],
corrected_by=None,
)
def test_apply_corrections_rejects_after_approve(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
JobRepository(session).approve(jid, reviewed_by="reviewer-a")
with session_scope() as session, pytest.raises(JobAlreadyApprovedError):
JobRepository(session).apply_corrections(
jid,
corrections=[("header.perihal", "X", None)],
corrected_by="reviewer-a",
)
def test_apply_corrections_rejects_missing_job(db_ready: None) -> None:
with session_scope() as session, pytest.raises(JobNotFoundError):
JobRepository(session).apply_corrections(
uuid4(),
corrections=[("header.perihal", "X", None)],
corrected_by=None,
)
def test_apply_corrections_rejects_pending_job(db_ready: None) -> None:
jid = uuid4()
with session_scope() as session:
JobRepository(session).create(
job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k"
)
with session_scope() as session, pytest.raises(JobNotCompletedError):
JobRepository(session).apply_corrections(
jid,
corrections=[("header.perihal", "X", None)],
corrected_by=None,
)
def test_missing_field_flag_cleared_when_header_gap_filled(db_ready: None) -> None:
jid = _seed_completed_job(
result={
"header": {
"nomor_sprint": None,
"satuan_penerbit": "POLRES X",
}
},
flags=["missing_field", "low_ocr_confidence"],
)
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[("header.nomor_sprint", "Sprin/2/I/2025", None)],
corrected_by="reviewer-a",
)
row = JobRepository(session).get_or_raise(jid)
# ``low_ocr_confidence`` stays (correction doesn't resolve that signal),
# but ``missing_field`` is gone because every required header key is
# now non-empty.
assert list(row.review_flags) == ["low_ocr_confidence"]
def test_approve_sets_timestamps_and_is_idempotent(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
row = JobRepository(session).approve(jid, reviewed_by="reviewer-a")
first_at = row.reviewed_at
assert first_at is not None
with session_scope() as session:
row = JobRepository(session).approve(jid, reviewed_by="reviewer-b")
# Second call must NOT overwrite reviewed_by or reviewed_at.
# SQLite drops tzinfo on roundtrip, so compare the naive components.
assert row.approved is True
assert row.reviewed_by == "reviewer-a"
assert row.reviewed_at is not None
assert row.reviewed_at.replace(tzinfo=None) == first_at.replace(tzinfo=None)
def test_approve_rejects_pending_job(db_ready: None) -> None:
jid = uuid4()
with session_scope() as session:
JobRepository(session).create(
job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k"
)
with session_scope() as session, pytest.raises(JobNotCompletedError):
JobRepository(session).approve(jid, reviewed_by="rev")
def test_history_returns_events_in_order(db_ready: None) -> None:
jid = _seed_completed_job()
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[("header.perihal", "one", None)],
corrected_by="r1",
)
with session_scope() as session:
JobRepository(session).apply_corrections(
jid,
corrections=[
("header.perihal", "two", None),
("personel[0].nama", "ANDI", None),
],
corrected_by="r2",
)
with session_scope() as session:
events = JobRepository(session).list_corrections(jid)
assert [e.new_value for e in events] == ["one", "two", "ANDI"]
assert [e.corrected_by for e in events] == ["r1", "r2", "r2"]

View File

@@ -0,0 +1,108 @@
"""Unit tests for the Ollama HTTP client wrapper.
We swap ``httpx.Client`` inside ``ocr_sprint.llm.client`` for a builder that
returns a real ``httpx.Client`` wrapping a ``MockTransport``. Capturing the
original constructor *before* patching avoids infinite recursion in the
patched callable.
"""
from __future__ import annotations
from typing import Any
import httpx
import pytest
from pydantic import BaseModel
import ocr_sprint.llm.client as llm_client_module
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
class _Schema(BaseModel):
foo: str
bar: int
def _ollama_envelope(content: str) -> dict[str, object]:
"""Mimic the shape Ollama's /api/chat returns."""
return {"message": {"role": "assistant", "content": content}, "done": True}
def _patch_transport(
monkeypatch: pytest.MonkeyPatch,
handler: Any,
) -> None:
transport = httpx.MockTransport(handler)
real_client = httpx.Client # capture before patching
def _factory(*_args: object, **kwargs: object) -> httpx.Client:
# Strip any caller-provided transport kwarg; we always inject ours.
kwargs.pop("transport", None)
return real_client(transport=transport, **kwargs)
monkeypatch.setattr(llm_client_module.httpx, "Client", _factory)
def test_chat_json_returns_validated_model(monkeypatch: pytest.MonkeyPatch) -> None:
captured: dict[str, object] = {}
def _handler(request: httpx.Request) -> httpx.Response:
captured["url"] = str(request.url)
captured["body"] = request.read()
return httpx.Response(200, json=_ollama_envelope('{"foo": "x", "bar": 7}'))
_patch_transport(monkeypatch, _handler)
client = OllamaClient(base_url="http://ollama:11434", model="m", timeout_s=5)
out = client.chat_json("system msg", "user msg", _Schema)
assert out == _Schema(foo="x", bar=7)
assert captured["url"] == "http://ollama:11434/api/chat"
body = captured["body"]
assert isinstance(body, bytes)
assert b'"format":"json"' in body
assert b'"system msg"' in body
def test_chat_json_raises_on_http_error(monkeypatch: pytest.MonkeyPatch) -> None:
def _handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(500, text="boom")
_patch_transport(monkeypatch, _handler)
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
with pytest.raises(LLMUnavailableError, match="Ollama request failed"):
client.chat_json("s", "u", _Schema)
def test_chat_json_raises_on_invalid_json(monkeypatch: pytest.MonkeyPatch) -> None:
def _handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json=_ollama_envelope("this is not json"))
_patch_transport(monkeypatch, _handler)
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
with pytest.raises(LLMUnavailableError, match="schema"):
client.chat_json("s", "u", _Schema)
def test_chat_json_raises_on_missing_envelope(monkeypatch: pytest.MonkeyPatch) -> None:
def _handler(_request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json={"oops": True})
_patch_transport(monkeypatch, _handler)
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
with pytest.raises(LLMUnavailableError, match=r"message\.content"):
client.chat_json("s", "u", _Schema)
def test_chat_json_raises_on_connection_error(monkeypatch: pytest.MonkeyPatch) -> None:
def _handler(request: httpx.Request) -> httpx.Response:
raise httpx.ConnectError("nobody home", request=request)
_patch_transport(monkeypatch, _handler)
client = OllamaClient(base_url="http://x", model="m", timeout_s=1)
with pytest.raises(LLMUnavailableError):
client.chat_json("s", "u", _Schema)

View File

@@ -0,0 +1,90 @@
"""Unit tests for the hybrid LLM header extractor / merger."""
from __future__ import annotations
from datetime import date
import pytest
from pydantic import BaseModel
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
from ocr_sprint.llm.extractor import LLMHeaderResult, _merge, llm_fill_header
from ocr_sprint.schemas.extraction import HeaderFields
class _StubClient(OllamaClient):
"""Test double that bypasses HTTP entirely."""
def __init__(self, payload: LLMHeaderResult | Exception) -> None:
# Skip the real __init__ — we don't need any real config.
self._payload = payload
def chat_json( # type: ignore[override]
self, system: str, user: str, schema_cls: type[BaseModel]
) -> BaseModel:
if isinstance(self._payload, Exception):
raise self._payload
return self._payload
def test_merge_keeps_regex_when_present() -> None:
regex = HeaderFields(nomor_sprint="Sprin/123/IV/2025/Reskrim", tanggal=date(2025, 4, 21))
llm = LLMHeaderResult(nomor_sprint="HALLUCINATED", tanggal=date(1999, 1, 1), perihal="ok")
out = _merge(regex, llm)
assert out.nomor_sprint == "Sprin/123/IV/2025/Reskrim"
assert out.tanggal == date(2025, 4, 21)
# Gaps get filled.
assert out.perihal == "ok"
def test_merge_fills_gaps() -> None:
regex = HeaderFields() # all None
llm = LLMHeaderResult(
nomor_sprint="Sprin/9/IX/2024",
tanggal=date(2024, 9, 1),
satuan_penerbit="Polres Bandung",
perihal="Penyelidikan",
dasar=["UU 2/2002", "Perkap 6/2017"],
)
out = _merge(regex, llm)
assert out.nomor_sprint == "Sprin/9/IX/2024"
assert out.tanggal == date(2024, 9, 1)
assert out.satuan_penerbit == "Polres Bandung"
assert out.perihal == "Penyelidikan"
assert out.dasar == ["UU 2/2002", "Perkap 6/2017"]
def test_llm_fill_header_returns_merged_when_client_succeeds() -> None:
regex = HeaderFields(nomor_sprint="Sprin/1/I/2025") # has nomor, missing rest
stub = _StubClient(
LLMHeaderResult(
satuan_penerbit="Polres Bandung",
perihal="Penyelidikan",
dasar=["UU 2/2002"],
)
)
out = llm_fill_header(raw_text="...", regex_header=regex, client=stub)
assert out is not None
assert out.nomor_sprint == "Sprin/1/I/2025"
assert out.satuan_penerbit == "Polres Bandung"
assert out.perihal == "Penyelidikan"
assert out.dasar == ["UU 2/2002"]
def test_llm_fill_header_returns_none_when_unavailable() -> None:
stub = _StubClient(LLMUnavailableError("server down"))
out = llm_fill_header(raw_text="...", regex_header=HeaderFields(), client=stub)
assert out is None
def test_merge_does_not_overwrite_dasar_when_regex_has_it() -> None:
regex = HeaderFields(dasar=["UU 2/2002"])
llm = LLMHeaderResult(dasar=["something else", "more"])
out = _merge(regex, llm)
assert out.dasar == ["UU 2/2002"]
def test_llm_extractor_unused_argument_kept_silent() -> None:
# A trivial sanity check that the public function signature accepts
# keyword-only `client` — this matches how the orchestrator calls it.
pytest.importorskip("ocr_sprint.llm.extractor")

View File

@@ -0,0 +1,171 @@
"""Orchestrator-level tests for the Phase 5 hybrid LLM wiring.
These tests stub out the heavy stages (ingest / preprocess / OCR / table)
so we can verify the *branching* behaviour around the LLM step without
booting Paddle.
"""
from __future__ import annotations
from datetime import date
import pytest
from ocr_sprint.pipeline import orchestrator as orch_module
from ocr_sprint.pipeline.orchestrator import _header_has_gaps, run_pipeline
from ocr_sprint.schemas.document import SourceKind
from ocr_sprint.schemas.extraction import HeaderFields, ReviewFlag, Signatory
def test_header_has_gaps_detects_missing_fields() -> None:
full = HeaderFields(
nomor_sprint="Sprin/1/I/2025",
tanggal=date(2025, 1, 1),
satuan_penerbit="Polres X",
perihal="ok",
dasar=["UU 2/2002"],
)
assert _header_has_gaps(full) is False
assert _header_has_gaps(HeaderFields()) is True
assert _header_has_gaps(full.model_copy(update={"perihal": None})) is True
assert _header_has_gaps(full.model_copy(update={"dasar": []})) is True
def _stub_pipeline_stages(
monkeypatch: pytest.MonkeyPatch,
*,
raw_text: str,
regex_header: HeaderFields,
) -> None:
"""Replace ingest -> ocr -> tables with cheap fakes so the orchestrator
runs without Paddle / PyMuPDF.
"""
import numpy as np
from ocr_sprint.pipeline import ingest as ingest_module
from ocr_sprint.pipeline import ocr as ocr_module
from ocr_sprint.pipeline.ingest import IngestedPage
img = np.full((100, 100, 3), 255, dtype=np.uint8)
fake_page = IngestedPage(image=img, page_index=0)
fake_ocr_page = ocr_module.OCRPage(
lines=[
ocr_module.OCRLine(text=raw_text, confidence=0.95, box=((0, 0), (1, 0), (1, 1), (0, 1)))
],
)
monkeypatch.setattr(orch_module, "detect_source_kind", lambda _: SourceKind.PDF)
monkeypatch.setattr(orch_module, "ingest", lambda *a, **k: [fake_page])
monkeypatch.setattr(orch_module, "detect_and_correct", lambda image, _cfg: image)
monkeypatch.setattr(orch_module, "preprocess", lambda image, _cfg: image)
monkeypatch.setattr(orch_module, "run_ocr", lambda _image: fake_ocr_page)
# No tables in these tests.
monkeypatch.setattr(orch_module, "run_table_extraction", lambda _img: [])
monkeypatch.setattr(orch_module, "extract_personnel", lambda _tables: [])
# Header / signatory / validators come from the real implementation
# for `extract_header`, but we override to control gap state.
monkeypatch.setattr(orch_module, "extract_header", lambda _text: regex_header)
monkeypatch.setattr(orch_module, "find_signatory", lambda _text: Signatory())
monkeypatch.setattr(orch_module, "validate_extraction", lambda _result: [])
# Keep ingest_module referenced so import isn't dropped.
assert ingest_module is not None
def test_orchestrator_skips_llm_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("LLM_ENABLED", "false")
from ocr_sprint.config import get_settings
get_settings.cache_clear()
_stub_pipeline_stages(
monkeypatch,
raw_text="dummy",
regex_header=HeaderFields(), # all gaps
)
called = {"n": 0}
def _trip(*_args: object, **_kwargs: object) -> None:
called["n"] += 1
return None
monkeypatch.setattr(orch_module, "llm_fill_header", _trip)
result = run_pipeline(b"%PDF-1.4\n%fake")
assert called["n"] == 0
assert ReviewFlag.LLM_FALLBACK not in result.result.review_flags
assert ReviewFlag.LLM_UNAVAILABLE not in result.result.review_flags
def test_orchestrator_skips_llm_when_header_complete(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("LLM_ENABLED", "true")
from ocr_sprint.config import get_settings
get_settings.cache_clear()
_stub_pipeline_stages(
monkeypatch,
raw_text="dummy",
regex_header=HeaderFields(
nomor_sprint="Sprin/1/I/2025",
tanggal=date(2025, 1, 1),
satuan_penerbit="Polres X",
perihal="ok",
dasar=["UU 2/2002"],
),
)
called = {"n": 0}
def _trip(*_args: object, **_kwargs: object) -> None:
called["n"] += 1
return None
monkeypatch.setattr(orch_module, "llm_fill_header", _trip)
run_pipeline(b"%PDF-1.4\n%fake")
assert called["n"] == 0
def test_orchestrator_calls_llm_and_marks_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("LLM_ENABLED", "true")
from ocr_sprint.config import get_settings
get_settings.cache_clear()
regex_partial = HeaderFields(nomor_sprint="Sprin/1/I/2025") # rest missing
_stub_pipeline_stages(monkeypatch, raw_text="dummy text", regex_header=regex_partial)
def _llm(_raw: str, header: HeaderFields, **_: object) -> HeaderFields:
return header.model_copy(
update={
"satuan_penerbit": "Polres Bandung",
"perihal": "Penyelidikan",
"dasar": ["UU 2/2002"],
}
)
monkeypatch.setattr(orch_module, "llm_fill_header", _llm)
out = run_pipeline(b"%PDF-1.4\n%fake")
assert out.result.header.satuan_penerbit == "Polres Bandung"
assert out.result.header.perihal == "Penyelidikan"
assert ReviewFlag.LLM_FALLBACK in out.result.review_flags
assert ReviewFlag.LLM_UNAVAILABLE not in out.result.review_flags
def test_orchestrator_marks_unavailable_when_llm_returns_none(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setenv("LLM_ENABLED", "true")
from ocr_sprint.config import get_settings
get_settings.cache_clear()
_stub_pipeline_stages(monkeypatch, raw_text="dummy", regex_header=HeaderFields())
monkeypatch.setattr(orch_module, "llm_fill_header", lambda *_a, **_k: None)
out = run_pipeline(b"%PDF-1.4\n%fake")
assert ReviewFlag.LLM_UNAVAILABLE in out.result.review_flags
assert ReviewFlag.LLM_FALLBACK not in out.result.review_flags