Merge pull request #4 from Adriankf59/devin/1777135879-phase-5-llm-hybrid
Phase 5: hybrid LLM extraction (Ollama) for header gaps
This commit is contained in:
60
alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py
Normal file
60
alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
"""phase6 hitl: job_corrections + approval columns
|
||||||
|
|
||||||
|
Revision ID: 3b1f2c9a4d56
|
||||||
|
Revises: ff8c14fbf8a0
|
||||||
|
Create Date: 2026-04-25 14:30:00.000000
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "3b1f2c9a4d56"
|
||||||
|
down_revision: str | None = "ff8c14fbf8a0"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
with op.batch_alter_table("jobs") as batch:
|
||||||
|
batch.add_column(
|
||||||
|
sa.Column(
|
||||||
|
"approved",
|
||||||
|
sa.Boolean(),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.false(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
batch.add_column(sa.Column("reviewed_by", sa.String(length=128), nullable=True))
|
||||||
|
batch.add_column(sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=True))
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"job_corrections",
|
||||||
|
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||||
|
sa.Column("job_id", sa.Uuid(), nullable=False),
|
||||||
|
sa.Column("field_path", sa.String(length=256), nullable=False),
|
||||||
|
sa.Column("old_value", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("new_value", sa.JSON(), nullable=True),
|
||||||
|
sa.Column("corrected_by", sa.String(length=128), nullable=True),
|
||||||
|
sa.Column("reason", sa.String(length=512), nullable=True),
|
||||||
|
sa.Column("corrected_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(["job_id"], ["jobs.job_id"], ondelete="CASCADE"),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
op.create_index(
|
||||||
|
op.f("ix_job_corrections_job_id"),
|
||||||
|
"job_corrections",
|
||||||
|
["job_id"],
|
||||||
|
unique=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index(op.f("ix_job_corrections_job_id"), table_name="job_corrections")
|
||||||
|
op.drop_table("job_corrections")
|
||||||
|
with op.batch_alter_table("jobs") as batch:
|
||||||
|
batch.drop_column("reviewed_at")
|
||||||
|
batch.drop_column("reviewed_by")
|
||||||
|
batch.drop_column("approved")
|
||||||
@@ -1,17 +1,17 @@
|
|||||||
"""phase4 jobs table
|
"""phase4 jobs table
|
||||||
|
|
||||||
Revision ID: ff8c14fbf8a0
|
Revision ID: ff8c14fbf8a0
|
||||||
Revises:
|
Revises:
|
||||||
Create Date: 2026-04-25 15:54:18.579147
|
Create Date: 2026-04-25 15:54:18.579147
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision: str = 'ff8c14fbf8a0'
|
revision: str = "ff8c14fbf8a0"
|
||||||
down_revision: str | None = None
|
down_revision: str | None = None
|
||||||
branch_labels: str | Sequence[str] | None = None
|
branch_labels: str | Sequence[str] | None = None
|
||||||
depends_on: str | Sequence[str] | None = None
|
depends_on: str | Sequence[str] | None = None
|
||||||
@@ -19,24 +19,25 @@ depends_on: str | Sequence[str] | None = None
|
|||||||
|
|
||||||
def upgrade() -> None:
|
def upgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table('jobs',
|
op.create_table(
|
||||||
sa.Column('job_id', sa.Uuid(), nullable=False),
|
"jobs",
|
||||||
sa.Column('status', sa.String(length=32), nullable=False),
|
sa.Column("job_id", sa.Uuid(), nullable=False),
|
||||||
sa.Column('source_kind', sa.String(length=16), nullable=False),
|
sa.Column("status", sa.String(length=32), nullable=False),
|
||||||
sa.Column('filename', sa.String(length=512), nullable=False),
|
sa.Column("source_kind", sa.String(length=16), nullable=False),
|
||||||
sa.Column('blob_key', sa.String(length=512), nullable=True),
|
sa.Column("filename", sa.String(length=512), nullable=False),
|
||||||
sa.Column('confidence', sa.Float(), nullable=True),
|
sa.Column("blob_key", sa.String(length=512), nullable=True),
|
||||||
sa.Column('review_flags', sa.JSON(), nullable=False),
|
sa.Column("confidence", sa.Float(), nullable=True),
|
||||||
sa.Column('result', sa.JSON(), nullable=True),
|
sa.Column("review_flags", sa.JSON(), nullable=False),
|
||||||
sa.Column('error', sa.String(length=2048), nullable=True),
|
sa.Column("result", sa.JSON(), nullable=True),
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("error", sa.String(length=2048), nullable=True),
|
||||||
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.PrimaryKeyConstraint('job_id')
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("job_id"),
|
||||||
)
|
)
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.drop_table('jobs')
|
op.drop_table("jobs")
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|||||||
@@ -22,7 +22,17 @@ from __future__ import annotations
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, Response, UploadFile, status
|
from fastapi import (
|
||||||
|
APIRouter,
|
||||||
|
Depends,
|
||||||
|
File,
|
||||||
|
Header,
|
||||||
|
HTTPException,
|
||||||
|
Query,
|
||||||
|
Response,
|
||||||
|
UploadFile,
|
||||||
|
status,
|
||||||
|
)
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ocr_sprint.api.deps.auth import require_api_key
|
from ocr_sprint.api.deps.auth import require_api_key
|
||||||
@@ -31,11 +41,22 @@ from ocr_sprint.api.errors import UnsupportedDocumentError
|
|||||||
from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS
|
from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS
|
||||||
from ocr_sprint.config import get_settings
|
from ocr_sprint.config import get_settings
|
||||||
from ocr_sprint.db.base import session_scope
|
from ocr_sprint.db.base import session_scope
|
||||||
from ocr_sprint.db.repositories import JobNotFoundError, JobRepository
|
from ocr_sprint.db.repositories import (
|
||||||
|
InvalidFieldPathError,
|
||||||
|
JobAlreadyApprovedError,
|
||||||
|
JobNotCompletedError,
|
||||||
|
JobNotFoundError,
|
||||||
|
JobRepository,
|
||||||
|
)
|
||||||
from ocr_sprint.pipeline.ingest import detect_source_kind
|
from ocr_sprint.pipeline.ingest import detect_source_kind
|
||||||
from ocr_sprint.pipeline.orchestrator import run_pipeline
|
from ocr_sprint.pipeline.orchestrator import run_pipeline
|
||||||
from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus
|
from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus
|
||||||
from ocr_sprint.schemas.extraction import ExtractionResult
|
from ocr_sprint.schemas.extraction import ExtractionResult
|
||||||
|
from ocr_sprint.schemas.review import (
|
||||||
|
ApprovalResponse,
|
||||||
|
CorrectionEventResponse,
|
||||||
|
CorrectionRequest,
|
||||||
|
)
|
||||||
from ocr_sprint.storage.blob import get_blob_storage
|
from ocr_sprint.storage.blob import get_blob_storage
|
||||||
from ocr_sprint.utils.logging import get_logger
|
from ocr_sprint.utils.logging import get_logger
|
||||||
|
|
||||||
@@ -75,6 +96,9 @@ def _row_to_response(row: object) -> DocumentResponse:
|
|||||||
data=result_obj,
|
data=result_obj,
|
||||||
review_flags=list(row.review_flags or []),
|
review_flags=list(row.review_flags or []),
|
||||||
error=row.error,
|
error=row.error,
|
||||||
|
approved=bool(row.approved),
|
||||||
|
reviewed_by=row.reviewed_by,
|
||||||
|
reviewed_at=row.reviewed_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -192,3 +216,116 @@ async def get_document(
|
|||||||
except JobNotFoundError as exc:
|
except JobNotFoundError as exc:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||||
return _row_to_response(row)
|
return _row_to_response(row)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- Phase 6 — HITL ----------
|
||||||
|
|
||||||
|
|
||||||
|
def _correction_row_to_response(row: object) -> CorrectionEventResponse:
|
||||||
|
# Local import to avoid a cyclic import at module load time.
|
||||||
|
from ocr_sprint.db.models import JobCorrectionRow
|
||||||
|
|
||||||
|
assert isinstance(row, JobCorrectionRow)
|
||||||
|
return CorrectionEventResponse(
|
||||||
|
id=row.id,
|
||||||
|
job_id=row.job_id,
|
||||||
|
field_path=row.field_path,
|
||||||
|
old_value=row.old_value,
|
||||||
|
new_value=row.new_value,
|
||||||
|
corrected_by=row.corrected_by,
|
||||||
|
reason=row.reason,
|
||||||
|
corrected_at=row.corrected_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch(
|
||||||
|
"/{job_id}",
|
||||||
|
response_model=DocumentResponse,
|
||||||
|
)
|
||||||
|
async def patch_document(
|
||||||
|
job_id: UUID,
|
||||||
|
body: CorrectionRequest,
|
||||||
|
session: Annotated[Session, Depends(get_session)],
|
||||||
|
x_user_id: Annotated[
|
||||||
|
str | None,
|
||||||
|
Header(description="Free-form reviewer identifier recorded on the audit row."),
|
||||||
|
] = None,
|
||||||
|
) -> DocumentResponse:
|
||||||
|
"""Apply one or more field-level corrections and record an audit trail.
|
||||||
|
|
||||||
|
The whole batch is applied atomically — if any path is invalid the
|
||||||
|
request fails with 400 and no side effects are written. Returns the
|
||||||
|
updated document so the client doesn't need a follow-up GET.
|
||||||
|
"""
|
||||||
|
repo = JobRepository(session)
|
||||||
|
try:
|
||||||
|
repo.apply_corrections(
|
||||||
|
job_id,
|
||||||
|
corrections=[(c.path, c.value, c.reason) for c in body.corrections],
|
||||||
|
corrected_by=x_user_id,
|
||||||
|
)
|
||||||
|
row = repo.get_or_raise(job_id)
|
||||||
|
except JobNotFoundError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||||
|
except InvalidFieldPathError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||||
|
except JobAlreadyApprovedError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||||
|
except JobNotCompletedError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
_logger.info(
|
||||||
|
"documents.patched",
|
||||||
|
job_id=str(job_id),
|
||||||
|
count=len(body.corrections),
|
||||||
|
corrected_by=x_user_id or "",
|
||||||
|
)
|
||||||
|
return _row_to_response(row)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{job_id}/history",
|
||||||
|
response_model=list[CorrectionEventResponse],
|
||||||
|
)
|
||||||
|
async def get_history(
|
||||||
|
job_id: UUID,
|
||||||
|
session: Annotated[Session, Depends(get_session)],
|
||||||
|
) -> list[CorrectionEventResponse]:
|
||||||
|
repo = JobRepository(session)
|
||||||
|
try:
|
||||||
|
rows = repo.list_corrections(job_id)
|
||||||
|
except JobNotFoundError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||||
|
return [_correction_row_to_response(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/{job_id}/approve",
|
||||||
|
response_model=ApprovalResponse,
|
||||||
|
)
|
||||||
|
async def approve_document(
|
||||||
|
job_id: UUID,
|
||||||
|
session: Annotated[Session, Depends(get_session)],
|
||||||
|
x_user_id: Annotated[
|
||||||
|
str | None,
|
||||||
|
Header(description="Free-form reviewer identifier recorded on the job."),
|
||||||
|
] = None,
|
||||||
|
) -> ApprovalResponse:
|
||||||
|
"""Lock a job's final version. Idempotent: re-approving returns the
|
||||||
|
existing row without overwriting ``reviewed_at``.
|
||||||
|
"""
|
||||||
|
repo = JobRepository(session)
|
||||||
|
try:
|
||||||
|
row = repo.approve(job_id, reviewed_by=x_user_id)
|
||||||
|
except JobNotFoundError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||||
|
except JobNotCompletedError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
|
||||||
|
|
||||||
|
_logger.info("documents.approved", job_id=str(job_id), reviewed_by=row.reviewed_by or "")
|
||||||
|
return ApprovalResponse(
|
||||||
|
job_id=row.job_id,
|
||||||
|
approved=bool(row.approved),
|
||||||
|
reviewed_by=row.reviewed_by,
|
||||||
|
reviewed_at=row.reviewed_at,
|
||||||
|
)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from datetime import datetime, timezone
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from sqlalchemy import JSON, DateTime, Float, String, Uuid
|
from sqlalchemy import JSON, Boolean, DateTime, Float, ForeignKey, Integer, String, Uuid
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from ocr_sprint.db.base import Base
|
from ocr_sprint.db.base import Base
|
||||||
@@ -42,6 +42,15 @@ class JobRow(Base):
|
|||||||
result: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
|
result: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
|
||||||
error: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
error: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
||||||
|
|
||||||
|
# Phase 6 — HITL review state.
|
||||||
|
# Once ``approved=True`` the row is immutable except to admin users;
|
||||||
|
# corrections after that point are rejected by the route. ``reviewed_by``
|
||||||
|
# stores the free-form user identifier the reviewer sent via the
|
||||||
|
# ``X-User-Id`` header (best-effort attribution — no full RBAC yet).
|
||||||
|
approved: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
|
reviewed_by: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||||
|
reviewed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, default=_utcnow
|
DateTime(timezone=True), nullable=False, default=_utcnow
|
||||||
)
|
)
|
||||||
@@ -51,3 +60,37 @@ class JobRow(Base):
|
|||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"JobRow(job_id={self.job_id!s}, status={self.status!r})"
|
return f"JobRow(job_id={self.job_id!s}, status={self.status!r})"
|
||||||
|
|
||||||
|
|
||||||
|
class JobCorrectionRow(Base):
|
||||||
|
"""One correction event on a job's ``result``.
|
||||||
|
|
||||||
|
Each PATCH call writes one row per changed field path so we have a
|
||||||
|
full audit trail. Rows are append-only — never updated, never
|
||||||
|
deleted — so the history is reproducible and usable as ground-truth
|
||||||
|
data for future fine-tuning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "job_corrections"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
job_id: Mapped[UUID] = mapped_column(
|
||||||
|
Uuid, ForeignKey("jobs.job_id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
# Dotted JSON path into ExtractionResult, e.g. "header.nomor_sprint" or
|
||||||
|
# "personel[3].nrp". Kept as a plain string for simplicity — we don't
|
||||||
|
# parse it server-side beyond the allow-list check in the repository.
|
||||||
|
field_path: Mapped[str] = mapped_column(String(256), nullable=False)
|
||||||
|
old_value: Mapped[Any | None] = mapped_column(JSON, nullable=True)
|
||||||
|
new_value: Mapped[Any | None] = mapped_column(JSON, nullable=True)
|
||||||
|
corrected_by: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||||
|
reason: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||||
|
corrected_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, default=_utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return (
|
||||||
|
f"JobCorrectionRow(job_id={self.job_id!s}, "
|
||||||
|
f"field={self.field_path!r}, by={self.corrected_by!r})"
|
||||||
|
)
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ know about sessions, transactions, or the row → schema mapping.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@@ -13,7 +16,7 @@ from uuid import UUID
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from ocr_sprint.db.models import JobRow
|
from ocr_sprint.db.models import JobCorrectionRow, JobRow
|
||||||
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||||
|
|
||||||
|
|
||||||
@@ -25,6 +28,110 @@ class JobNotFoundError(LookupError):
|
|||||||
"""Raised by API code when GET /documents/{id} hits a missing row."""
|
"""Raised by API code when GET /documents/{id} hits a missing row."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidFieldPathError(ValueError):
|
||||||
|
"""Raised when a PATCH request references an unsupported field path."""
|
||||||
|
|
||||||
|
|
||||||
|
class JobAlreadyApprovedError(RuntimeError):
|
||||||
|
"""Raised when a PATCH is attempted against an already-approved job."""
|
||||||
|
|
||||||
|
|
||||||
|
class JobNotCompletedError(RuntimeError):
|
||||||
|
"""Raised when a PATCH/approve is attempted against a job that hasn't
|
||||||
|
produced a ``result`` payload yet (e.g. still pending or failed).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AppliedCorrection:
|
||||||
|
"""Internal record of a correction that successfully applied to the
|
||||||
|
in-memory ``result`` dict. The repository turns this into a persisted
|
||||||
|
``JobCorrectionRow`` after the whole batch is validated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
field_path: str
|
||||||
|
old_value: Any
|
||||||
|
new_value: Any
|
||||||
|
reason: str | None
|
||||||
|
|
||||||
|
|
||||||
|
# Allow-list of top-level keys we let reviewers edit. Keeps the attack
|
||||||
|
# surface small: they can't inject arbitrary fields into the JSON blob.
|
||||||
|
_ALLOWED_ROOTS: frozenset[str] = frozenset({"header", "ttd", "personel", "untuk"})
|
||||||
|
|
||||||
|
# Matches a single path segment like ``personel[3]`` — supports at most one
|
||||||
|
# index per segment, enough for the list fields we care about.
|
||||||
|
_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$")
|
||||||
|
|
||||||
|
|
||||||
|
def _split_path(path: str) -> list[tuple[str, int | None]]:
|
||||||
|
"""Parse ``header.nomor_sprint`` or ``personel[2].nrp`` into segments.
|
||||||
|
|
||||||
|
Returns list of ``(name, index_or_none)`` tuples. Raises
|
||||||
|
``InvalidFieldPathError`` on malformed input so the caller can surface
|
||||||
|
a 400 to the client.
|
||||||
|
"""
|
||||||
|
if not path or path.startswith(".") or path.endswith("."):
|
||||||
|
raise InvalidFieldPathError(f"Invalid field path: {path!r}")
|
||||||
|
|
||||||
|
parts = path.split(".")
|
||||||
|
out: list[tuple[str, int | None]] = []
|
||||||
|
for part in parts:
|
||||||
|
match = _SEGMENT_RE.match(part)
|
||||||
|
if match is None:
|
||||||
|
raise InvalidFieldPathError(f"Invalid segment in path: {part!r}")
|
||||||
|
name = match.group(1)
|
||||||
|
idx_raw = match.group(2)
|
||||||
|
idx = int(idx_raw) if idx_raw is not None else None
|
||||||
|
out.append((name, idx))
|
||||||
|
|
||||||
|
if out[0][0] not in _ALLOWED_ROOTS:
|
||||||
|
raise InvalidFieldPathError(
|
||||||
|
f"Field path root {out[0][0]!r} not in allowed roots {sorted(_ALLOWED_ROOTS)!r}"
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_path(data: dict[str, Any], path: str, new_value: Any) -> Any:
|
||||||
|
"""Apply a single correction to ``data`` in place. Returns the old
|
||||||
|
value so the caller can record it in the audit row.
|
||||||
|
|
||||||
|
Does NOT validate that the new value matches the field's expected
|
||||||
|
type — that's the reviewer's responsibility; the whole point of HITL
|
||||||
|
is to let humans override the model's typing.
|
||||||
|
"""
|
||||||
|
segments = _split_path(path)
|
||||||
|
cursor: Any = data
|
||||||
|
for name, idx in segments[:-1]:
|
||||||
|
if not isinstance(cursor, dict) or name not in cursor:
|
||||||
|
raise InvalidFieldPathError(f"Cannot traverse to {path!r}: missing {name!r}")
|
||||||
|
cursor = cursor[name]
|
||||||
|
if idx is not None:
|
||||||
|
if not isinstance(cursor, list) or idx >= len(cursor):
|
||||||
|
raise InvalidFieldPathError(
|
||||||
|
f"Cannot traverse to {path!r}: index [{idx}] out of range"
|
||||||
|
)
|
||||||
|
cursor = cursor[idx]
|
||||||
|
|
||||||
|
name, idx = segments[-1]
|
||||||
|
if idx is not None:
|
||||||
|
# Terminal segment is a list-element, e.g. ``untuk[2]``.
|
||||||
|
if not isinstance(cursor, dict) or name not in cursor:
|
||||||
|
raise InvalidFieldPathError(f"Cannot apply to {path!r}: missing container {name!r}")
|
||||||
|
container = cursor[name]
|
||||||
|
if not isinstance(container, list) or idx >= len(container):
|
||||||
|
raise InvalidFieldPathError(f"Cannot apply to {path!r}: index [{idx}] out of range")
|
||||||
|
old = container[idx]
|
||||||
|
container[idx] = new_value
|
||||||
|
return old
|
||||||
|
|
||||||
|
if not isinstance(cursor, dict):
|
||||||
|
raise InvalidFieldPathError(f"Cannot apply to {path!r}: parent is not an object")
|
||||||
|
old = cursor.get(name)
|
||||||
|
cursor[name] = new_value
|
||||||
|
return old
|
||||||
|
|
||||||
|
|
||||||
class JobRepository:
|
class JobRepository:
|
||||||
"""SQL-backed repository for `jobs` rows."""
|
"""SQL-backed repository for `jobs` rows."""
|
||||||
|
|
||||||
@@ -94,3 +201,139 @@ class JobRepository:
|
|||||||
if row is None:
|
if row is None:
|
||||||
raise JobNotFoundError(f"Job not found: {job_id}")
|
raise JobNotFoundError(f"Job not found: {job_id}")
|
||||||
return row
|
return row
|
||||||
|
|
||||||
|
# ---------- Phase 6 — HITL ----------
|
||||||
|
|
||||||
|
def apply_corrections(
|
||||||
|
self,
|
||||||
|
job_id: UUID,
|
||||||
|
*,
|
||||||
|
corrections: list[tuple[str, Any, str | None]],
|
||||||
|
corrected_by: str | None,
|
||||||
|
) -> list[JobCorrectionRow]:
|
||||||
|
"""Apply a batch of field corrections atomically.
|
||||||
|
|
||||||
|
``corrections`` is a list of ``(path, new_value, reason)`` tuples.
|
||||||
|
Returns the persisted audit rows so the caller can surface them in
|
||||||
|
the response.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
JobNotFoundError
|
||||||
|
If the row doesn't exist.
|
||||||
|
JobNotCompletedError
|
||||||
|
If the job hasn't produced a result yet (status pending /
|
||||||
|
processing / failed).
|
||||||
|
JobAlreadyApprovedError
|
||||||
|
If the job has been approved — edits are locked.
|
||||||
|
InvalidFieldPathError
|
||||||
|
If any path is malformed or references a disallowed root.
|
||||||
|
"""
|
||||||
|
row = self._get_or_raise(job_id)
|
||||||
|
if row.result is None:
|
||||||
|
raise JobNotCompletedError(
|
||||||
|
f"Job {job_id} has no result to correct (status={row.status})"
|
||||||
|
)
|
||||||
|
if row.approved:
|
||||||
|
raise JobAlreadyApprovedError(f"Job {job_id} is already approved; edits are locked")
|
||||||
|
|
||||||
|
# Deep-copy so we can roll back in memory if any correction fails.
|
||||||
|
# The underlying JSON column will only be re-assigned once every
|
||||||
|
# path applied cleanly.
|
||||||
|
working = copy.deepcopy(row.result)
|
||||||
|
applied: list[AppliedCorrection] = []
|
||||||
|
for path, new_value, reason in corrections:
|
||||||
|
old_value = _apply_path(working, path, new_value)
|
||||||
|
applied.append(
|
||||||
|
AppliedCorrection(
|
||||||
|
field_path=path, old_value=old_value, new_value=new_value, reason=reason
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Persist audit rows first; if they fail the session rollback also
|
||||||
|
# undoes the result-column update we're about to do.
|
||||||
|
persisted: list[JobCorrectionRow] = []
|
||||||
|
for event in applied:
|
||||||
|
row_event = JobCorrectionRow(
|
||||||
|
job_id=job_id,
|
||||||
|
field_path=event.field_path,
|
||||||
|
old_value=event.old_value,
|
||||||
|
new_value=event.new_value,
|
||||||
|
corrected_by=corrected_by,
|
||||||
|
reason=event.reason,
|
||||||
|
)
|
||||||
|
self.session.add(row_event)
|
||||||
|
persisted.append(row_event)
|
||||||
|
|
||||||
|
# Clear review flags that the correction has resolved. Right now we
|
||||||
|
# only auto-clear MISSING_FIELD when any corrected field previously
|
||||||
|
# held a null/empty value — the reviewer explicitly filled a gap.
|
||||||
|
row.result = working
|
||||||
|
row.review_flags = _recompute_flags(
|
||||||
|
original_flags=list(row.review_flags or []),
|
||||||
|
applied=applied,
|
||||||
|
working_result=working,
|
||||||
|
)
|
||||||
|
row.updated_at = _utcnow()
|
||||||
|
|
||||||
|
self.session.flush()
|
||||||
|
return persisted
|
||||||
|
|
||||||
|
def list_corrections(self, job_id: UUID) -> list[JobCorrectionRow]:
|
||||||
|
"""Return the full audit trail for ``job_id`` in chronological order."""
|
||||||
|
# ``get_or_raise`` so callers get a 404 instead of an empty list
|
||||||
|
# when the job itself doesn't exist.
|
||||||
|
self._get_or_raise(job_id)
|
||||||
|
stmt = (
|
||||||
|
select(JobCorrectionRow)
|
||||||
|
.where(JobCorrectionRow.job_id == job_id)
|
||||||
|
.order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id)
|
||||||
|
)
|
||||||
|
return list(self.session.scalars(stmt))
|
||||||
|
|
||||||
|
def approve(self, job_id: UUID, *, reviewed_by: str | None) -> JobRow:
|
||||||
|
"""Mark a job as approved. Idempotent — re-approving is a no-op
|
||||||
|
that keeps the original ``reviewed_at`` (so the audit trail stays
|
||||||
|
intact).
|
||||||
|
"""
|
||||||
|
row = self._get_or_raise(job_id)
|
||||||
|
if row.result is None:
|
||||||
|
raise JobNotCompletedError(
|
||||||
|
f"Job {job_id} has no result to approve (status={row.status})"
|
||||||
|
)
|
||||||
|
if row.approved:
|
||||||
|
return row
|
||||||
|
row.approved = True
|
||||||
|
row.reviewed_by = reviewed_by
|
||||||
|
row.reviewed_at = _utcnow()
|
||||||
|
row.updated_at = row.reviewed_at
|
||||||
|
return row
|
||||||
|
|
||||||
|
|
||||||
|
def _recompute_flags(
|
||||||
|
*,
|
||||||
|
original_flags: list[str],
|
||||||
|
applied: list[AppliedCorrection],
|
||||||
|
working_result: dict[str, Any],
|
||||||
|
) -> list[str]:
|
||||||
|
"""Update review flags in light of the corrections just applied.
|
||||||
|
|
||||||
|
Keeps the policy simple on purpose:
|
||||||
|
* ``missing_field`` is removed if after the edit every required
|
||||||
|
header field is non-empty.
|
||||||
|
* Other flags stay untouched — the reviewer should either correct the
|
||||||
|
underlying issue (which this helper can detect) or explicitly
|
||||||
|
approve the result as-is (which bypasses the flag list).
|
||||||
|
"""
|
||||||
|
flags = list(original_flags)
|
||||||
|
if "missing_field" in flags:
|
||||||
|
header = working_result.get("header") or {}
|
||||||
|
filled = all(bool(header.get(key)) for key in ("nomor_sprint", "satuan_penerbit"))
|
||||||
|
if filled:
|
||||||
|
flags = [f for f in flags if f != "missing_field"]
|
||||||
|
|
||||||
|
# ``applied`` isn't used directly in this MVP rule, but we keep the
|
||||||
|
# parameter so future policies can inspect exactly what changed
|
||||||
|
# without re-diffing the blob.
|
||||||
|
_ = applied
|
||||||
|
return flags
|
||||||
|
|||||||
18
src/ocr_sprint/llm/__init__.py
Normal file
18
src/ocr_sprint/llm/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""LLM-based extraction (Phase 5).
|
||||||
|
|
||||||
|
The hybrid extractor first runs the deterministic regex layer and then —
|
||||||
|
only for fields that came back missing or low-confidence — calls a local
|
||||||
|
Ollama model with a Pydantic-typed prompt. Everything is gated by
|
||||||
|
``LLM_ENABLED``; if the flag is off or the Ollama server is unreachable,
|
||||||
|
the pipeline degrades gracefully back to the regex result.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||||
|
from ocr_sprint.llm.extractor import LLMHeaderResult, llm_fill_header
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LLMHeaderResult",
|
||||||
|
"LLMUnavailableError",
|
||||||
|
"OllamaClient",
|
||||||
|
"llm_fill_header",
|
||||||
|
]
|
||||||
97
src/ocr_sprint/llm/client.py
Normal file
97
src/ocr_sprint/llm/client.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""Ollama HTTP client.
|
||||||
|
|
||||||
|
We deliberately avoid the ``ollama`` Python package — the wire format is a
|
||||||
|
single ``POST /api/chat`` with ``format="json"`` and a system + user message,
|
||||||
|
so a small ``httpx`` wrapper is enough. This keeps the runtime dependency
|
||||||
|
footprint smaller and makes the mock-based unit tests trivial.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
from ocr_sprint.utils.logging import get_logger
|
||||||
|
|
||||||
|
_logger = get_logger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMUnavailableError(RuntimeError):
|
||||||
|
"""Raised when the Ollama server is unreachable, times out, or returns
|
||||||
|
a malformed payload. The pipeline catches this and falls back to the
|
||||||
|
regex-only result with a ``llm_fallback`` review flag.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaClient:
|
||||||
|
"""Tiny synchronous HTTP wrapper around the Ollama ``/api/chat`` endpoint.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
base_url:
|
||||||
|
Ollama server URL, e.g. ``http://localhost:11434``.
|
||||||
|
model:
|
||||||
|
Model tag to invoke (default ``qwen2.5:1.5b`` — chosen for CPU
|
||||||
|
latency at acceptable accuracy).
|
||||||
|
timeout_s:
|
||||||
|
Hard wall-clock timeout for a single request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
base_url: str | None = None,
|
||||||
|
model: str | None = None,
|
||||||
|
timeout_s: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
s = get_settings()
|
||||||
|
self.base_url = (base_url or s.llm_base_url).rstrip("/")
|
||||||
|
self.model = model or s.llm_model
|
||||||
|
self.timeout_s = timeout_s if timeout_s is not None else s.llm_timeout_s
|
||||||
|
|
||||||
|
# ---------- public API ----------
|
||||||
|
|
||||||
|
def chat_json(self, system: str, user: str, schema_cls: type[T]) -> T:
|
||||||
|
"""Run a single chat completion in JSON mode and validate the
|
||||||
|
response against ``schema_cls``. Raises ``LLMUnavailableError`` on
|
||||||
|
any transport / parse / validation failure so callers only have one
|
||||||
|
exception to handle.
|
||||||
|
"""
|
||||||
|
payload = {
|
||||||
|
"model": self.model,
|
||||||
|
"stream": False,
|
||||||
|
"format": "json",
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": system},
|
||||||
|
{"role": "user", "content": user},
|
||||||
|
],
|
||||||
|
# Keep determinism reasonable — we want extraction, not creativity.
|
||||||
|
"options": {"temperature": 0.0, "num_ctx": 4096},
|
||||||
|
}
|
||||||
|
url = f"{self.base_url}/api/chat"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with httpx.Client(timeout=self.timeout_s) as client:
|
||||||
|
response = client.post(url, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
except (httpx.HTTPError, ValueError) as exc:
|
||||||
|
_logger.warning("llm.transport_error", url=url, error=str(exc))
|
||||||
|
raise LLMUnavailableError(f"Ollama request failed: {exc}") from exc
|
||||||
|
|
||||||
|
# Ollama returns {"message": {"role": "assistant", "content": "<json>"}}.
|
||||||
|
try:
|
||||||
|
content = data["message"]["content"]
|
||||||
|
except (KeyError, TypeError) as exc:
|
||||||
|
_logger.warning("llm.bad_envelope", payload=data)
|
||||||
|
raise LLMUnavailableError(f"Ollama response missing message.content: {data!r}") from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
return schema_cls.model_validate_json(content)
|
||||||
|
except ValidationError as exc:
|
||||||
|
_logger.warning("llm.validation_error", error=str(exc), content=content[:400])
|
||||||
|
raise LLMUnavailableError(f"LLM JSON failed schema: {exc}") from exc
|
||||||
84
src/ocr_sprint/llm/extractor.py
Normal file
84
src/ocr_sprint/llm/extractor.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""High-level LLM extractor.
|
||||||
|
|
||||||
|
The job is *narrow*: take the raw OCR text plus the partial header that
|
||||||
|
came back from the regex layer, and return an LLM-derived header that the
|
||||||
|
caller can merge in. We never let the LLM populate the personnel table —
|
||||||
|
PP-Structure is more accurate and cheaper for that.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||||
|
from ocr_sprint.llm.prompts import SYSTEM_HEADER, build_user_prompt
|
||||||
|
from ocr_sprint.schemas.extraction import HeaderFields
|
||||||
|
from ocr_sprint.utils.logging import get_logger
|
||||||
|
|
||||||
|
_logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMHeaderResult(BaseModel):
|
||||||
|
"""Schema we ask the model to fill. Mirrors ``HeaderFields`` but is
|
||||||
|
intentionally separate so we control exactly what the prompt and
|
||||||
|
validation surface look like — the public ``HeaderFields`` may grow
|
||||||
|
fields later that we don't want the LLM touching.
|
||||||
|
"""
|
||||||
|
|
||||||
|
nomor_sprint: str | None = None
|
||||||
|
tanggal: date | None = None
|
||||||
|
satuan_penerbit: str | None = None
|
||||||
|
perihal: str | None = None
|
||||||
|
dasar: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def llm_fill_header(
|
||||||
|
raw_text: str,
|
||||||
|
regex_header: HeaderFields,
|
||||||
|
*,
|
||||||
|
client: OllamaClient | None = None,
|
||||||
|
) -> HeaderFields | None:
|
||||||
|
"""Run the LLM extractor and return a *merged* HeaderFields.
|
||||||
|
|
||||||
|
Returns ``None`` if the model is unavailable so the caller can decide
|
||||||
|
what to do (typically: keep the regex result and emit a fallback
|
||||||
|
review flag).
|
||||||
|
"""
|
||||||
|
client = client or OllamaClient()
|
||||||
|
|
||||||
|
user = build_user_prompt(
|
||||||
|
raw_text=raw_text,
|
||||||
|
regex_partial=regex_header.model_dump(mode="json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
llm = client.chat_json(SYSTEM_HEADER, user, LLMHeaderResult)
|
||||||
|
except LLMUnavailableError as exc:
|
||||||
|
_logger.warning("llm.unavailable", error=str(exc))
|
||||||
|
return None
|
||||||
|
|
||||||
|
return _merge(regex_header, llm)
|
||||||
|
|
||||||
|
|
||||||
|
def _merge(regex: HeaderFields, llm: LLMHeaderResult) -> HeaderFields:
|
||||||
|
"""Merge LLM output into the regex result.
|
||||||
|
|
||||||
|
Policy: regex wins for any field it already filled. The LLM only fills
|
||||||
|
the *gaps*. This keeps deterministic / verifiable extractions for the
|
||||||
|
fields where regex is reliable and prevents the LLM from "correcting"
|
||||||
|
a value that happens to look unusual but is in fact correct.
|
||||||
|
"""
|
||||||
|
merged = regex.model_copy(deep=True)
|
||||||
|
if merged.nomor_sprint is None and llm.nomor_sprint:
|
||||||
|
merged.nomor_sprint = llm.nomor_sprint
|
||||||
|
if merged.tanggal is None and llm.tanggal is not None:
|
||||||
|
merged.tanggal = llm.tanggal
|
||||||
|
if not merged.satuan_penerbit and llm.satuan_penerbit:
|
||||||
|
merged.satuan_penerbit = llm.satuan_penerbit
|
||||||
|
if not merged.perihal and llm.perihal:
|
||||||
|
merged.perihal = llm.perihal
|
||||||
|
if not merged.dasar and llm.dasar:
|
||||||
|
merged.dasar = list(llm.dasar)
|
||||||
|
return merged
|
||||||
48
src/ocr_sprint/llm/prompts.py
Normal file
48
src/ocr_sprint/llm/prompts.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Prompt builders for the LLM extractor.
|
||||||
|
|
||||||
|
Kept in their own module so the prompts can be edited / version-tracked
|
||||||
|
without touching the orchestration logic. We build prompts in Indonesian
|
||||||
|
because the source documents are too — the model performs better when the
|
||||||
|
field labels in the prompt match the OCR text it's being asked about.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
SYSTEM_HEADER = (
|
||||||
|
"Anda adalah asisten ekstraksi data untuk dokumen Surat Perintah (Sprint) "
|
||||||
|
"Kepolisian Republik Indonesia (POLRI). Pengguna akan memberikan teks hasil "
|
||||||
|
"OCR sebuah surat sprint, dan Anda harus mengembalikan JSON yang sesuai "
|
||||||
|
"dengan skema yang diberikan.\n\n"
|
||||||
|
"Aturan keras:\n"
|
||||||
|
"1. Jangan mengarang. Jika sebuah field tidak terlihat di teks, kembalikan null.\n"
|
||||||
|
"2. Jangan menerjemahkan field. Output harus identik ejaannya dengan teks "
|
||||||
|
"sumber (kecuali normalisasi spasi/kapitalisasi yang jelas hasil OCR error).\n"
|
||||||
|
"3. Tanggal: kembalikan format ISO YYYY-MM-DD jika tanggal terlihat, "
|
||||||
|
"selain itu null.\n"
|
||||||
|
"4. Dasar hukum: array string berisi tiap butir, urut sesuai teks.\n"
|
||||||
|
"5. Jangan menambahkan field apa pun di luar skema. Output WAJIB JSON valid."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_user_prompt(raw_text: str, regex_partial: dict[str, object]) -> str:
|
||||||
|
"""Construct the user message: OCR text + a hint about which fields the
|
||||||
|
deterministic regex layer already filled. Telling the LLM what we
|
||||||
|
*already have* keeps it from "creatively" overwriting good values.
|
||||||
|
"""
|
||||||
|
known_fields = "\n".join(f" - {k}: {v!r}" for k, v in sorted(regex_partial.items()) if v)
|
||||||
|
known_block = (
|
||||||
|
f"\nField yang sudah berhasil diekstrak dengan regex:\n{known_fields}\n"
|
||||||
|
if known_fields
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
"Teks OCR:\n"
|
||||||
|
"----------\n"
|
||||||
|
f"{raw_text}\n"
|
||||||
|
"----------\n"
|
||||||
|
f"{known_block}"
|
||||||
|
"Tugas: kembalikan JSON dengan field nomor_sprint, tanggal (ISO date | null), "
|
||||||
|
"satuan_penerbit, perihal, dasar (array string). Hanya field yang terlihat — "
|
||||||
|
"yang tidak ada di teks isi null (atau array kosong untuk dasar)."
|
||||||
|
)
|
||||||
@@ -15,6 +15,7 @@ from __future__ import annotations
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from ocr_sprint.config import get_settings
|
from ocr_sprint.config import get_settings
|
||||||
|
from ocr_sprint.llm.extractor import llm_fill_header
|
||||||
from ocr_sprint.pipeline.confidence import compute_confidence, route
|
from ocr_sprint.pipeline.confidence import compute_confidence, route
|
||||||
from ocr_sprint.pipeline.document_detect import DocumentDetectConfig, detect_and_correct
|
from ocr_sprint.pipeline.document_detect import DocumentDetectConfig, detect_and_correct
|
||||||
from ocr_sprint.pipeline.extract.personnel import extract_personnel
|
from ocr_sprint.pipeline.extract.personnel import extract_personnel
|
||||||
@@ -35,6 +36,18 @@ _logger = get_logger(__name__)
|
|||||||
_OCR_CONFIDENCE_FLAG_THRESHOLD = 0.80
|
_OCR_CONFIDENCE_FLAG_THRESHOLD = 0.80
|
||||||
|
|
||||||
|
|
||||||
|
def _header_has_gaps(header: object) -> bool:
|
||||||
|
"""True if any header field worth asking the LLM about is missing.
|
||||||
|
|
||||||
|
Using ``getattr`` so this stays decoupled from the exact attribute
|
||||||
|
names; the schema change cost was too large last time we hard-coded.
|
||||||
|
"""
|
||||||
|
for field in ("nomor_sprint", "tanggal", "satuan_penerbit", "perihal"):
|
||||||
|
if not getattr(header, field, None):
|
||||||
|
return True
|
||||||
|
return not getattr(header, "dasar", None)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineOutput:
|
class PipelineOutput:
|
||||||
"""Bundle returned by the orchestrator."""
|
"""Bundle returned by the orchestrator."""
|
||||||
@@ -84,6 +97,20 @@ def run_pipeline(content: bytes) -> PipelineOutput:
|
|||||||
header = extract_header(full_text)
|
header = extract_header(full_text)
|
||||||
ttd = find_signatory(full_text)
|
ttd = find_signatory(full_text)
|
||||||
|
|
||||||
|
# Phase 5 — hybrid extraction. The regex layer is deterministic but
|
||||||
|
# brittle to layout variants between satuan; if any header field is
|
||||||
|
# still missing we ask the local LLM to fill the gaps. The merger
|
||||||
|
# never lets the LLM overwrite a field that regex already captured.
|
||||||
|
llm_flags: list[ReviewFlag] = []
|
||||||
|
if s.llm_enabled and _header_has_gaps(header):
|
||||||
|
merged = llm_fill_header(full_text, header)
|
||||||
|
if merged is None:
|
||||||
|
llm_flags.append(ReviewFlag.LLM_UNAVAILABLE)
|
||||||
|
else:
|
||||||
|
if merged.model_dump() != header.model_dump():
|
||||||
|
llm_flags.append(ReviewFlag.LLM_FALLBACK)
|
||||||
|
header = merged
|
||||||
|
|
||||||
personel: list[PersonnelEntry] = []
|
personel: list[PersonnelEntry] = []
|
||||||
if s.tables_enabled and cleaned_pages:
|
if s.tables_enabled and cleaned_pages:
|
||||||
all_tables: list[DetectedTable] = []
|
all_tables: list[DetectedTable] = []
|
||||||
@@ -99,7 +126,7 @@ def run_pipeline(content: bytes) -> PipelineOutput:
|
|||||||
personel_rows=len(personel),
|
personel_rows=len(personel),
|
||||||
)
|
)
|
||||||
|
|
||||||
initial_flags: list[ReviewFlag] = []
|
initial_flags: list[ReviewFlag] = list(llm_flags)
|
||||||
if mean_ocr_conf < _OCR_CONFIDENCE_FLAG_THRESHOLD:
|
if mean_ocr_conf < _OCR_CONFIDENCE_FLAG_THRESHOLD:
|
||||||
initial_flags.append(ReviewFlag.LOW_OCR_CONFIDENCE)
|
initial_flags.append(ReviewFlag.LOW_OCR_CONFIDENCE)
|
||||||
|
|
||||||
|
|||||||
@@ -55,3 +55,7 @@ class DocumentResponse(BaseModel):
|
|||||||
data: ExtractionResult | None = None
|
data: ExtractionResult | None = None
|
||||||
review_flags: list[str] = Field(default_factory=list)
|
review_flags: list[str] = Field(default_factory=list)
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
|
# Phase 6 — HITL review state.
|
||||||
|
approved: bool = False
|
||||||
|
reviewed_by: str | None = None
|
||||||
|
reviewed_at: datetime | None = None
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ class ReviewFlag(str, Enum):
|
|||||||
UNKNOWN_PANGKAT = "unknown_pangkat"
|
UNKNOWN_PANGKAT = "unknown_pangkat"
|
||||||
PERSONNEL_COUNT_MISMATCH = "personnel_count_mismatch"
|
PERSONNEL_COUNT_MISMATCH = "personnel_count_mismatch"
|
||||||
DATE_PARSE_FAILED = "date_parse_failed"
|
DATE_PARSE_FAILED = "date_parse_failed"
|
||||||
|
LLM_FALLBACK = "llm_fallback"
|
||||||
|
LLM_UNAVAILABLE = "llm_unavailable"
|
||||||
|
|
||||||
|
|
||||||
class Signatory(BaseModel):
|
class Signatory(BaseModel):
|
||||||
|
|||||||
62
src/ocr_sprint/schemas/review.py
Normal file
62
src/ocr_sprint/schemas/review.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Request / response schemas for the HITL review endpoints (Phase 6).
|
||||||
|
|
||||||
|
The API surface is deliberately small:
|
||||||
|
|
||||||
|
* ``CorrectionRequest`` — body of ``PATCH /documents/{id}``. A list of
|
||||||
|
``FieldCorrection`` entries; each one is applied atomically (all-or-
|
||||||
|
nothing) and recorded in the audit trail.
|
||||||
|
* ``CorrectionEventResponse`` — single row in ``GET /documents/{id}/history``.
|
||||||
|
* ``ApprovalResponse`` — echo back after ``POST /documents/{id}/approve``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class FieldCorrection(BaseModel):
|
||||||
|
"""One field-level correction.
|
||||||
|
|
||||||
|
``path`` is a dotted JSON path into ``ExtractionResult``. Supported
|
||||||
|
roots: ``header``, ``ttd``, ``personel[n]`` (n is a 0-based index),
|
||||||
|
``untuk``. The path is validated by the repository before being
|
||||||
|
applied; unknown roots return 400.
|
||||||
|
"""
|
||||||
|
|
||||||
|
path: str = Field(..., description="Dotted JSON path, e.g. 'header.nomor_sprint'.")
|
||||||
|
value: Any = Field(..., description="New value (any JSON-serialisable payload).")
|
||||||
|
reason: str | None = Field(
|
||||||
|
None, max_length=512, description="Optional free-form reason for the correction."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionRequest(BaseModel):
|
||||||
|
"""PATCH body — one or more field corrections, applied atomically."""
|
||||||
|
|
||||||
|
corrections: list[FieldCorrection] = Field(..., min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
class CorrectionEventResponse(BaseModel):
|
||||||
|
"""One row of the audit log surfaced by GET /history."""
|
||||||
|
|
||||||
|
id: int
|
||||||
|
job_id: UUID
|
||||||
|
field_path: str
|
||||||
|
old_value: Any | None = None
|
||||||
|
new_value: Any | None = None
|
||||||
|
corrected_by: str | None = None
|
||||||
|
reason: str | None = None
|
||||||
|
corrected_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class ApprovalResponse(BaseModel):
|
||||||
|
"""Echo returned after a job is approved."""
|
||||||
|
|
||||||
|
job_id: UUID
|
||||||
|
approved: bool
|
||||||
|
reviewed_by: str | None = None
|
||||||
|
reviewed_at: datetime | None = None
|
||||||
248
tests/unit/test_api_hitl.py
Normal file
248
tests/unit/test_api_hitl.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
"""End-to-end HTTP tests for the HITL endpoints.
|
||||||
|
|
||||||
|
We re-use the ``fake_pipeline`` style from ``test_api.py`` so we don't pay
|
||||||
|
the PaddleOCR init cost; the orchestrator is monkey-patched to return a
|
||||||
|
synthetic ``ExtractionResult``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from ocr_sprint.main import create_app
|
||||||
|
from ocr_sprint.pipeline import orchestrator as orch_module
|
||||||
|
from ocr_sprint.pipeline.orchestrator import PipelineOutput
|
||||||
|
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||||
|
from ocr_sprint.schemas.extraction import (
|
||||||
|
ExtractionResult,
|
||||||
|
HeaderFields,
|
||||||
|
PersonnelEntry,
|
||||||
|
ReviewFlag,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client() -> TestClient:
|
||||||
|
return TestClient(create_app())
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput:
|
||||||
|
result = ExtractionResult(
|
||||||
|
header=HeaderFields(
|
||||||
|
nomor_sprint="Sprin/1/I/2025",
|
||||||
|
tanggal=date(2025, 1, 1),
|
||||||
|
satuan_penerbit="POLRES TEST",
|
||||||
|
perihal=None, # intentional gap so a PATCH can fill it
|
||||||
|
),
|
||||||
|
personel=[
|
||||||
|
PersonnelEntry(pangkat="AIPDA", nrp="77060000", nama="BUDI", jabatan="ANGGOTA"),
|
||||||
|
],
|
||||||
|
review_flags=[ReviewFlag.MISSING_FIELD],
|
||||||
|
confidence=0.7,
|
||||||
|
)
|
||||||
|
output = PipelineOutput(
|
||||||
|
source_kind=SourceKind.PDF,
|
||||||
|
status=DocumentStatus.NEEDS_REVIEW,
|
||||||
|
confidence=0.7,
|
||||||
|
result=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _fake_run(_content: bytes) -> PipelineOutput:
|
||||||
|
return output
|
||||||
|
|
||||||
|
monkeypatch.setattr(orch_module, "run_pipeline", _fake_run)
|
||||||
|
from ocr_sprint.api.routes import documents as docs_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(docs_module, "run_pipeline", _fake_run)
|
||||||
|
from ocr_sprint.worker import tasks as tasks_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def _create_job(client: TestClient) -> str:
|
||||||
|
post = client.post(
|
||||||
|
"/api/v1/documents?sync=true",
|
||||||
|
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||||
|
)
|
||||||
|
assert post.status_code == 200, post.text
|
||||||
|
body = post.json()
|
||||||
|
assert body["status"] == "needs_review"
|
||||||
|
return str(body["job_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_applies_correction_and_clears_missing_field(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
job_id = _create_job(client)
|
||||||
|
patched = client.patch(
|
||||||
|
f"/api/v1/documents/{job_id}",
|
||||||
|
json={
|
||||||
|
"corrections": [
|
||||||
|
{
|
||||||
|
"path": "header.perihal",
|
||||||
|
"value": "Penyelidikan kasus X",
|
||||||
|
"reason": "LLM missed it",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
headers={"X-User-Id": "reviewer-a"},
|
||||||
|
)
|
||||||
|
assert patched.status_code == 200, patched.text
|
||||||
|
body = patched.json()
|
||||||
|
assert body["data"]["header"]["perihal"] == "Penyelidikan kasus X"
|
||||||
|
# The fake pipeline has both required header fields filled, so the
|
||||||
|
# ``missing_field`` flag is auto-cleared as soon as any correction
|
||||||
|
# lands (the policy re-evaluates required-field coverage on every
|
||||||
|
# edit).
|
||||||
|
assert "missing_field" not in body["review_flags"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_returns_400_for_unknown_path(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
job_id = _create_job(client)
|
||||||
|
resp = client.patch(
|
||||||
|
f"/api/v1/documents/{job_id}",
|
||||||
|
json={"corrections": [{"path": "bogus.field", "value": "x"}]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_is_atomic_on_partial_failure(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
job_id = _create_job(client)
|
||||||
|
resp = client.patch(
|
||||||
|
f"/api/v1/documents/{job_id}",
|
||||||
|
json={
|
||||||
|
"corrections": [
|
||||||
|
{"path": "header.perihal", "value": "OK"},
|
||||||
|
{"path": "bogus.root", "value": "X"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
# The first correction must not have persisted.
|
||||||
|
got = client.get(f"/api/v1/documents/{job_id}")
|
||||||
|
assert got.json()["data"]["header"]["perihal"] is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_history_returns_corrections_in_order(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
job_id = _create_job(client)
|
||||||
|
client.patch(
|
||||||
|
f"/api/v1/documents/{job_id}",
|
||||||
|
json={"corrections": [{"path": "header.perihal", "value": "first"}]},
|
||||||
|
headers={"X-User-Id": "reviewer-a"},
|
||||||
|
)
|
||||||
|
client.patch(
|
||||||
|
f"/api/v1/documents/{job_id}",
|
||||||
|
json={"corrections": [{"path": "header.perihal", "value": "second"}]},
|
||||||
|
headers={"X-User-Id": "reviewer-b"},
|
||||||
|
)
|
||||||
|
|
||||||
|
history = client.get(f"/api/v1/documents/{job_id}/history")
|
||||||
|
assert history.status_code == 200
|
||||||
|
events = history.json()
|
||||||
|
assert [e["new_value"] for e in events] == ["first", "second"]
|
||||||
|
assert [e["corrected_by"] for e in events] == ["reviewer-a", "reviewer-b"]
|
||||||
|
# old_value of the second event should reflect the first edit.
|
||||||
|
assert events[1]["old_value"] == "first"
|
||||||
|
|
||||||
|
|
||||||
|
def test_history_returns_empty_list_for_untouched_job(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
job_id = _create_job(client)
|
||||||
|
history = client.get(f"/api/v1/documents/{job_id}/history")
|
||||||
|
assert history.status_code == 200
|
||||||
|
assert history.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_history_returns_404_for_unknown_job(client: TestClient) -> None:
|
||||||
|
resp = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000/history")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_approve_locks_subsequent_patches(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
job_id = _create_job(client)
|
||||||
|
approved = client.post(
|
||||||
|
f"/api/v1/documents/{job_id}/approve",
|
||||||
|
headers={"X-User-Id": "reviewer-a"},
|
||||||
|
)
|
||||||
|
assert approved.status_code == 200, approved.text
|
||||||
|
body = approved.json()
|
||||||
|
assert body["approved"] is True
|
||||||
|
assert body["reviewed_by"] == "reviewer-a"
|
||||||
|
assert body["reviewed_at"] # non-empty timestamp
|
||||||
|
|
||||||
|
# GET reflects the approval state.
|
||||||
|
got = client.get(f"/api/v1/documents/{job_id}").json()
|
||||||
|
assert got["approved"] is True
|
||||||
|
|
||||||
|
# PATCH after approve must be rejected with 409.
|
||||||
|
patched = client.patch(
|
||||||
|
f"/api/v1/documents/{job_id}",
|
||||||
|
json={"corrections": [{"path": "header.perihal", "value": "X"}]},
|
||||||
|
)
|
||||||
|
assert patched.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
def test_approve_is_idempotent(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
job_id = _create_job(client)
|
||||||
|
first = client.post(
|
||||||
|
f"/api/v1/documents/{job_id}/approve",
|
||||||
|
headers={"X-User-Id": "reviewer-a"},
|
||||||
|
)
|
||||||
|
second = client.post(
|
||||||
|
f"/api/v1/documents/{job_id}/approve",
|
||||||
|
headers={"X-User-Id": "reviewer-b"},
|
||||||
|
)
|
||||||
|
assert first.status_code == 200
|
||||||
|
assert second.status_code == 200
|
||||||
|
# Second approve must NOT change the attribution. (SQLite drops tzinfo
|
||||||
|
# on roundtrip, which changes Pydantic's serialization between the two
|
||||||
|
# calls; compare the naive components.)
|
||||||
|
assert second.json()["reviewed_by"] == "reviewer-a"
|
||||||
|
assert (
|
||||||
|
second.json()["reviewed_at"].rstrip("Z").split("+")[0]
|
||||||
|
== (first.json()["reviewed_at"].rstrip("Z").split("+")[0])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_requires_at_least_one_correction(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
job_id = _create_job(client)
|
||||||
|
resp = client.patch(
|
||||||
|
f"/api/v1/documents/{job_id}",
|
||||||
|
json={"corrections": []},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422 # Pydantic min_length=1 violation
|
||||||
|
|
||||||
|
|
||||||
|
def test_patch_missing_job_returns_404(client: TestClient) -> None:
|
||||||
|
resp = client.patch(
|
||||||
|
"/api/v1/documents/00000000-0000-0000-0000-000000000000",
|
||||||
|
json={"corrections": [{"path": "header.perihal", "value": "X"}]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
238
tests/unit/test_db_hitl.py
Normal file
238
tests/unit/test_db_hitl.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
"""Repository tests for Phase 6 HITL helpers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ocr_sprint.db.base import Base, get_engine, session_scope
|
||||||
|
from ocr_sprint.db.repositories import (
|
||||||
|
InvalidFieldPathError,
|
||||||
|
JobAlreadyApprovedError,
|
||||||
|
JobNotCompletedError,
|
||||||
|
JobNotFoundError,
|
||||||
|
JobRepository,
|
||||||
|
)
|
||||||
|
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_ready() -> None:
|
||||||
|
Base.metadata.create_all(bind=get_engine())
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_completed_job(
|
||||||
|
*,
|
||||||
|
result: dict[str, object] | None = None,
|
||||||
|
flags: list[str] | None = None,
|
||||||
|
) -> uuid4: # type: ignore[type-arg]
|
||||||
|
jid = uuid4()
|
||||||
|
with session_scope() as session:
|
||||||
|
repo = JobRepository(session)
|
||||||
|
repo.create(
|
||||||
|
job_id=jid,
|
||||||
|
filename="x.pdf",
|
||||||
|
source_kind=SourceKind.PDF,
|
||||||
|
blob_key="k",
|
||||||
|
)
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).mark_completed(
|
||||||
|
jid,
|
||||||
|
status=DocumentStatus.NEEDS_REVIEW,
|
||||||
|
confidence=0.7,
|
||||||
|
result=result
|
||||||
|
or {
|
||||||
|
"header": {
|
||||||
|
"nomor_sprint": "Sprin/1/I/2025",
|
||||||
|
"satuan_penerbit": "POLRES X",
|
||||||
|
"perihal": None,
|
||||||
|
},
|
||||||
|
"personel": [
|
||||||
|
{"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"},
|
||||||
|
],
|
||||||
|
"untuk": ["Melaksanakan tugas"],
|
||||||
|
},
|
||||||
|
review_flags=flags or [],
|
||||||
|
)
|
||||||
|
return jid
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_corrections_updates_nested_header_field(db_ready: None) -> None:
|
||||||
|
jid = _seed_completed_job()
|
||||||
|
with session_scope() as session:
|
||||||
|
repo = JobRepository(session)
|
||||||
|
repo.apply_corrections(
|
||||||
|
jid,
|
||||||
|
corrections=[("header.perihal", "Penyelidikan kasus X", "regex miss")],
|
||||||
|
corrected_by="reviewer-a",
|
||||||
|
)
|
||||||
|
row = repo.get_or_raise(jid)
|
||||||
|
assert row.result is not None
|
||||||
|
assert row.result["header"]["perihal"] == "Penyelidikan kasus X"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_corrections_writes_audit_row(db_ready: None) -> None:
|
||||||
|
jid = _seed_completed_job()
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).apply_corrections(
|
||||||
|
jid,
|
||||||
|
corrections=[("header.perihal", "Penyelidikan", None)],
|
||||||
|
corrected_by="reviewer-a",
|
||||||
|
)
|
||||||
|
with session_scope() as session:
|
||||||
|
events = JobRepository(session).list_corrections(jid)
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].field_path == "header.perihal"
|
||||||
|
assert events[0].old_value is None
|
||||||
|
assert events[0].new_value == "Penyelidikan"
|
||||||
|
assert events[0].corrected_by == "reviewer-a"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_corrections_supports_list_index(db_ready: None) -> None:
|
||||||
|
jid = _seed_completed_job()
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).apply_corrections(
|
||||||
|
jid,
|
||||||
|
corrections=[("personel[0].nrp", "77060001", None)],
|
||||||
|
corrected_by=None,
|
||||||
|
)
|
||||||
|
row = JobRepository(session).get_or_raise(jid)
|
||||||
|
assert row.result is not None
|
||||||
|
assert row.result["personel"][0]["nrp"] == "77060001"
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_corrections_is_atomic_on_invalid_path(db_ready: None) -> None:
|
||||||
|
"""A second-correction failure must roll back the first one."""
|
||||||
|
jid = _seed_completed_job()
|
||||||
|
with session_scope() as session, pytest.raises(InvalidFieldPathError):
|
||||||
|
JobRepository(session).apply_corrections(
|
||||||
|
jid,
|
||||||
|
corrections=[
|
||||||
|
("header.perihal", "OK", None),
|
||||||
|
("bogus.root", "X", None),
|
||||||
|
],
|
||||||
|
corrected_by=None,
|
||||||
|
)
|
||||||
|
# The first correction must not have persisted.
|
||||||
|
with session_scope() as session:
|
||||||
|
row = JobRepository(session).get_or_raise(jid)
|
||||||
|
assert row.result is not None
|
||||||
|
assert row.result["header"].get("perihal") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_corrections_rejects_out_of_range_index(db_ready: None) -> None:
|
||||||
|
jid = _seed_completed_job()
|
||||||
|
with session_scope() as session, pytest.raises(InvalidFieldPathError):
|
||||||
|
JobRepository(session).apply_corrections(
|
||||||
|
jid,
|
||||||
|
corrections=[("personel[99].nrp", "77060001", None)],
|
||||||
|
corrected_by=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_corrections_rejects_after_approve(db_ready: None) -> None:
|
||||||
|
jid = _seed_completed_job()
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).approve(jid, reviewed_by="reviewer-a")
|
||||||
|
with session_scope() as session, pytest.raises(JobAlreadyApprovedError):
|
||||||
|
JobRepository(session).apply_corrections(
|
||||||
|
jid,
|
||||||
|
corrections=[("header.perihal", "X", None)],
|
||||||
|
corrected_by="reviewer-a",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_corrections_rejects_missing_job(db_ready: None) -> None:
|
||||||
|
with session_scope() as session, pytest.raises(JobNotFoundError):
|
||||||
|
JobRepository(session).apply_corrections(
|
||||||
|
uuid4(),
|
||||||
|
corrections=[("header.perihal", "X", None)],
|
||||||
|
corrected_by=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_apply_corrections_rejects_pending_job(db_ready: None) -> None:
|
||||||
|
jid = uuid4()
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).create(
|
||||||
|
job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k"
|
||||||
|
)
|
||||||
|
with session_scope() as session, pytest.raises(JobNotCompletedError):
|
||||||
|
JobRepository(session).apply_corrections(
|
||||||
|
jid,
|
||||||
|
corrections=[("header.perihal", "X", None)],
|
||||||
|
corrected_by=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_field_flag_cleared_when_header_gap_filled(db_ready: None) -> None:
|
||||||
|
jid = _seed_completed_job(
|
||||||
|
result={
|
||||||
|
"header": {
|
||||||
|
"nomor_sprint": None,
|
||||||
|
"satuan_penerbit": "POLRES X",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
flags=["missing_field", "low_ocr_confidence"],
|
||||||
|
)
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).apply_corrections(
|
||||||
|
jid,
|
||||||
|
corrections=[("header.nomor_sprint", "Sprin/2/I/2025", None)],
|
||||||
|
corrected_by="reviewer-a",
|
||||||
|
)
|
||||||
|
row = JobRepository(session).get_or_raise(jid)
|
||||||
|
# ``low_ocr_confidence`` stays (correction doesn't resolve that signal),
|
||||||
|
# but ``missing_field`` is gone because every required header key is
|
||||||
|
# now non-empty.
|
||||||
|
assert list(row.review_flags) == ["low_ocr_confidence"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_approve_sets_timestamps_and_is_idempotent(db_ready: None) -> None:
|
||||||
|
jid = _seed_completed_job()
|
||||||
|
with session_scope() as session:
|
||||||
|
row = JobRepository(session).approve(jid, reviewed_by="reviewer-a")
|
||||||
|
first_at = row.reviewed_at
|
||||||
|
assert first_at is not None
|
||||||
|
with session_scope() as session:
|
||||||
|
row = JobRepository(session).approve(jid, reviewed_by="reviewer-b")
|
||||||
|
# Second call must NOT overwrite reviewed_by or reviewed_at.
|
||||||
|
# SQLite drops tzinfo on roundtrip, so compare the naive components.
|
||||||
|
assert row.approved is True
|
||||||
|
assert row.reviewed_by == "reviewer-a"
|
||||||
|
assert row.reviewed_at is not None
|
||||||
|
assert row.reviewed_at.replace(tzinfo=None) == first_at.replace(tzinfo=None)
|
||||||
|
|
||||||
|
|
||||||
|
def test_approve_rejects_pending_job(db_ready: None) -> None:
|
||||||
|
jid = uuid4()
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).create(
|
||||||
|
job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k"
|
||||||
|
)
|
||||||
|
with session_scope() as session, pytest.raises(JobNotCompletedError):
|
||||||
|
JobRepository(session).approve(jid, reviewed_by="rev")
|
||||||
|
|
||||||
|
|
||||||
|
def test_history_returns_events_in_order(db_ready: None) -> None:
|
||||||
|
jid = _seed_completed_job()
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).apply_corrections(
|
||||||
|
jid,
|
||||||
|
corrections=[("header.perihal", "one", None)],
|
||||||
|
corrected_by="r1",
|
||||||
|
)
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).apply_corrections(
|
||||||
|
jid,
|
||||||
|
corrections=[
|
||||||
|
("header.perihal", "two", None),
|
||||||
|
("personel[0].nama", "ANDI", None),
|
||||||
|
],
|
||||||
|
corrected_by="r2",
|
||||||
|
)
|
||||||
|
with session_scope() as session:
|
||||||
|
events = JobRepository(session).list_corrections(jid)
|
||||||
|
assert [e.new_value for e in events] == ["one", "two", "ANDI"]
|
||||||
|
assert [e.corrected_by for e in events] == ["r1", "r2", "r2"]
|
||||||
108
tests/unit/test_llm_client.py
Normal file
108
tests/unit/test_llm_client.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Unit tests for the Ollama HTTP client wrapper.
|
||||||
|
|
||||||
|
We swap ``httpx.Client`` inside ``ocr_sprint.llm.client`` for a builder that
|
||||||
|
returns a real ``httpx.Client`` wrapping a ``MockTransport``. Capturing the
|
||||||
|
original constructor *before* patching avoids infinite recursion in the
|
||||||
|
patched callable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import ocr_sprint.llm.client as llm_client_module
|
||||||
|
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||||
|
|
||||||
|
|
||||||
|
class _Schema(BaseModel):
|
||||||
|
foo: str
|
||||||
|
bar: int
|
||||||
|
|
||||||
|
|
||||||
|
def _ollama_envelope(content: str) -> dict[str, object]:
|
||||||
|
"""Mimic the shape Ollama's /api/chat returns."""
|
||||||
|
return {"message": {"role": "assistant", "content": content}, "done": True}
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_transport(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
handler: Any,
|
||||||
|
) -> None:
|
||||||
|
transport = httpx.MockTransport(handler)
|
||||||
|
real_client = httpx.Client # capture before patching
|
||||||
|
|
||||||
|
def _factory(*_args: object, **kwargs: object) -> httpx.Client:
|
||||||
|
# Strip any caller-provided transport kwarg; we always inject ours.
|
||||||
|
kwargs.pop("transport", None)
|
||||||
|
return real_client(transport=transport, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(llm_client_module.httpx, "Client", _factory)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_json_returns_validated_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
|
def _handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
captured["url"] = str(request.url)
|
||||||
|
captured["body"] = request.read()
|
||||||
|
return httpx.Response(200, json=_ollama_envelope('{"foo": "x", "bar": 7}'))
|
||||||
|
|
||||||
|
_patch_transport(monkeypatch, _handler)
|
||||||
|
|
||||||
|
client = OllamaClient(base_url="http://ollama:11434", model="m", timeout_s=5)
|
||||||
|
out = client.chat_json("system msg", "user msg", _Schema)
|
||||||
|
|
||||||
|
assert out == _Schema(foo="x", bar=7)
|
||||||
|
assert captured["url"] == "http://ollama:11434/api/chat"
|
||||||
|
body = captured["body"]
|
||||||
|
assert isinstance(body, bytes)
|
||||||
|
assert b'"format":"json"' in body
|
||||||
|
assert b'"system msg"' in body
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_json_raises_on_http_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def _handler(_request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(500, text="boom")
|
||||||
|
|
||||||
|
_patch_transport(monkeypatch, _handler)
|
||||||
|
|
||||||
|
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
|
||||||
|
with pytest.raises(LLMUnavailableError, match="Ollama request failed"):
|
||||||
|
client.chat_json("s", "u", _Schema)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_json_raises_on_invalid_json(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def _handler(_request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200, json=_ollama_envelope("this is not json"))
|
||||||
|
|
||||||
|
_patch_transport(monkeypatch, _handler)
|
||||||
|
|
||||||
|
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
|
||||||
|
with pytest.raises(LLMUnavailableError, match="schema"):
|
||||||
|
client.chat_json("s", "u", _Schema)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_json_raises_on_missing_envelope(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def _handler(_request: httpx.Request) -> httpx.Response:
|
||||||
|
return httpx.Response(200, json={"oops": True})
|
||||||
|
|
||||||
|
_patch_transport(monkeypatch, _handler)
|
||||||
|
|
||||||
|
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
|
||||||
|
with pytest.raises(LLMUnavailableError, match=r"message\.content"):
|
||||||
|
client.chat_json("s", "u", _Schema)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_json_raises_on_connection_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
def _handler(request: httpx.Request) -> httpx.Response:
|
||||||
|
raise httpx.ConnectError("nobody home", request=request)
|
||||||
|
|
||||||
|
_patch_transport(monkeypatch, _handler)
|
||||||
|
|
||||||
|
client = OllamaClient(base_url="http://x", model="m", timeout_s=1)
|
||||||
|
with pytest.raises(LLMUnavailableError):
|
||||||
|
client.chat_json("s", "u", _Schema)
|
||||||
90
tests/unit/test_llm_extractor.py
Normal file
90
tests/unit/test_llm_extractor.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Unit tests for the hybrid LLM header extractor / merger."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||||
|
from ocr_sprint.llm.extractor import LLMHeaderResult, _merge, llm_fill_header
|
||||||
|
from ocr_sprint.schemas.extraction import HeaderFields
|
||||||
|
|
||||||
|
|
||||||
|
class _StubClient(OllamaClient):
|
||||||
|
"""Test double that bypasses HTTP entirely."""
|
||||||
|
|
||||||
|
def __init__(self, payload: LLMHeaderResult | Exception) -> None:
|
||||||
|
# Skip the real __init__ — we don't need any real config.
|
||||||
|
self._payload = payload
|
||||||
|
|
||||||
|
def chat_json( # type: ignore[override]
|
||||||
|
self, system: str, user: str, schema_cls: type[BaseModel]
|
||||||
|
) -> BaseModel:
|
||||||
|
if isinstance(self._payload, Exception):
|
||||||
|
raise self._payload
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_keeps_regex_when_present() -> None:
|
||||||
|
regex = HeaderFields(nomor_sprint="Sprin/123/IV/2025/Reskrim", tanggal=date(2025, 4, 21))
|
||||||
|
llm = LLMHeaderResult(nomor_sprint="HALLUCINATED", tanggal=date(1999, 1, 1), perihal="ok")
|
||||||
|
out = _merge(regex, llm)
|
||||||
|
assert out.nomor_sprint == "Sprin/123/IV/2025/Reskrim"
|
||||||
|
assert out.tanggal == date(2025, 4, 21)
|
||||||
|
# Gaps get filled.
|
||||||
|
assert out.perihal == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_fills_gaps() -> None:
|
||||||
|
regex = HeaderFields() # all None
|
||||||
|
llm = LLMHeaderResult(
|
||||||
|
nomor_sprint="Sprin/9/IX/2024",
|
||||||
|
tanggal=date(2024, 9, 1),
|
||||||
|
satuan_penerbit="Polres Bandung",
|
||||||
|
perihal="Penyelidikan",
|
||||||
|
dasar=["UU 2/2002", "Perkap 6/2017"],
|
||||||
|
)
|
||||||
|
out = _merge(regex, llm)
|
||||||
|
assert out.nomor_sprint == "Sprin/9/IX/2024"
|
||||||
|
assert out.tanggal == date(2024, 9, 1)
|
||||||
|
assert out.satuan_penerbit == "Polres Bandung"
|
||||||
|
assert out.perihal == "Penyelidikan"
|
||||||
|
assert out.dasar == ["UU 2/2002", "Perkap 6/2017"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_fill_header_returns_merged_when_client_succeeds() -> None:
|
||||||
|
regex = HeaderFields(nomor_sprint="Sprin/1/I/2025") # has nomor, missing rest
|
||||||
|
stub = _StubClient(
|
||||||
|
LLMHeaderResult(
|
||||||
|
satuan_penerbit="Polres Bandung",
|
||||||
|
perihal="Penyelidikan",
|
||||||
|
dasar=["UU 2/2002"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
out = llm_fill_header(raw_text="...", regex_header=regex, client=stub)
|
||||||
|
assert out is not None
|
||||||
|
assert out.nomor_sprint == "Sprin/1/I/2025"
|
||||||
|
assert out.satuan_penerbit == "Polres Bandung"
|
||||||
|
assert out.perihal == "Penyelidikan"
|
||||||
|
assert out.dasar == ["UU 2/2002"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_fill_header_returns_none_when_unavailable() -> None:
|
||||||
|
stub = _StubClient(LLMUnavailableError("server down"))
|
||||||
|
out = llm_fill_header(raw_text="...", regex_header=HeaderFields(), client=stub)
|
||||||
|
assert out is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_does_not_overwrite_dasar_when_regex_has_it() -> None:
|
||||||
|
regex = HeaderFields(dasar=["UU 2/2002"])
|
||||||
|
llm = LLMHeaderResult(dasar=["something else", "more"])
|
||||||
|
out = _merge(regex, llm)
|
||||||
|
assert out.dasar == ["UU 2/2002"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm_extractor_unused_argument_kept_silent() -> None:
|
||||||
|
# A trivial sanity check that the public function signature accepts
|
||||||
|
# keyword-only `client` — this matches how the orchestrator calls it.
|
||||||
|
pytest.importorskip("ocr_sprint.llm.extractor")
|
||||||
171
tests/unit/test_orchestrator_llm.py
Normal file
171
tests/unit/test_orchestrator_llm.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""Orchestrator-level tests for the Phase 5 hybrid LLM wiring.
|
||||||
|
|
||||||
|
These tests stub out the heavy stages (ingest / preprocess / OCR / table)
|
||||||
|
so we can verify the *branching* behaviour around the LLM step without
|
||||||
|
booting Paddle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ocr_sprint.pipeline import orchestrator as orch_module
|
||||||
|
from ocr_sprint.pipeline.orchestrator import _header_has_gaps, run_pipeline
|
||||||
|
from ocr_sprint.schemas.document import SourceKind
|
||||||
|
from ocr_sprint.schemas.extraction import HeaderFields, ReviewFlag, Signatory
|
||||||
|
|
||||||
|
|
||||||
|
def test_header_has_gaps_detects_missing_fields() -> None:
|
||||||
|
full = HeaderFields(
|
||||||
|
nomor_sprint="Sprin/1/I/2025",
|
||||||
|
tanggal=date(2025, 1, 1),
|
||||||
|
satuan_penerbit="Polres X",
|
||||||
|
perihal="ok",
|
||||||
|
dasar=["UU 2/2002"],
|
||||||
|
)
|
||||||
|
assert _header_has_gaps(full) is False
|
||||||
|
|
||||||
|
assert _header_has_gaps(HeaderFields()) is True
|
||||||
|
assert _header_has_gaps(full.model_copy(update={"perihal": None})) is True
|
||||||
|
assert _header_has_gaps(full.model_copy(update={"dasar": []})) is True
|
||||||
|
|
||||||
|
|
||||||
|
def _stub_pipeline_stages(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
*,
|
||||||
|
raw_text: str,
|
||||||
|
regex_header: HeaderFields,
|
||||||
|
) -> None:
|
||||||
|
"""Replace ingest -> ocr -> tables with cheap fakes so the orchestrator
|
||||||
|
runs without Paddle / PyMuPDF.
|
||||||
|
"""
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ocr_sprint.pipeline import ingest as ingest_module
|
||||||
|
from ocr_sprint.pipeline import ocr as ocr_module
|
||||||
|
from ocr_sprint.pipeline.ingest import IngestedPage
|
||||||
|
|
||||||
|
img = np.full((100, 100, 3), 255, dtype=np.uint8)
|
||||||
|
fake_page = IngestedPage(image=img, page_index=0)
|
||||||
|
fake_ocr_page = ocr_module.OCRPage(
|
||||||
|
lines=[
|
||||||
|
ocr_module.OCRLine(text=raw_text, confidence=0.95, box=((0, 0), (1, 0), (1, 1), (0, 1)))
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(orch_module, "detect_source_kind", lambda _: SourceKind.PDF)
|
||||||
|
monkeypatch.setattr(orch_module, "ingest", lambda *a, **k: [fake_page])
|
||||||
|
monkeypatch.setattr(orch_module, "detect_and_correct", lambda image, _cfg: image)
|
||||||
|
monkeypatch.setattr(orch_module, "preprocess", lambda image, _cfg: image)
|
||||||
|
monkeypatch.setattr(orch_module, "run_ocr", lambda _image: fake_ocr_page)
|
||||||
|
# No tables in these tests.
|
||||||
|
monkeypatch.setattr(orch_module, "run_table_extraction", lambda _img: [])
|
||||||
|
monkeypatch.setattr(orch_module, "extract_personnel", lambda _tables: [])
|
||||||
|
# Header / signatory / validators come from the real implementation
|
||||||
|
# for `extract_header`, but we override to control gap state.
|
||||||
|
monkeypatch.setattr(orch_module, "extract_header", lambda _text: regex_header)
|
||||||
|
monkeypatch.setattr(orch_module, "find_signatory", lambda _text: Signatory())
|
||||||
|
monkeypatch.setattr(orch_module, "validate_extraction", lambda _result: [])
|
||||||
|
# Keep ingest_module referenced so import isn't dropped.
|
||||||
|
assert ingest_module is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_skips_llm_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setenv("LLM_ENABLED", "false")
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
_stub_pipeline_stages(
|
||||||
|
monkeypatch,
|
||||||
|
raw_text="dummy",
|
||||||
|
regex_header=HeaderFields(), # all gaps
|
||||||
|
)
|
||||||
|
|
||||||
|
called = {"n": 0}
|
||||||
|
|
||||||
|
def _trip(*_args: object, **_kwargs: object) -> None:
|
||||||
|
called["n"] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(orch_module, "llm_fill_header", _trip)
|
||||||
|
|
||||||
|
result = run_pipeline(b"%PDF-1.4\n%fake")
|
||||||
|
assert called["n"] == 0
|
||||||
|
assert ReviewFlag.LLM_FALLBACK not in result.result.review_flags
|
||||||
|
assert ReviewFlag.LLM_UNAVAILABLE not in result.result.review_flags
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_skips_llm_when_header_complete(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setenv("LLM_ENABLED", "true")
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
_stub_pipeline_stages(
|
||||||
|
monkeypatch,
|
||||||
|
raw_text="dummy",
|
||||||
|
regex_header=HeaderFields(
|
||||||
|
nomor_sprint="Sprin/1/I/2025",
|
||||||
|
tanggal=date(2025, 1, 1),
|
||||||
|
satuan_penerbit="Polres X",
|
||||||
|
perihal="ok",
|
||||||
|
dasar=["UU 2/2002"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
called = {"n": 0}
|
||||||
|
|
||||||
|
def _trip(*_args: object, **_kwargs: object) -> None:
|
||||||
|
called["n"] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(orch_module, "llm_fill_header", _trip)
|
||||||
|
|
||||||
|
run_pipeline(b"%PDF-1.4\n%fake")
|
||||||
|
assert called["n"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_calls_llm_and_marks_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
monkeypatch.setenv("LLM_ENABLED", "true")
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
regex_partial = HeaderFields(nomor_sprint="Sprin/1/I/2025") # rest missing
|
||||||
|
_stub_pipeline_stages(monkeypatch, raw_text="dummy text", regex_header=regex_partial)
|
||||||
|
|
||||||
|
def _llm(_raw: str, header: HeaderFields, **_: object) -> HeaderFields:
|
||||||
|
return header.model_copy(
|
||||||
|
update={
|
||||||
|
"satuan_penerbit": "Polres Bandung",
|
||||||
|
"perihal": "Penyelidikan",
|
||||||
|
"dasar": ["UU 2/2002"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setattr(orch_module, "llm_fill_header", _llm)
|
||||||
|
|
||||||
|
out = run_pipeline(b"%PDF-1.4\n%fake")
|
||||||
|
assert out.result.header.satuan_penerbit == "Polres Bandung"
|
||||||
|
assert out.result.header.perihal == "Penyelidikan"
|
||||||
|
assert ReviewFlag.LLM_FALLBACK in out.result.review_flags
|
||||||
|
assert ReviewFlag.LLM_UNAVAILABLE not in out.result.review_flags
|
||||||
|
|
||||||
|
|
||||||
|
def test_orchestrator_marks_unavailable_when_llm_returns_none(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
monkeypatch.setenv("LLM_ENABLED", "true")
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
_stub_pipeline_stages(monkeypatch, raw_text="dummy", regex_header=HeaderFields())
|
||||||
|
monkeypatch.setattr(orch_module, "llm_fill_header", lambda *_a, **_k: None)
|
||||||
|
|
||||||
|
out = run_pipeline(b"%PDF-1.4\n%fake")
|
||||||
|
assert ReviewFlag.LLM_UNAVAILABLE in out.result.review_flags
|
||||||
|
assert ReviewFlag.LLM_FALLBACK not in out.result.review_flags
|
||||||
Reference in New Issue
Block a user