diff --git a/alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py b/alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py new file mode 100644 index 0000000..420db9f --- /dev/null +++ b/alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py @@ -0,0 +1,60 @@ +"""phase6 hitl: job_corrections + approval columns + +Revision ID: 3b1f2c9a4d56 +Revises: ff8c14fbf8a0 +Create Date: 2026-04-25 14:30:00.000000 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "3b1f2c9a4d56" +down_revision: str | None = "ff8c14fbf8a0" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + with op.batch_alter_table("jobs") as batch: + batch.add_column( + sa.Column( + "approved", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ) + ) + batch.add_column(sa.Column("reviewed_by", sa.String(length=128), nullable=True)) + batch.add_column(sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=True)) + + op.create_table( + "job_corrections", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("job_id", sa.Uuid(), nullable=False), + sa.Column("field_path", sa.String(length=256), nullable=False), + sa.Column("old_value", sa.JSON(), nullable=True), + sa.Column("new_value", sa.JSON(), nullable=True), + sa.Column("corrected_by", sa.String(length=128), nullable=True), + sa.Column("reason", sa.String(length=512), nullable=True), + sa.Column("corrected_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(["job_id"], ["jobs.job_id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_job_corrections_job_id"), + "job_corrections", + ["job_id"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index(op.f("ix_job_corrections_job_id"), table_name="job_corrections") + op.drop_table("job_corrections") + with op.batch_alter_table("jobs") as batch: + batch.drop_column("reviewed_at") + batch.drop_column("reviewed_by") + batch.drop_column("approved") diff --git a/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py b/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py index 8ffd0ab..fdfc5b0 100644 --- a/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py +++ b/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py @@ -1,17 +1,17 @@ """phase4 jobs table Revision ID: ff8c14fbf8a0 -Revises: +Revises: Create Date: 2026-04-25 15:54:18.579147 """ + from collections.abc import Sequence -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'ff8c14fbf8a0' +revision: str = "ff8c14fbf8a0" down_revision: str | None = None branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None @@ -19,24 +19,25 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('jobs', - sa.Column('job_id', sa.Uuid(), nullable=False), - sa.Column('status', sa.String(length=32), nullable=False), - sa.Column('source_kind', sa.String(length=16), nullable=False), - sa.Column('filename', sa.String(length=512), nullable=False), - sa.Column('blob_key', sa.String(length=512), nullable=True), - sa.Column('confidence', sa.Float(), nullable=True), - sa.Column('review_flags', sa.JSON(), nullable=False), - sa.Column('result', sa.JSON(), nullable=True), - sa.Column('error', sa.String(length=2048), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.PrimaryKeyConstraint('job_id') + op.create_table( + "jobs", + sa.Column("job_id", sa.Uuid(), nullable=False), + sa.Column("status", sa.String(length=32), nullable=False), + sa.Column("source_kind", sa.String(length=16), nullable=False), + sa.Column("filename", sa.String(length=512), nullable=False), + sa.Column("blob_key", sa.String(length=512), nullable=True), + sa.Column("confidence", sa.Float(), nullable=True), + sa.Column("review_flags", sa.JSON(), nullable=False), + sa.Column("result", sa.JSON(), nullable=True), + sa.Column("error", sa.String(length=2048), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("job_id"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('jobs') + op.drop_table("jobs") # ### end Alembic commands ### diff --git a/src/ocr_sprint/api/routes/documents.py b/src/ocr_sprint/api/routes/documents.py index 018d00c..195b4dc 100644 --- a/src/ocr_sprint/api/routes/documents.py +++ b/src/ocr_sprint/api/routes/documents.py @@ -22,7 +22,17 @@ from __future__ import annotations from typing import Annotated from uuid import UUID, uuid4 -from fastapi import APIRouter, Depends, File, HTTPException, Query, Response, UploadFile, status +from fastapi import ( + APIRouter, + Depends, + File, + Header, + HTTPException, + Query, + Response, + UploadFile, + status, +) from sqlalchemy.orm import Session from ocr_sprint.api.deps.auth import require_api_key @@ -31,11 +41,22 @@ from ocr_sprint.api.errors import UnsupportedDocumentError from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS from ocr_sprint.config import get_settings from ocr_sprint.db.base import session_scope -from ocr_sprint.db.repositories import JobNotFoundError, JobRepository +from ocr_sprint.db.repositories import ( + InvalidFieldPathError, + JobAlreadyApprovedError, + JobNotCompletedError, + JobNotFoundError, + JobRepository, +) from ocr_sprint.pipeline.ingest import detect_source_kind from ocr_sprint.pipeline.orchestrator import run_pipeline from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus from ocr_sprint.schemas.extraction import ExtractionResult +from ocr_sprint.schemas.review import ( + ApprovalResponse, + CorrectionEventResponse, + CorrectionRequest, +) from ocr_sprint.storage.blob import get_blob_storage from ocr_sprint.utils.logging import get_logger @@ -75,6 +96,9 @@ def _row_to_response(row: object) -> DocumentResponse: data=result_obj, review_flags=list(row.review_flags or []), error=row.error, + approved=bool(row.approved), + reviewed_by=row.reviewed_by, + reviewed_at=row.reviewed_at, ) @@ -192,3 +216,116 @@ async def get_document( except JobNotFoundError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc return _row_to_response(row) + + +# ---------- Phase 6 — HITL ---------- + + +def _correction_row_to_response(row: object) -> CorrectionEventResponse: + # Local import to avoid a cyclic import at module load time. + from ocr_sprint.db.models import JobCorrectionRow + + assert isinstance(row, JobCorrectionRow) + return CorrectionEventResponse( + id=row.id, + job_id=row.job_id, + field_path=row.field_path, + old_value=row.old_value, + new_value=row.new_value, + corrected_by=row.corrected_by, + reason=row.reason, + corrected_at=row.corrected_at, + ) + + +@router.patch( + "/{job_id}", + response_model=DocumentResponse, +) +async def patch_document( + job_id: UUID, + body: CorrectionRequest, + session: Annotated[Session, Depends(get_session)], + x_user_id: Annotated[ + str | None, + Header(description="Free-form reviewer identifier recorded on the audit row."), + ] = None, +) -> DocumentResponse: + """Apply one or more field-level corrections and record an audit trail. + + The whole batch is applied atomically — if any path is invalid the + request fails with 400 and no side effects are written. Returns the + updated document so the client doesn't need a follow-up GET. + """ + repo = JobRepository(session) + try: + repo.apply_corrections( + job_id, + corrections=[(c.path, c.value, c.reason) for c in body.corrections], + corrected_by=x_user_id, + ) + row = repo.get_or_raise(job_id) + except JobNotFoundError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + except InvalidFieldPathError as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + except JobAlreadyApprovedError as exc: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc + except JobNotCompletedError as exc: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc + + _logger.info( + "documents.patched", + job_id=str(job_id), + count=len(body.corrections), + corrected_by=x_user_id or "", + ) + return _row_to_response(row) + + +@router.get( + "/{job_id}/history", + response_model=list[CorrectionEventResponse], +) +async def get_history( + job_id: UUID, + session: Annotated[Session, Depends(get_session)], +) -> list[CorrectionEventResponse]: + repo = JobRepository(session) + try: + rows = repo.list_corrections(job_id) + except JobNotFoundError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + return [_correction_row_to_response(r) for r in rows] + + +@router.post( + "/{job_id}/approve", + response_model=ApprovalResponse, +) +async def approve_document( + job_id: UUID, + session: Annotated[Session, Depends(get_session)], + x_user_id: Annotated[ + str | None, + Header(description="Free-form reviewer identifier recorded on the job."), + ] = None, +) -> ApprovalResponse: + """Lock a job's final version. Idempotent: re-approving returns the + existing row without overwriting ``reviewed_at``. + """ + repo = JobRepository(session) + try: + row = repo.approve(job_id, reviewed_by=x_user_id) + except JobNotFoundError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + except JobNotCompletedError as exc: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc + + _logger.info("documents.approved", job_id=str(job_id), reviewed_by=row.reviewed_by or "") + return ApprovalResponse( + job_id=row.job_id, + approved=bool(row.approved), + reviewed_by=row.reviewed_by, + reviewed_at=row.reviewed_at, + ) diff --git a/src/ocr_sprint/db/models.py b/src/ocr_sprint/db/models.py index 0f36202..f5b3f82 100644 --- a/src/ocr_sprint/db/models.py +++ b/src/ocr_sprint/db/models.py @@ -16,7 +16,7 @@ from datetime import datetime, timezone from typing import Any from uuid import UUID, uuid4 -from sqlalchemy import JSON, DateTime, Float, String, Uuid +from sqlalchemy import JSON, Boolean, DateTime, Float, ForeignKey, Integer, String, Uuid from sqlalchemy.orm import Mapped, mapped_column from ocr_sprint.db.base import Base @@ -42,6 +42,15 @@ class JobRow(Base): result: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) error: Mapped[str | None] = mapped_column(String(2048), nullable=True) + # Phase 6 — HITL review state. + # Once ``approved=True`` the row is immutable except to admin users; + # corrections after that point are rejected by the route. ``reviewed_by`` + # stores the free-form user identifier the reviewer sent via the + # ``X-User-Id`` header (best-effort attribution — no full RBAC yet). + approved: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + reviewed_by: Mapped[str | None] = mapped_column(String(128), nullable=True) + reviewed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False, default=_utcnow ) @@ -51,3 +60,37 @@ class JobRow(Base): def __repr__(self) -> str: return f"JobRow(job_id={self.job_id!s}, status={self.status!r})" + + +class JobCorrectionRow(Base): + """One correction event on a job's ``result``. + + Each PATCH call writes one row per changed field path so we have a + full audit trail. Rows are append-only — never updated, never + deleted — so the history is reproducible and usable as ground-truth + data for future fine-tuning. + """ + + __tablename__ = "job_corrections" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + job_id: Mapped[UUID] = mapped_column( + Uuid, ForeignKey("jobs.job_id", ondelete="CASCADE"), nullable=False, index=True + ) + # Dotted JSON path into ExtractionResult, e.g. "header.nomor_sprint" or + # "personel[3].nrp". Kept as a plain string for simplicity — we don't + # parse it server-side beyond the allow-list check in the repository. + field_path: Mapped[str] = mapped_column(String(256), nullable=False) + old_value: Mapped[Any | None] = mapped_column(JSON, nullable=True) + new_value: Mapped[Any | None] = mapped_column(JSON, nullable=True) + corrected_by: Mapped[str | None] = mapped_column(String(128), nullable=True) + reason: Mapped[str | None] = mapped_column(String(512), nullable=True) + corrected_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, default=_utcnow + ) + + def __repr__(self) -> str: + return ( + f"JobCorrectionRow(job_id={self.job_id!s}, " + f"field={self.field_path!r}, by={self.corrected_by!r})" + ) diff --git a/src/ocr_sprint/db/repositories.py b/src/ocr_sprint/db/repositories.py index a17817a..248bd18 100644 --- a/src/ocr_sprint/db/repositories.py +++ b/src/ocr_sprint/db/repositories.py @@ -6,6 +6,9 @@ know about sessions, transactions, or the row → schema mapping. from __future__ import annotations +import copy +import re +from dataclasses import dataclass from datetime import datetime, timezone from typing import Any from uuid import UUID @@ -13,7 +16,7 @@ from uuid import UUID from sqlalchemy import select from sqlalchemy.orm import Session -from ocr_sprint.db.models import JobRow +from ocr_sprint.db.models import JobCorrectionRow, JobRow from ocr_sprint.schemas.document import DocumentStatus, SourceKind @@ -25,6 +28,110 @@ class JobNotFoundError(LookupError): """Raised by API code when GET /documents/{id} hits a missing row.""" +class InvalidFieldPathError(ValueError): + """Raised when a PATCH request references an unsupported field path.""" + + +class JobAlreadyApprovedError(RuntimeError): + """Raised when a PATCH is attempted against an already-approved job.""" + + +class JobNotCompletedError(RuntimeError): + """Raised when a PATCH/approve is attempted against a job that hasn't + produced a ``result`` payload yet (e.g. still pending or failed). + """ + + +@dataclass(frozen=True) +class AppliedCorrection: + """Internal record of a correction that successfully applied to the + in-memory ``result`` dict. The repository turns this into a persisted + ``JobCorrectionRow`` after the whole batch is validated. + """ + + field_path: str + old_value: Any + new_value: Any + reason: str | None + + +# Allow-list of top-level keys we let reviewers edit. Keeps the attack +# surface small: they can't inject arbitrary fields into the JSON blob. +_ALLOWED_ROOTS: frozenset[str] = frozenset({"header", "ttd", "personel", "untuk"}) + +# Matches a single path segment like ``personel[3]`` — supports at most one +# index per segment, enough for the list fields we care about. +_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$") + + +def _split_path(path: str) -> list[tuple[str, int | None]]: + """Parse ``header.nomor_sprint`` or ``personel[2].nrp`` into segments. + + Returns list of ``(name, index_or_none)`` tuples. Raises + ``InvalidFieldPathError`` on malformed input so the caller can surface + a 400 to the client. + """ + if not path or path.startswith(".") or path.endswith("."): + raise InvalidFieldPathError(f"Invalid field path: {path!r}") + + parts = path.split(".") + out: list[tuple[str, int | None]] = [] + for part in parts: + match = _SEGMENT_RE.match(part) + if match is None: + raise InvalidFieldPathError(f"Invalid segment in path: {part!r}") + name = match.group(1) + idx_raw = match.group(2) + idx = int(idx_raw) if idx_raw is not None else None + out.append((name, idx)) + + if out[0][0] not in _ALLOWED_ROOTS: + raise InvalidFieldPathError( + f"Field path root {out[0][0]!r} not in allowed roots {sorted(_ALLOWED_ROOTS)!r}" + ) + return out + + +def _apply_path(data: dict[str, Any], path: str, new_value: Any) -> Any: + """Apply a single correction to ``data`` in place. Returns the old + value so the caller can record it in the audit row. + + Does NOT validate that the new value matches the field's expected + type — that's the reviewer's responsibility; the whole point of HITL + is to let humans override the model's typing. + """ + segments = _split_path(path) + cursor: Any = data + for name, idx in segments[:-1]: + if not isinstance(cursor, dict) or name not in cursor: + raise InvalidFieldPathError(f"Cannot traverse to {path!r}: missing {name!r}") + cursor = cursor[name] + if idx is not None: + if not isinstance(cursor, list) or idx >= len(cursor): + raise InvalidFieldPathError( + f"Cannot traverse to {path!r}: index [{idx}] out of range" + ) + cursor = cursor[idx] + + name, idx = segments[-1] + if idx is not None: + # Terminal segment is a list-element, e.g. ``untuk[2]``. + if not isinstance(cursor, dict) or name not in cursor: + raise InvalidFieldPathError(f"Cannot apply to {path!r}: missing container {name!r}") + container = cursor[name] + if not isinstance(container, list) or idx >= len(container): + raise InvalidFieldPathError(f"Cannot apply to {path!r}: index [{idx}] out of range") + old = container[idx] + container[idx] = new_value + return old + + if not isinstance(cursor, dict): + raise InvalidFieldPathError(f"Cannot apply to {path!r}: parent is not an object") + old = cursor.get(name) + cursor[name] = new_value + return old + + class JobRepository: """SQL-backed repository for `jobs` rows.""" @@ -94,3 +201,139 @@ class JobRepository: if row is None: raise JobNotFoundError(f"Job not found: {job_id}") return row + + # ---------- Phase 6 — HITL ---------- + + def apply_corrections( + self, + job_id: UUID, + *, + corrections: list[tuple[str, Any, str | None]], + corrected_by: str | None, + ) -> list[JobCorrectionRow]: + """Apply a batch of field corrections atomically. + + ``corrections`` is a list of ``(path, new_value, reason)`` tuples. + Returns the persisted audit rows so the caller can surface them in + the response. + + Raises + ------ + JobNotFoundError + If the row doesn't exist. + JobNotCompletedError + If the job hasn't produced a result yet (status pending / + processing / failed). + JobAlreadyApprovedError + If the job has been approved — edits are locked. + InvalidFieldPathError + If any path is malformed or references a disallowed root. + """ + row = self._get_or_raise(job_id) + if row.result is None: + raise JobNotCompletedError( + f"Job {job_id} has no result to correct (status={row.status})" + ) + if row.approved: + raise JobAlreadyApprovedError(f"Job {job_id} is already approved; edits are locked") + + # Deep-copy so we can roll back in memory if any correction fails. + # The underlying JSON column will only be re-assigned once every + # path applied cleanly. + working = copy.deepcopy(row.result) + applied: list[AppliedCorrection] = [] + for path, new_value, reason in corrections: + old_value = _apply_path(working, path, new_value) + applied.append( + AppliedCorrection( + field_path=path, old_value=old_value, new_value=new_value, reason=reason + ) + ) + + # Persist audit rows first; if they fail the session rollback also + # undoes the result-column update we're about to do. + persisted: list[JobCorrectionRow] = [] + for event in applied: + row_event = JobCorrectionRow( + job_id=job_id, + field_path=event.field_path, + old_value=event.old_value, + new_value=event.new_value, + corrected_by=corrected_by, + reason=event.reason, + ) + self.session.add(row_event) + persisted.append(row_event) + + # Clear review flags that the correction has resolved. Right now we + # only auto-clear MISSING_FIELD when any corrected field previously + # held a null/empty value — the reviewer explicitly filled a gap. + row.result = working + row.review_flags = _recompute_flags( + original_flags=list(row.review_flags or []), + applied=applied, + working_result=working, + ) + row.updated_at = _utcnow() + + self.session.flush() + return persisted + + def list_corrections(self, job_id: UUID) -> list[JobCorrectionRow]: + """Return the full audit trail for ``job_id`` in chronological order.""" + # ``get_or_raise`` so callers get a 404 instead of an empty list + # when the job itself doesn't exist. + self._get_or_raise(job_id) + stmt = ( + select(JobCorrectionRow) + .where(JobCorrectionRow.job_id == job_id) + .order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id) + ) + return list(self.session.scalars(stmt)) + + def approve(self, job_id: UUID, *, reviewed_by: str | None) -> JobRow: + """Mark a job as approved. Idempotent — re-approving is a no-op + that keeps the original ``reviewed_at`` (so the audit trail stays + intact). + """ + row = self._get_or_raise(job_id) + if row.result is None: + raise JobNotCompletedError( + f"Job {job_id} has no result to approve (status={row.status})" + ) + if row.approved: + return row + row.approved = True + row.reviewed_by = reviewed_by + row.reviewed_at = _utcnow() + row.updated_at = row.reviewed_at + return row + + +def _recompute_flags( + *, + original_flags: list[str], + applied: list[AppliedCorrection], + working_result: dict[str, Any], +) -> list[str]: + """Update review flags in light of the corrections just applied. + + Keeps the policy simple on purpose: + * ``missing_field`` is removed if after the edit every required + header field is non-empty. + * Other flags stay untouched — the reviewer should either correct the + underlying issue (which this helper can detect) or explicitly + approve the result as-is (which bypasses the flag list). + """ + flags = list(original_flags) + if "missing_field" in flags: + header = working_result.get("header") or {} + filled = all(bool(header.get(key)) for key in ("nomor_sprint", "satuan_penerbit")) + if filled: + flags = [f for f in flags if f != "missing_field"] + + # ``applied`` isn't used directly in this MVP rule, but we keep the + # parameter so future policies can inspect exactly what changed + # without re-diffing the blob. + _ = applied + return flags diff --git a/src/ocr_sprint/llm/__init__.py b/src/ocr_sprint/llm/__init__.py new file mode 100644 index 0000000..9aa1df5 --- /dev/null +++ b/src/ocr_sprint/llm/__init__.py @@ -0,0 +1,18 @@ +"""LLM-based extraction (Phase 5). + +The hybrid extractor first runs the deterministic regex layer and then — +only for fields that came back missing or low-confidence — calls a local +Ollama model with a Pydantic-typed prompt. Everything is gated by +``LLM_ENABLED``; if the flag is off or the Ollama server is unreachable, +the pipeline degrades gracefully back to the regex result. +""" + +from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient +from ocr_sprint.llm.extractor import LLMHeaderResult, llm_fill_header + +__all__ = [ + "LLMHeaderResult", + "LLMUnavailableError", + "OllamaClient", + "llm_fill_header", +] diff --git a/src/ocr_sprint/llm/client.py b/src/ocr_sprint/llm/client.py new file mode 100644 index 0000000..d5dd1a8 --- /dev/null +++ b/src/ocr_sprint/llm/client.py @@ -0,0 +1,97 @@ +"""Ollama HTTP client. + +We deliberately avoid the ``ollama`` Python package — the wire format is a +single ``POST /api/chat`` with ``format="json"`` and a system + user message, +so a small ``httpx`` wrapper is enough. This keeps the runtime dependency +footprint smaller and makes the mock-based unit tests trivial. +""" + +from __future__ import annotations + +from typing import TypeVar + +import httpx +from pydantic import BaseModel, ValidationError + +from ocr_sprint.config import get_settings +from ocr_sprint.utils.logging import get_logger + +_logger = get_logger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class LLMUnavailableError(RuntimeError): + """Raised when the Ollama server is unreachable, times out, or returns + a malformed payload. The pipeline catches this and falls back to the + regex-only result with a ``llm_fallback`` review flag. + """ + + +class OllamaClient: + """Tiny synchronous HTTP wrapper around the Ollama ``/api/chat`` endpoint. + + Parameters + ---------- + base_url: + Ollama server URL, e.g. ``http://localhost:11434``. + model: + Model tag to invoke (default ``qwen2.5:1.5b`` — chosen for CPU + latency at acceptable accuracy). + timeout_s: + Hard wall-clock timeout for a single request. + """ + + def __init__( + self, + base_url: str | None = None, + model: str | None = None, + timeout_s: int | None = None, + ) -> None: + s = get_settings() + self.base_url = (base_url or s.llm_base_url).rstrip("/") + self.model = model or s.llm_model + self.timeout_s = timeout_s if timeout_s is not None else s.llm_timeout_s + + # ---------- public API ---------- + + def chat_json(self, system: str, user: str, schema_cls: type[T]) -> T: + """Run a single chat completion in JSON mode and validate the + response against ``schema_cls``. Raises ``LLMUnavailableError`` on + any transport / parse / validation failure so callers only have one + exception to handle. + """ + payload = { + "model": self.model, + "stream": False, + "format": "json", + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + # Keep determinism reasonable — we want extraction, not creativity. + "options": {"temperature": 0.0, "num_ctx": 4096}, + } + url = f"{self.base_url}/api/chat" + + try: + with httpx.Client(timeout=self.timeout_s) as client: + response = client.post(url, json=payload) + response.raise_for_status() + data = response.json() + except (httpx.HTTPError, ValueError) as exc: + _logger.warning("llm.transport_error", url=url, error=str(exc)) + raise LLMUnavailableError(f"Ollama request failed: {exc}") from exc + + # Ollama returns {"message": {"role": "assistant", "content": ""}}. + try: + content = data["message"]["content"] + except (KeyError, TypeError) as exc: + _logger.warning("llm.bad_envelope", payload=data) + raise LLMUnavailableError(f"Ollama response missing message.content: {data!r}") from exc + + try: + return schema_cls.model_validate_json(content) + except ValidationError as exc: + _logger.warning("llm.validation_error", error=str(exc), content=content[:400]) + raise LLMUnavailableError(f"LLM JSON failed schema: {exc}") from exc diff --git a/src/ocr_sprint/llm/extractor.py b/src/ocr_sprint/llm/extractor.py new file mode 100644 index 0000000..31f58ca --- /dev/null +++ b/src/ocr_sprint/llm/extractor.py @@ -0,0 +1,84 @@ +"""High-level LLM extractor. + +The job is *narrow*: take the raw OCR text plus the partial header that +came back from the regex layer, and return an LLM-derived header that the +caller can merge in. We never let the LLM populate the personnel table — +PP-Structure is more accurate and cheaper for that. +""" + +from __future__ import annotations + +from datetime import date + +from pydantic import BaseModel, Field + +from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient +from ocr_sprint.llm.prompts import SYSTEM_HEADER, build_user_prompt +from ocr_sprint.schemas.extraction import HeaderFields +from ocr_sprint.utils.logging import get_logger + +_logger = get_logger(__name__) + + +class LLMHeaderResult(BaseModel): + """Schema we ask the model to fill. Mirrors ``HeaderFields`` but is + intentionally separate so we control exactly what the prompt and + validation surface look like — the public ``HeaderFields`` may grow + fields later that we don't want the LLM touching. + """ + + nomor_sprint: str | None = None + tanggal: date | None = None + satuan_penerbit: str | None = None + perihal: str | None = None + dasar: list[str] = Field(default_factory=list) + + +def llm_fill_header( + raw_text: str, + regex_header: HeaderFields, + *, + client: OllamaClient | None = None, +) -> HeaderFields | None: + """Run the LLM extractor and return a *merged* HeaderFields. + + Returns ``None`` if the model is unavailable so the caller can decide + what to do (typically: keep the regex result and emit a fallback + review flag). + """ + client = client or OllamaClient() + + user = build_user_prompt( + raw_text=raw_text, + regex_partial=regex_header.model_dump(mode="json"), + ) + + try: + llm = client.chat_json(SYSTEM_HEADER, user, LLMHeaderResult) + except LLMUnavailableError as exc: + _logger.warning("llm.unavailable", error=str(exc)) + return None + + return _merge(regex_header, llm) + + +def _merge(regex: HeaderFields, llm: LLMHeaderResult) -> HeaderFields: + """Merge LLM output into the regex result. + + Policy: regex wins for any field it already filled. The LLM only fills + the *gaps*. This keeps deterministic / verifiable extractions for the + fields where regex is reliable and prevents the LLM from "correcting" + a value that happens to look unusual but is in fact correct. + """ + merged = regex.model_copy(deep=True) + if merged.nomor_sprint is None and llm.nomor_sprint: + merged.nomor_sprint = llm.nomor_sprint + if merged.tanggal is None and llm.tanggal is not None: + merged.tanggal = llm.tanggal + if not merged.satuan_penerbit and llm.satuan_penerbit: + merged.satuan_penerbit = llm.satuan_penerbit + if not merged.perihal and llm.perihal: + merged.perihal = llm.perihal + if not merged.dasar and llm.dasar: + merged.dasar = list(llm.dasar) + return merged diff --git a/src/ocr_sprint/llm/prompts.py b/src/ocr_sprint/llm/prompts.py new file mode 100644 index 0000000..b74081f --- /dev/null +++ b/src/ocr_sprint/llm/prompts.py @@ -0,0 +1,48 @@ +"""Prompt builders for the LLM extractor. + +Kept in their own module so the prompts can be edited / version-tracked +without touching the orchestration logic. We build prompts in Indonesian +because the source documents are too — the model performs better when the +field labels in the prompt match the OCR text it's being asked about. +""" + +from __future__ import annotations + +SYSTEM_HEADER = ( + "Anda adalah asisten ekstraksi data untuk dokumen Surat Perintah (Sprint) " + "Kepolisian Republik Indonesia (POLRI). Pengguna akan memberikan teks hasil " + "OCR sebuah surat sprint, dan Anda harus mengembalikan JSON yang sesuai " + "dengan skema yang diberikan.\n\n" + "Aturan keras:\n" + "1. Jangan mengarang. Jika sebuah field tidak terlihat di teks, kembalikan null.\n" + "2. Jangan menerjemahkan field. Output harus identik ejaannya dengan teks " + "sumber (kecuali normalisasi spasi/kapitalisasi yang jelas hasil OCR error).\n" + "3. Tanggal: kembalikan format ISO YYYY-MM-DD jika tanggal terlihat, " + "selain itu null.\n" + "4. Dasar hukum: array string berisi tiap butir, urut sesuai teks.\n" + "5. Jangan menambahkan field apa pun di luar skema. Output WAJIB JSON valid." +) + + +def build_user_prompt(raw_text: str, regex_partial: dict[str, object]) -> str: + """Construct the user message: OCR text + a hint about which fields the + deterministic regex layer already filled. Telling the LLM what we + *already have* keeps it from "creatively" overwriting good values. + """ + known_fields = "\n".join(f" - {k}: {v!r}" for k, v in sorted(regex_partial.items()) if v) + known_block = ( + f"\nField yang sudah berhasil diekstrak dengan regex:\n{known_fields}\n" + if known_fields + else "" + ) + + return ( + "Teks OCR:\n" + "----------\n" + f"{raw_text}\n" + "----------\n" + f"{known_block}" + "Tugas: kembalikan JSON dengan field nomor_sprint, tanggal (ISO date | null), " + "satuan_penerbit, perihal, dasar (array string). Hanya field yang terlihat — " + "yang tidak ada di teks isi null (atau array kosong untuk dasar)." + ) diff --git a/src/ocr_sprint/pipeline/orchestrator.py b/src/ocr_sprint/pipeline/orchestrator.py index f42e810..231aec1 100644 --- a/src/ocr_sprint/pipeline/orchestrator.py +++ b/src/ocr_sprint/pipeline/orchestrator.py @@ -15,6 +15,7 @@ from __future__ import annotations from dataclasses import dataclass from ocr_sprint.config import get_settings +from ocr_sprint.llm.extractor import llm_fill_header from ocr_sprint.pipeline.confidence import compute_confidence, route from ocr_sprint.pipeline.document_detect import DocumentDetectConfig, detect_and_correct from ocr_sprint.pipeline.extract.personnel import extract_personnel @@ -35,6 +36,18 @@ _logger = get_logger(__name__) _OCR_CONFIDENCE_FLAG_THRESHOLD = 0.80 +def _header_has_gaps(header: object) -> bool: + """True if any header field worth asking the LLM about is missing. + + Using ``getattr`` so this stays decoupled from the exact attribute + names; the schema change cost was too large last time we hard-coded. + """ + for field in ("nomor_sprint", "tanggal", "satuan_penerbit", "perihal"): + if not getattr(header, field, None): + return True + return not getattr(header, "dasar", None) + + @dataclass class PipelineOutput: """Bundle returned by the orchestrator.""" @@ -84,6 +97,20 @@ def run_pipeline(content: bytes) -> PipelineOutput: header = extract_header(full_text) ttd = find_signatory(full_text) + # Phase 5 — hybrid extraction. The regex layer is deterministic but + # brittle to layout variants between satuan; if any header field is + # still missing we ask the local LLM to fill the gaps. The merger + # never lets the LLM overwrite a field that regex already captured. + llm_flags: list[ReviewFlag] = [] + if s.llm_enabled and _header_has_gaps(header): + merged = llm_fill_header(full_text, header) + if merged is None: + llm_flags.append(ReviewFlag.LLM_UNAVAILABLE) + else: + if merged.model_dump() != header.model_dump(): + llm_flags.append(ReviewFlag.LLM_FALLBACK) + header = merged + personel: list[PersonnelEntry] = [] if s.tables_enabled and cleaned_pages: all_tables: list[DetectedTable] = [] @@ -99,7 +126,7 @@ def run_pipeline(content: bytes) -> PipelineOutput: personel_rows=len(personel), ) - initial_flags: list[ReviewFlag] = [] + initial_flags: list[ReviewFlag] = list(llm_flags) if mean_ocr_conf < _OCR_CONFIDENCE_FLAG_THRESHOLD: initial_flags.append(ReviewFlag.LOW_OCR_CONFIDENCE) diff --git a/src/ocr_sprint/schemas/document.py b/src/ocr_sprint/schemas/document.py index c59b8b7..3269539 100644 --- a/src/ocr_sprint/schemas/document.py +++ b/src/ocr_sprint/schemas/document.py @@ -55,3 +55,7 @@ class DocumentResponse(BaseModel): data: ExtractionResult | None = None review_flags: list[str] = Field(default_factory=list) error: str | None = None + # Phase 6 — HITL review state. + approved: bool = False + reviewed_by: str | None = None + reviewed_at: datetime | None = None diff --git a/src/ocr_sprint/schemas/extraction.py b/src/ocr_sprint/schemas/extraction.py index 1311faa..5a3cdb0 100644 --- a/src/ocr_sprint/schemas/extraction.py +++ b/src/ocr_sprint/schemas/extraction.py @@ -19,6 +19,8 @@ class ReviewFlag(str, Enum): UNKNOWN_PANGKAT = "unknown_pangkat" PERSONNEL_COUNT_MISMATCH = "personnel_count_mismatch" DATE_PARSE_FAILED = "date_parse_failed" + LLM_FALLBACK = "llm_fallback" + LLM_UNAVAILABLE = "llm_unavailable" class Signatory(BaseModel): diff --git a/src/ocr_sprint/schemas/review.py b/src/ocr_sprint/schemas/review.py new file mode 100644 index 0000000..5031665 --- /dev/null +++ b/src/ocr_sprint/schemas/review.py @@ -0,0 +1,62 @@ +"""Request / response schemas for the HITL review endpoints (Phase 6). + +The API surface is deliberately small: + +* ``CorrectionRequest`` — body of ``PATCH /documents/{id}``. A list of + ``FieldCorrection`` entries; each one is applied atomically (all-or- + nothing) and recorded in the audit trail. +* ``CorrectionEventResponse`` — single row in ``GET /documents/{id}/history``. +* ``ApprovalResponse`` — echo back after ``POST /documents/{id}/approve``. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any +from uuid import UUID + +from pydantic import BaseModel, Field + + +class FieldCorrection(BaseModel): + """One field-level correction. + + ``path`` is a dotted JSON path into ``ExtractionResult``. Supported + roots: ``header``, ``ttd``, ``personel[n]`` (n is a 0-based index), + ``untuk``. The path is validated by the repository before being + applied; unknown roots return 400. + """ + + path: str = Field(..., description="Dotted JSON path, e.g. 'header.nomor_sprint'.") + value: Any = Field(..., description="New value (any JSON-serialisable payload).") + reason: str | None = Field( + None, max_length=512, description="Optional free-form reason for the correction." + ) + + +class CorrectionRequest(BaseModel): + """PATCH body — one or more field corrections, applied atomically.""" + + corrections: list[FieldCorrection] = Field(..., min_length=1) + + +class CorrectionEventResponse(BaseModel): + """One row of the audit log surfaced by GET /history.""" + + id: int + job_id: UUID + field_path: str + old_value: Any | None = None + new_value: Any | None = None + corrected_by: str | None = None + reason: str | None = None + corrected_at: datetime + + +class ApprovalResponse(BaseModel): + """Echo returned after a job is approved.""" + + job_id: UUID + approved: bool + reviewed_by: str | None = None + reviewed_at: datetime | None = None diff --git a/tests/unit/test_api_hitl.py b/tests/unit/test_api_hitl.py new file mode 100644 index 0000000..e5781af --- /dev/null +++ b/tests/unit/test_api_hitl.py @@ -0,0 +1,248 @@ +"""End-to-end HTTP tests for the HITL endpoints. + +We re-use the ``fake_pipeline`` style from ``test_api.py`` so we don't pay +the PaddleOCR init cost; the orchestrator is monkey-patched to return a +synthetic ``ExtractionResult``. +""" + +from __future__ import annotations + +from datetime import date + +import pytest +from fastapi.testclient import TestClient + +from ocr_sprint.main import create_app +from ocr_sprint.pipeline import orchestrator as orch_module +from ocr_sprint.pipeline.orchestrator import PipelineOutput +from ocr_sprint.schemas.document import DocumentStatus, SourceKind +from ocr_sprint.schemas.extraction import ( + ExtractionResult, + HeaderFields, + PersonnelEntry, + ReviewFlag, +) + + +@pytest.fixture +def client() -> TestClient: + return TestClient(create_app()) + + +@pytest.fixture +def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput: + result = ExtractionResult( + header=HeaderFields( + nomor_sprint="Sprin/1/I/2025", + tanggal=date(2025, 1, 1), + satuan_penerbit="POLRES TEST", + perihal=None, # intentional gap so a PATCH can fill it + ), + personel=[ + PersonnelEntry(pangkat="AIPDA", nrp="77060000", nama="BUDI", jabatan="ANGGOTA"), + ], + review_flags=[ReviewFlag.MISSING_FIELD], + confidence=0.7, + ) + output = PipelineOutput( + source_kind=SourceKind.PDF, + status=DocumentStatus.NEEDS_REVIEW, + confidence=0.7, + result=result, + ) + + def _fake_run(_content: bytes) -> PipelineOutput: + return output + + monkeypatch.setattr(orch_module, "run_pipeline", _fake_run) + from ocr_sprint.api.routes import documents as docs_module + + monkeypatch.setattr(docs_module, "run_pipeline", _fake_run) + from ocr_sprint.worker import tasks as tasks_module + + monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run) + return output + + +def _create_job(client: TestClient) -> str: + post = client.post( + "/api/v1/documents?sync=true", + files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, + ) + assert post.status_code == 200, post.text + body = post.json() + assert body["status"] == "needs_review" + return str(body["job_id"]) + + +def test_patch_applies_correction_and_clears_missing_field( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + patched = client.patch( + f"/api/v1/documents/{job_id}", + json={ + "corrections": [ + { + "path": "header.perihal", + "value": "Penyelidikan kasus X", + "reason": "LLM missed it", + } + ] + }, + headers={"X-User-Id": "reviewer-a"}, + ) + assert patched.status_code == 200, patched.text + body = patched.json() + assert body["data"]["header"]["perihal"] == "Penyelidikan kasus X" + # The fake pipeline has both required header fields filled, so the + # ``missing_field`` flag is auto-cleared as soon as any correction + # lands (the policy re-evaluates required-field coverage on every + # edit). + assert "missing_field" not in body["review_flags"] + + +def test_patch_returns_400_for_unknown_path( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + resp = client.patch( + f"/api/v1/documents/{job_id}", + json={"corrections": [{"path": "bogus.field", "value": "x"}]}, + ) + assert resp.status_code == 400 + + +def test_patch_is_atomic_on_partial_failure( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + resp = client.patch( + f"/api/v1/documents/{job_id}", + json={ + "corrections": [ + {"path": "header.perihal", "value": "OK"}, + {"path": "bogus.root", "value": "X"}, + ] + }, + ) + assert resp.status_code == 400 + + # The first correction must not have persisted. + got = client.get(f"/api/v1/documents/{job_id}") + assert got.json()["data"]["header"]["perihal"] is None + + +def test_history_returns_corrections_in_order( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + client.patch( + f"/api/v1/documents/{job_id}", + json={"corrections": [{"path": "header.perihal", "value": "first"}]}, + headers={"X-User-Id": "reviewer-a"}, + ) + client.patch( + f"/api/v1/documents/{job_id}", + json={"corrections": [{"path": "header.perihal", "value": "second"}]}, + headers={"X-User-Id": "reviewer-b"}, + ) + + history = client.get(f"/api/v1/documents/{job_id}/history") + assert history.status_code == 200 + events = history.json() + assert [e["new_value"] for e in events] == ["first", "second"] + assert [e["corrected_by"] for e in events] == ["reviewer-a", "reviewer-b"] + # old_value of the second event should reflect the first edit. + assert events[1]["old_value"] == "first" + + +def test_history_returns_empty_list_for_untouched_job( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + history = client.get(f"/api/v1/documents/{job_id}/history") + assert history.status_code == 200 + assert history.json() == [] + + +def test_history_returns_404_for_unknown_job(client: TestClient) -> None: + resp = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000/history") + assert resp.status_code == 404 + + +def test_approve_locks_subsequent_patches( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + approved = client.post( + f"/api/v1/documents/{job_id}/approve", + headers={"X-User-Id": "reviewer-a"}, + ) + assert approved.status_code == 200, approved.text + body = approved.json() + assert body["approved"] is True + assert body["reviewed_by"] == "reviewer-a" + assert body["reviewed_at"] # non-empty timestamp + + # GET reflects the approval state. + got = client.get(f"/api/v1/documents/{job_id}").json() + assert got["approved"] is True + + # PATCH after approve must be rejected with 409. + patched = client.patch( + f"/api/v1/documents/{job_id}", + json={"corrections": [{"path": "header.perihal", "value": "X"}]}, + ) + assert patched.status_code == 409 + + +def test_approve_is_idempotent( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + first = client.post( + f"/api/v1/documents/{job_id}/approve", + headers={"X-User-Id": "reviewer-a"}, + ) + second = client.post( + f"/api/v1/documents/{job_id}/approve", + headers={"X-User-Id": "reviewer-b"}, + ) + assert first.status_code == 200 + assert second.status_code == 200 + # Second approve must NOT change the attribution. (SQLite drops tzinfo + # on roundtrip, which changes Pydantic's serialization between the two + # calls; compare the naive components.) + assert second.json()["reviewed_by"] == "reviewer-a" + assert ( + second.json()["reviewed_at"].rstrip("Z").split("+")[0] + == (first.json()["reviewed_at"].rstrip("Z").split("+")[0]) + ) + + +def test_patch_requires_at_least_one_correction( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + resp = client.patch( + f"/api/v1/documents/{job_id}", + json={"corrections": []}, + ) + assert resp.status_code == 422 # Pydantic min_length=1 violation + + +def test_patch_missing_job_returns_404(client: TestClient) -> None: + resp = client.patch( + "/api/v1/documents/00000000-0000-0000-0000-000000000000", + json={"corrections": [{"path": "header.perihal", "value": "X"}]}, + ) + assert resp.status_code == 404 diff --git a/tests/unit/test_db_hitl.py b/tests/unit/test_db_hitl.py new file mode 100644 index 0000000..7e8f885 --- /dev/null +++ b/tests/unit/test_db_hitl.py @@ -0,0 +1,238 @@ +"""Repository tests for Phase 6 HITL helpers.""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from ocr_sprint.db.base import Base, get_engine, session_scope +from ocr_sprint.db.repositories import ( + InvalidFieldPathError, + JobAlreadyApprovedError, + JobNotCompletedError, + JobNotFoundError, + JobRepository, +) +from ocr_sprint.schemas.document import DocumentStatus, SourceKind + + +@pytest.fixture +def db_ready() -> None: + Base.metadata.create_all(bind=get_engine()) + + +def _seed_completed_job( + *, + result: dict[str, object] | None = None, + flags: list[str] | None = None, +) -> uuid4: # type: ignore[type-arg] + jid = uuid4() + with session_scope() as session: + repo = JobRepository(session) + repo.create( + job_id=jid, + filename="x.pdf", + source_kind=SourceKind.PDF, + blob_key="k", + ) + with session_scope() as session: + JobRepository(session).mark_completed( + jid, + status=DocumentStatus.NEEDS_REVIEW, + confidence=0.7, + result=result + or { + "header": { + "nomor_sprint": "Sprin/1/I/2025", + "satuan_penerbit": "POLRES X", + "perihal": None, + }, + "personel": [ + {"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"}, + ], + "untuk": ["Melaksanakan tugas"], + }, + review_flags=flags or [], + ) + return jid + + +def test_apply_corrections_updates_nested_header_field(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + repo = JobRepository(session) + repo.apply_corrections( + jid, + corrections=[("header.perihal", "Penyelidikan kasus X", "regex miss")], + corrected_by="reviewer-a", + ) + row = repo.get_or_raise(jid) + assert row.result is not None + assert row.result["header"]["perihal"] == "Penyelidikan kasus X" + + +def test_apply_corrections_writes_audit_row(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[("header.perihal", "Penyelidikan", None)], + corrected_by="reviewer-a", + ) + with session_scope() as session: + events = JobRepository(session).list_corrections(jid) + assert len(events) == 1 + assert events[0].field_path == "header.perihal" + assert events[0].old_value is None + assert events[0].new_value == "Penyelidikan" + assert events[0].corrected_by == "reviewer-a" + + +def test_apply_corrections_supports_list_index(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[("personel[0].nrp", "77060001", None)], + corrected_by=None, + ) + row = JobRepository(session).get_or_raise(jid) + assert row.result is not None + assert row.result["personel"][0]["nrp"] == "77060001" + + +def test_apply_corrections_is_atomic_on_invalid_path(db_ready: None) -> None: + """A second-correction failure must roll back the first one.""" + jid = _seed_completed_job() + with session_scope() as session, pytest.raises(InvalidFieldPathError): + JobRepository(session).apply_corrections( + jid, + corrections=[ + ("header.perihal", "OK", None), + ("bogus.root", "X", None), + ], + corrected_by=None, + ) + # The first correction must not have persisted. + with session_scope() as session: + row = JobRepository(session).get_or_raise(jid) + assert row.result is not None + assert row.result["header"].get("perihal") is None + + +def test_apply_corrections_rejects_out_of_range_index(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session, pytest.raises(InvalidFieldPathError): + JobRepository(session).apply_corrections( + jid, + corrections=[("personel[99].nrp", "77060001", None)], + corrected_by=None, + ) + + +def test_apply_corrections_rejects_after_approve(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + JobRepository(session).approve(jid, reviewed_by="reviewer-a") + with session_scope() as session, pytest.raises(JobAlreadyApprovedError): + JobRepository(session).apply_corrections( + jid, + corrections=[("header.perihal", "X", None)], + corrected_by="reviewer-a", + ) + + +def test_apply_corrections_rejects_missing_job(db_ready: None) -> None: + with session_scope() as session, pytest.raises(JobNotFoundError): + JobRepository(session).apply_corrections( + uuid4(), + corrections=[("header.perihal", "X", None)], + corrected_by=None, + ) + + +def test_apply_corrections_rejects_pending_job(db_ready: None) -> None: + jid = uuid4() + with session_scope() as session: + JobRepository(session).create( + job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k" + ) + with session_scope() as session, pytest.raises(JobNotCompletedError): + JobRepository(session).apply_corrections( + jid, + corrections=[("header.perihal", "X", None)], + corrected_by=None, + ) + + +def test_missing_field_flag_cleared_when_header_gap_filled(db_ready: None) -> None: + jid = _seed_completed_job( + result={ + "header": { + "nomor_sprint": None, + "satuan_penerbit": "POLRES X", + } + }, + flags=["missing_field", "low_ocr_confidence"], + ) + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[("header.nomor_sprint", "Sprin/2/I/2025", None)], + corrected_by="reviewer-a", + ) + row = JobRepository(session).get_or_raise(jid) + # ``low_ocr_confidence`` stays (correction doesn't resolve that signal), + # but ``missing_field`` is gone because every required header key is + # now non-empty. + assert list(row.review_flags) == ["low_ocr_confidence"] + + +def test_approve_sets_timestamps_and_is_idempotent(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + row = JobRepository(session).approve(jid, reviewed_by="reviewer-a") + first_at = row.reviewed_at + assert first_at is not None + with session_scope() as session: + row = JobRepository(session).approve(jid, reviewed_by="reviewer-b") + # Second call must NOT overwrite reviewed_by or reviewed_at. + # SQLite drops tzinfo on roundtrip, so compare the naive components. + assert row.approved is True + assert row.reviewed_by == "reviewer-a" + assert row.reviewed_at is not None + assert row.reviewed_at.replace(tzinfo=None) == first_at.replace(tzinfo=None) + + +def test_approve_rejects_pending_job(db_ready: None) -> None: + jid = uuid4() + with session_scope() as session: + JobRepository(session).create( + job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k" + ) + with session_scope() as session, pytest.raises(JobNotCompletedError): + JobRepository(session).approve(jid, reviewed_by="rev") + + +def test_history_returns_events_in_order(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[("header.perihal", "one", None)], + corrected_by="r1", + ) + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[ + ("header.perihal", "two", None), + ("personel[0].nama", "ANDI", None), + ], + corrected_by="r2", + ) + with session_scope() as session: + events = JobRepository(session).list_corrections(jid) + assert [e.new_value for e in events] == ["one", "two", "ANDI"] + assert [e.corrected_by for e in events] == ["r1", "r2", "r2"] diff --git a/tests/unit/test_llm_client.py b/tests/unit/test_llm_client.py new file mode 100644 index 0000000..ed621a9 --- /dev/null +++ b/tests/unit/test_llm_client.py @@ -0,0 +1,108 @@ +"""Unit tests for the Ollama HTTP client wrapper. + +We swap ``httpx.Client`` inside ``ocr_sprint.llm.client`` for a builder that +returns a real ``httpx.Client`` wrapping a ``MockTransport``. Capturing the +original constructor *before* patching avoids infinite recursion in the +patched callable. +""" + +from __future__ import annotations + +from typing import Any + +import httpx +import pytest +from pydantic import BaseModel + +import ocr_sprint.llm.client as llm_client_module +from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient + + +class _Schema(BaseModel): + foo: str + bar: int + + +def _ollama_envelope(content: str) -> dict[str, object]: + """Mimic the shape Ollama's /api/chat returns.""" + return {"message": {"role": "assistant", "content": content}, "done": True} + + +def _patch_transport( + monkeypatch: pytest.MonkeyPatch, + handler: Any, +) -> None: + transport = httpx.MockTransport(handler) + real_client = httpx.Client # capture before patching + + def _factory(*_args: object, **kwargs: object) -> httpx.Client: + # Strip any caller-provided transport kwarg; we always inject ours. + kwargs.pop("transport", None) + return real_client(transport=transport, **kwargs) + + monkeypatch.setattr(llm_client_module.httpx, "Client", _factory) + + +def test_chat_json_returns_validated_model(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, object] = {} + + def _handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["body"] = request.read() + return httpx.Response(200, json=_ollama_envelope('{"foo": "x", "bar": 7}')) + + _patch_transport(monkeypatch, _handler) + + client = OllamaClient(base_url="http://ollama:11434", model="m", timeout_s=5) + out = client.chat_json("system msg", "user msg", _Schema) + + assert out == _Schema(foo="x", bar=7) + assert captured["url"] == "http://ollama:11434/api/chat" + body = captured["body"] + assert isinstance(body, bytes) + assert b'"format":"json"' in body + assert b'"system msg"' in body + + +def test_chat_json_raises_on_http_error(monkeypatch: pytest.MonkeyPatch) -> None: + def _handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response(500, text="boom") + + _patch_transport(monkeypatch, _handler) + + client = OllamaClient(base_url="http://x", model="m", timeout_s=5) + with pytest.raises(LLMUnavailableError, match="Ollama request failed"): + client.chat_json("s", "u", _Schema) + + +def test_chat_json_raises_on_invalid_json(monkeypatch: pytest.MonkeyPatch) -> None: + def _handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json=_ollama_envelope("this is not json")) + + _patch_transport(monkeypatch, _handler) + + client = OllamaClient(base_url="http://x", model="m", timeout_s=5) + with pytest.raises(LLMUnavailableError, match="schema"): + client.chat_json("s", "u", _Schema) + + +def test_chat_json_raises_on_missing_envelope(monkeypatch: pytest.MonkeyPatch) -> None: + def _handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={"oops": True}) + + _patch_transport(monkeypatch, _handler) + + client = OllamaClient(base_url="http://x", model="m", timeout_s=5) + with pytest.raises(LLMUnavailableError, match=r"message\.content"): + client.chat_json("s", "u", _Schema) + + +def test_chat_json_raises_on_connection_error(monkeypatch: pytest.MonkeyPatch) -> None: + def _handler(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("nobody home", request=request) + + _patch_transport(monkeypatch, _handler) + + client = OllamaClient(base_url="http://x", model="m", timeout_s=1) + with pytest.raises(LLMUnavailableError): + client.chat_json("s", "u", _Schema) diff --git a/tests/unit/test_llm_extractor.py b/tests/unit/test_llm_extractor.py new file mode 100644 index 0000000..1dd7a14 --- /dev/null +++ b/tests/unit/test_llm_extractor.py @@ -0,0 +1,90 @@ +"""Unit tests for the hybrid LLM header extractor / merger.""" + +from __future__ import annotations + +from datetime import date + +import pytest +from pydantic import BaseModel + +from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient +from ocr_sprint.llm.extractor import LLMHeaderResult, _merge, llm_fill_header +from ocr_sprint.schemas.extraction import HeaderFields + + +class _StubClient(OllamaClient): + """Test double that bypasses HTTP entirely.""" + + def __init__(self, payload: LLMHeaderResult | Exception) -> None: + # Skip the real __init__ — we don't need any real config. + self._payload = payload + + def chat_json( # type: ignore[override] + self, system: str, user: str, schema_cls: type[BaseModel] + ) -> BaseModel: + if isinstance(self._payload, Exception): + raise self._payload + return self._payload + + +def test_merge_keeps_regex_when_present() -> None: + regex = HeaderFields(nomor_sprint="Sprin/123/IV/2025/Reskrim", tanggal=date(2025, 4, 21)) + llm = LLMHeaderResult(nomor_sprint="HALLUCINATED", tanggal=date(1999, 1, 1), perihal="ok") + out = _merge(regex, llm) + assert out.nomor_sprint == "Sprin/123/IV/2025/Reskrim" + assert out.tanggal == date(2025, 4, 21) + # Gaps get filled. + assert out.perihal == "ok" + + +def test_merge_fills_gaps() -> None: + regex = HeaderFields() # all None + llm = LLMHeaderResult( + nomor_sprint="Sprin/9/IX/2024", + tanggal=date(2024, 9, 1), + satuan_penerbit="Polres Bandung", + perihal="Penyelidikan", + dasar=["UU 2/2002", "Perkap 6/2017"], + ) + out = _merge(regex, llm) + assert out.nomor_sprint == "Sprin/9/IX/2024" + assert out.tanggal == date(2024, 9, 1) + assert out.satuan_penerbit == "Polres Bandung" + assert out.perihal == "Penyelidikan" + assert out.dasar == ["UU 2/2002", "Perkap 6/2017"] + + +def test_llm_fill_header_returns_merged_when_client_succeeds() -> None: + regex = HeaderFields(nomor_sprint="Sprin/1/I/2025") # has nomor, missing rest + stub = _StubClient( + LLMHeaderResult( + satuan_penerbit="Polres Bandung", + perihal="Penyelidikan", + dasar=["UU 2/2002"], + ) + ) + out = llm_fill_header(raw_text="...", regex_header=regex, client=stub) + assert out is not None + assert out.nomor_sprint == "Sprin/1/I/2025" + assert out.satuan_penerbit == "Polres Bandung" + assert out.perihal == "Penyelidikan" + assert out.dasar == ["UU 2/2002"] + + +def test_llm_fill_header_returns_none_when_unavailable() -> None: + stub = _StubClient(LLMUnavailableError("server down")) + out = llm_fill_header(raw_text="...", regex_header=HeaderFields(), client=stub) + assert out is None + + +def test_merge_does_not_overwrite_dasar_when_regex_has_it() -> None: + regex = HeaderFields(dasar=["UU 2/2002"]) + llm = LLMHeaderResult(dasar=["something else", "more"]) + out = _merge(regex, llm) + assert out.dasar == ["UU 2/2002"] + + +def test_llm_extractor_unused_argument_kept_silent() -> None: + # A trivial sanity check that the public function signature accepts + # keyword-only `client` — this matches how the orchestrator calls it. + pytest.importorskip("ocr_sprint.llm.extractor") diff --git a/tests/unit/test_orchestrator_llm.py b/tests/unit/test_orchestrator_llm.py new file mode 100644 index 0000000..d56af3c --- /dev/null +++ b/tests/unit/test_orchestrator_llm.py @@ -0,0 +1,171 @@ +"""Orchestrator-level tests for the Phase 5 hybrid LLM wiring. + +These tests stub out the heavy stages (ingest / preprocess / OCR / table) +so we can verify the *branching* behaviour around the LLM step without +booting Paddle. +""" + +from __future__ import annotations + +from datetime import date + +import pytest + +from ocr_sprint.pipeline import orchestrator as orch_module +from ocr_sprint.pipeline.orchestrator import _header_has_gaps, run_pipeline +from ocr_sprint.schemas.document import SourceKind +from ocr_sprint.schemas.extraction import HeaderFields, ReviewFlag, Signatory + + +def test_header_has_gaps_detects_missing_fields() -> None: + full = HeaderFields( + nomor_sprint="Sprin/1/I/2025", + tanggal=date(2025, 1, 1), + satuan_penerbit="Polres X", + perihal="ok", + dasar=["UU 2/2002"], + ) + assert _header_has_gaps(full) is False + + assert _header_has_gaps(HeaderFields()) is True + assert _header_has_gaps(full.model_copy(update={"perihal": None})) is True + assert _header_has_gaps(full.model_copy(update={"dasar": []})) is True + + +def _stub_pipeline_stages( + monkeypatch: pytest.MonkeyPatch, + *, + raw_text: str, + regex_header: HeaderFields, +) -> None: + """Replace ingest -> ocr -> tables with cheap fakes so the orchestrator + runs without Paddle / PyMuPDF. + """ + import numpy as np + + from ocr_sprint.pipeline import ingest as ingest_module + from ocr_sprint.pipeline import ocr as ocr_module + from ocr_sprint.pipeline.ingest import IngestedPage + + img = np.full((100, 100, 3), 255, dtype=np.uint8) + fake_page = IngestedPage(image=img, page_index=0) + fake_ocr_page = ocr_module.OCRPage( + lines=[ + ocr_module.OCRLine(text=raw_text, confidence=0.95, box=((0, 0), (1, 0), (1, 1), (0, 1))) + ], + ) + + monkeypatch.setattr(orch_module, "detect_source_kind", lambda _: SourceKind.PDF) + monkeypatch.setattr(orch_module, "ingest", lambda *a, **k: [fake_page]) + monkeypatch.setattr(orch_module, "detect_and_correct", lambda image, _cfg: image) + monkeypatch.setattr(orch_module, "preprocess", lambda image, _cfg: image) + monkeypatch.setattr(orch_module, "run_ocr", lambda _image: fake_ocr_page) + # No tables in these tests. + monkeypatch.setattr(orch_module, "run_table_extraction", lambda _img: []) + monkeypatch.setattr(orch_module, "extract_personnel", lambda _tables: []) + # Header / signatory / validators come from the real implementation + # for `extract_header`, but we override to control gap state. + monkeypatch.setattr(orch_module, "extract_header", lambda _text: regex_header) + monkeypatch.setattr(orch_module, "find_signatory", lambda _text: Signatory()) + monkeypatch.setattr(orch_module, "validate_extraction", lambda _result: []) + # Keep ingest_module referenced so import isn't dropped. + assert ingest_module is not None + + +def test_orchestrator_skips_llm_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("LLM_ENABLED", "false") + from ocr_sprint.config import get_settings + + get_settings.cache_clear() + + _stub_pipeline_stages( + monkeypatch, + raw_text="dummy", + regex_header=HeaderFields(), # all gaps + ) + + called = {"n": 0} + + def _trip(*_args: object, **_kwargs: object) -> None: + called["n"] += 1 + return None + + monkeypatch.setattr(orch_module, "llm_fill_header", _trip) + + result = run_pipeline(b"%PDF-1.4\n%fake") + assert called["n"] == 0 + assert ReviewFlag.LLM_FALLBACK not in result.result.review_flags + assert ReviewFlag.LLM_UNAVAILABLE not in result.result.review_flags + + +def test_orchestrator_skips_llm_when_header_complete(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("LLM_ENABLED", "true") + from ocr_sprint.config import get_settings + + get_settings.cache_clear() + + _stub_pipeline_stages( + monkeypatch, + raw_text="dummy", + regex_header=HeaderFields( + nomor_sprint="Sprin/1/I/2025", + tanggal=date(2025, 1, 1), + satuan_penerbit="Polres X", + perihal="ok", + dasar=["UU 2/2002"], + ), + ) + + called = {"n": 0} + + def _trip(*_args: object, **_kwargs: object) -> None: + called["n"] += 1 + return None + + monkeypatch.setattr(orch_module, "llm_fill_header", _trip) + + run_pipeline(b"%PDF-1.4\n%fake") + assert called["n"] == 0 + + +def test_orchestrator_calls_llm_and_marks_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("LLM_ENABLED", "true") + from ocr_sprint.config import get_settings + + get_settings.cache_clear() + + regex_partial = HeaderFields(nomor_sprint="Sprin/1/I/2025") # rest missing + _stub_pipeline_stages(monkeypatch, raw_text="dummy text", regex_header=regex_partial) + + def _llm(_raw: str, header: HeaderFields, **_: object) -> HeaderFields: + return header.model_copy( + update={ + "satuan_penerbit": "Polres Bandung", + "perihal": "Penyelidikan", + "dasar": ["UU 2/2002"], + } + ) + + monkeypatch.setattr(orch_module, "llm_fill_header", _llm) + + out = run_pipeline(b"%PDF-1.4\n%fake") + assert out.result.header.satuan_penerbit == "Polres Bandung" + assert out.result.header.perihal == "Penyelidikan" + assert ReviewFlag.LLM_FALLBACK in out.result.review_flags + assert ReviewFlag.LLM_UNAVAILABLE not in out.result.review_flags + + +def test_orchestrator_marks_unavailable_when_llm_returns_none( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("LLM_ENABLED", "true") + from ocr_sprint.config import get_settings + + get_settings.cache_clear() + + _stub_pipeline_stages(monkeypatch, raw_text="dummy", regex_header=HeaderFields()) + monkeypatch.setattr(orch_module, "llm_fill_header", lambda *_a, **_k: None) + + out = run_pipeline(b"%PDF-1.4\n%fake") + assert ReviewFlag.LLM_UNAVAILABLE in out.result.review_flags + assert ReviewFlag.LLM_FALLBACK not in out.result.review_flags