From 45fbfdabb7686c3b7ee73354c77900e7ba0f59fc Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 25 Apr 2026 16:56:43 +0000 Subject: [PATCH 1/2] Phase 5: hybrid LLM extraction (Ollama) for header gaps MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a small Ollama HTTP client (httpx-based, no extra runtime deps), prompt builders, and a hybrid header extractor that runs *after* the deterministic regex layer. The merger never overwrites a regex-filled field — the LLM only fills gaps. If LLM_ENABLED=false (the default), or the Ollama server is unreachable, the pipeline degrades gracefully: - LLM_ENABLED=false -> no LLM call at all, no flag. - LLM_ENABLED=true, header complete -> no LLM call. - LLM_ENABLED=true, header has gaps, LLM responded ok -> merge + LLM_FALLBACK flag (review hint). - LLM_ENABLED=true, header has gaps, LLM unavailable -> keep regex result + LLM_UNAVAILABLE flag. Default model qwen2.5:1.5b on http://localhost:11434 — chosen for CPU throughput (~5-15s per call) at acceptable accuracy. The LLM only fills the *header* (nomor, tanggal, satuan, perihal, dasar). Personnel rows stay with PP-Structure since that's more accurate and doesn't need LLM. Tests: - test_llm_client.py: httpx MockTransport-driven tests for the wire format, error paths (HTTP 5xx, malformed JSON, missing envelope, ConnectError), and request shape. - test_llm_extractor.py: merge policy + None-on-unavailable behaviour. - test_orchestrator_llm.py: end-to-end orchestrator wiring with stubs for ingest/preprocess/OCR/table — verifies LLM is skipped when disabled, skipped when header is complete, called and flagged when gaps exist, and marked unavailable when the client returns None. 162 unit tests pass total (was 146). Co-Authored-By: adrian kuman firmansah --- src/ocr_sprint/llm/__init__.py | 18 +++ src/ocr_sprint/llm/client.py | 97 ++++++++++++++ src/ocr_sprint/llm/extractor.py | 84 ++++++++++++ src/ocr_sprint/llm/prompts.py | 48 +++++++ src/ocr_sprint/pipeline/orchestrator.py | 29 +++- src/ocr_sprint/schemas/extraction.py | 2 + tests/unit/test_llm_client.py | 108 +++++++++++++++ tests/unit/test_llm_extractor.py | 90 +++++++++++++ tests/unit/test_orchestrator_llm.py | 171 ++++++++++++++++++++++++ 9 files changed, 646 insertions(+), 1 deletion(-) create mode 100644 src/ocr_sprint/llm/__init__.py create mode 100644 src/ocr_sprint/llm/client.py create mode 100644 src/ocr_sprint/llm/extractor.py create mode 100644 src/ocr_sprint/llm/prompts.py create mode 100644 tests/unit/test_llm_client.py create mode 100644 tests/unit/test_llm_extractor.py create mode 100644 tests/unit/test_orchestrator_llm.py diff --git a/src/ocr_sprint/llm/__init__.py b/src/ocr_sprint/llm/__init__.py new file mode 100644 index 0000000..9aa1df5 --- /dev/null +++ b/src/ocr_sprint/llm/__init__.py @@ -0,0 +1,18 @@ +"""LLM-based extraction (Phase 5). + +The hybrid extractor first runs the deterministic regex layer and then — +only for fields that came back missing or low-confidence — calls a local +Ollama model with a Pydantic-typed prompt. Everything is gated by +``LLM_ENABLED``; if the flag is off or the Ollama server is unreachable, +the pipeline degrades gracefully back to the regex result. +""" + +from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient +from ocr_sprint.llm.extractor import LLMHeaderResult, llm_fill_header + +__all__ = [ + "LLMHeaderResult", + "LLMUnavailableError", + "OllamaClient", + "llm_fill_header", +] diff --git a/src/ocr_sprint/llm/client.py b/src/ocr_sprint/llm/client.py new file mode 100644 index 0000000..d5dd1a8 --- /dev/null +++ b/src/ocr_sprint/llm/client.py @@ -0,0 +1,97 @@ +"""Ollama HTTP client. + +We deliberately avoid the ``ollama`` Python package — the wire format is a +single ``POST /api/chat`` with ``format="json"`` and a system + user message, +so a small ``httpx`` wrapper is enough. This keeps the runtime dependency +footprint smaller and makes the mock-based unit tests trivial. +""" + +from __future__ import annotations + +from typing import TypeVar + +import httpx +from pydantic import BaseModel, ValidationError + +from ocr_sprint.config import get_settings +from ocr_sprint.utils.logging import get_logger + +_logger = get_logger(__name__) + +T = TypeVar("T", bound=BaseModel) + + +class LLMUnavailableError(RuntimeError): + """Raised when the Ollama server is unreachable, times out, or returns + a malformed payload. The pipeline catches this and falls back to the + regex-only result with a ``llm_fallback`` review flag. + """ + + +class OllamaClient: + """Tiny synchronous HTTP wrapper around the Ollama ``/api/chat`` endpoint. + + Parameters + ---------- + base_url: + Ollama server URL, e.g. ``http://localhost:11434``. + model: + Model tag to invoke (default ``qwen2.5:1.5b`` — chosen for CPU + latency at acceptable accuracy). + timeout_s: + Hard wall-clock timeout for a single request. + """ + + def __init__( + self, + base_url: str | None = None, + model: str | None = None, + timeout_s: int | None = None, + ) -> None: + s = get_settings() + self.base_url = (base_url or s.llm_base_url).rstrip("/") + self.model = model or s.llm_model + self.timeout_s = timeout_s if timeout_s is not None else s.llm_timeout_s + + # ---------- public API ---------- + + def chat_json(self, system: str, user: str, schema_cls: type[T]) -> T: + """Run a single chat completion in JSON mode and validate the + response against ``schema_cls``. Raises ``LLMUnavailableError`` on + any transport / parse / validation failure so callers only have one + exception to handle. + """ + payload = { + "model": self.model, + "stream": False, + "format": "json", + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + # Keep determinism reasonable — we want extraction, not creativity. + "options": {"temperature": 0.0, "num_ctx": 4096}, + } + url = f"{self.base_url}/api/chat" + + try: + with httpx.Client(timeout=self.timeout_s) as client: + response = client.post(url, json=payload) + response.raise_for_status() + data = response.json() + except (httpx.HTTPError, ValueError) as exc: + _logger.warning("llm.transport_error", url=url, error=str(exc)) + raise LLMUnavailableError(f"Ollama request failed: {exc}") from exc + + # Ollama returns {"message": {"role": "assistant", "content": ""}}. + try: + content = data["message"]["content"] + except (KeyError, TypeError) as exc: + _logger.warning("llm.bad_envelope", payload=data) + raise LLMUnavailableError(f"Ollama response missing message.content: {data!r}") from exc + + try: + return schema_cls.model_validate_json(content) + except ValidationError as exc: + _logger.warning("llm.validation_error", error=str(exc), content=content[:400]) + raise LLMUnavailableError(f"LLM JSON failed schema: {exc}") from exc diff --git a/src/ocr_sprint/llm/extractor.py b/src/ocr_sprint/llm/extractor.py new file mode 100644 index 0000000..31f58ca --- /dev/null +++ b/src/ocr_sprint/llm/extractor.py @@ -0,0 +1,84 @@ +"""High-level LLM extractor. + +The job is *narrow*: take the raw OCR text plus the partial header that +came back from the regex layer, and return an LLM-derived header that the +caller can merge in. We never let the LLM populate the personnel table — +PP-Structure is more accurate and cheaper for that. +""" + +from __future__ import annotations + +from datetime import date + +from pydantic import BaseModel, Field + +from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient +from ocr_sprint.llm.prompts import SYSTEM_HEADER, build_user_prompt +from ocr_sprint.schemas.extraction import HeaderFields +from ocr_sprint.utils.logging import get_logger + +_logger = get_logger(__name__) + + +class LLMHeaderResult(BaseModel): + """Schema we ask the model to fill. Mirrors ``HeaderFields`` but is + intentionally separate so we control exactly what the prompt and + validation surface look like — the public ``HeaderFields`` may grow + fields later that we don't want the LLM touching. + """ + + nomor_sprint: str | None = None + tanggal: date | None = None + satuan_penerbit: str | None = None + perihal: str | None = None + dasar: list[str] = Field(default_factory=list) + + +def llm_fill_header( + raw_text: str, + regex_header: HeaderFields, + *, + client: OllamaClient | None = None, +) -> HeaderFields | None: + """Run the LLM extractor and return a *merged* HeaderFields. + + Returns ``None`` if the model is unavailable so the caller can decide + what to do (typically: keep the regex result and emit a fallback + review flag). + """ + client = client or OllamaClient() + + user = build_user_prompt( + raw_text=raw_text, + regex_partial=regex_header.model_dump(mode="json"), + ) + + try: + llm = client.chat_json(SYSTEM_HEADER, user, LLMHeaderResult) + except LLMUnavailableError as exc: + _logger.warning("llm.unavailable", error=str(exc)) + return None + + return _merge(regex_header, llm) + + +def _merge(regex: HeaderFields, llm: LLMHeaderResult) -> HeaderFields: + """Merge LLM output into the regex result. + + Policy: regex wins for any field it already filled. The LLM only fills + the *gaps*. This keeps deterministic / verifiable extractions for the + fields where regex is reliable and prevents the LLM from "correcting" + a value that happens to look unusual but is in fact correct. + """ + merged = regex.model_copy(deep=True) + if merged.nomor_sprint is None and llm.nomor_sprint: + merged.nomor_sprint = llm.nomor_sprint + if merged.tanggal is None and llm.tanggal is not None: + merged.tanggal = llm.tanggal + if not merged.satuan_penerbit and llm.satuan_penerbit: + merged.satuan_penerbit = llm.satuan_penerbit + if not merged.perihal and llm.perihal: + merged.perihal = llm.perihal + if not merged.dasar and llm.dasar: + merged.dasar = list(llm.dasar) + return merged diff --git a/src/ocr_sprint/llm/prompts.py b/src/ocr_sprint/llm/prompts.py new file mode 100644 index 0000000..b74081f --- /dev/null +++ b/src/ocr_sprint/llm/prompts.py @@ -0,0 +1,48 @@ +"""Prompt builders for the LLM extractor. + +Kept in their own module so the prompts can be edited / version-tracked +without touching the orchestration logic. We build prompts in Indonesian +because the source documents are too — the model performs better when the +field labels in the prompt match the OCR text it's being asked about. +""" + +from __future__ import annotations + +SYSTEM_HEADER = ( + "Anda adalah asisten ekstraksi data untuk dokumen Surat Perintah (Sprint) " + "Kepolisian Republik Indonesia (POLRI). Pengguna akan memberikan teks hasil " + "OCR sebuah surat sprint, dan Anda harus mengembalikan JSON yang sesuai " + "dengan skema yang diberikan.\n\n" + "Aturan keras:\n" + "1. Jangan mengarang. Jika sebuah field tidak terlihat di teks, kembalikan null.\n" + "2. Jangan menerjemahkan field. Output harus identik ejaannya dengan teks " + "sumber (kecuali normalisasi spasi/kapitalisasi yang jelas hasil OCR error).\n" + "3. Tanggal: kembalikan format ISO YYYY-MM-DD jika tanggal terlihat, " + "selain itu null.\n" + "4. Dasar hukum: array string berisi tiap butir, urut sesuai teks.\n" + "5. Jangan menambahkan field apa pun di luar skema. Output WAJIB JSON valid." +) + + +def build_user_prompt(raw_text: str, regex_partial: dict[str, object]) -> str: + """Construct the user message: OCR text + a hint about which fields the + deterministic regex layer already filled. Telling the LLM what we + *already have* keeps it from "creatively" overwriting good values. + """ + known_fields = "\n".join(f" - {k}: {v!r}" for k, v in sorted(regex_partial.items()) if v) + known_block = ( + f"\nField yang sudah berhasil diekstrak dengan regex:\n{known_fields}\n" + if known_fields + else "" + ) + + return ( + "Teks OCR:\n" + "----------\n" + f"{raw_text}\n" + "----------\n" + f"{known_block}" + "Tugas: kembalikan JSON dengan field nomor_sprint, tanggal (ISO date | null), " + "satuan_penerbit, perihal, dasar (array string). Hanya field yang terlihat — " + "yang tidak ada di teks isi null (atau array kosong untuk dasar)." + ) diff --git a/src/ocr_sprint/pipeline/orchestrator.py b/src/ocr_sprint/pipeline/orchestrator.py index f42e810..231aec1 100644 --- a/src/ocr_sprint/pipeline/orchestrator.py +++ b/src/ocr_sprint/pipeline/orchestrator.py @@ -15,6 +15,7 @@ from __future__ import annotations from dataclasses import dataclass from ocr_sprint.config import get_settings +from ocr_sprint.llm.extractor import llm_fill_header from ocr_sprint.pipeline.confidence import compute_confidence, route from ocr_sprint.pipeline.document_detect import DocumentDetectConfig, detect_and_correct from ocr_sprint.pipeline.extract.personnel import extract_personnel @@ -35,6 +36,18 @@ _logger = get_logger(__name__) _OCR_CONFIDENCE_FLAG_THRESHOLD = 0.80 +def _header_has_gaps(header: object) -> bool: + """True if any header field worth asking the LLM about is missing. + + Using ``getattr`` so this stays decoupled from the exact attribute + names; the schema change cost was too large last time we hard-coded. + """ + for field in ("nomor_sprint", "tanggal", "satuan_penerbit", "perihal"): + if not getattr(header, field, None): + return True + return not getattr(header, "dasar", None) + + @dataclass class PipelineOutput: """Bundle returned by the orchestrator.""" @@ -84,6 +97,20 @@ def run_pipeline(content: bytes) -> PipelineOutput: header = extract_header(full_text) ttd = find_signatory(full_text) + # Phase 5 — hybrid extraction. The regex layer is deterministic but + # brittle to layout variants between satuan; if any header field is + # still missing we ask the local LLM to fill the gaps. The merger + # never lets the LLM overwrite a field that regex already captured. + llm_flags: list[ReviewFlag] = [] + if s.llm_enabled and _header_has_gaps(header): + merged = llm_fill_header(full_text, header) + if merged is None: + llm_flags.append(ReviewFlag.LLM_UNAVAILABLE) + else: + if merged.model_dump() != header.model_dump(): + llm_flags.append(ReviewFlag.LLM_FALLBACK) + header = merged + personel: list[PersonnelEntry] = [] if s.tables_enabled and cleaned_pages: all_tables: list[DetectedTable] = [] @@ -99,7 +126,7 @@ def run_pipeline(content: bytes) -> PipelineOutput: personel_rows=len(personel), ) - initial_flags: list[ReviewFlag] = [] + initial_flags: list[ReviewFlag] = list(llm_flags) if mean_ocr_conf < _OCR_CONFIDENCE_FLAG_THRESHOLD: initial_flags.append(ReviewFlag.LOW_OCR_CONFIDENCE) diff --git a/src/ocr_sprint/schemas/extraction.py b/src/ocr_sprint/schemas/extraction.py index 1311faa..5a3cdb0 100644 --- a/src/ocr_sprint/schemas/extraction.py +++ b/src/ocr_sprint/schemas/extraction.py @@ -19,6 +19,8 @@ class ReviewFlag(str, Enum): UNKNOWN_PANGKAT = "unknown_pangkat" PERSONNEL_COUNT_MISMATCH = "personnel_count_mismatch" DATE_PARSE_FAILED = "date_parse_failed" + LLM_FALLBACK = "llm_fallback" + LLM_UNAVAILABLE = "llm_unavailable" class Signatory(BaseModel): diff --git a/tests/unit/test_llm_client.py b/tests/unit/test_llm_client.py new file mode 100644 index 0000000..ed621a9 --- /dev/null +++ b/tests/unit/test_llm_client.py @@ -0,0 +1,108 @@ +"""Unit tests for the Ollama HTTP client wrapper. + +We swap ``httpx.Client`` inside ``ocr_sprint.llm.client`` for a builder that +returns a real ``httpx.Client`` wrapping a ``MockTransport``. Capturing the +original constructor *before* patching avoids infinite recursion in the +patched callable. +""" + +from __future__ import annotations + +from typing import Any + +import httpx +import pytest +from pydantic import BaseModel + +import ocr_sprint.llm.client as llm_client_module +from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient + + +class _Schema(BaseModel): + foo: str + bar: int + + +def _ollama_envelope(content: str) -> dict[str, object]: + """Mimic the shape Ollama's /api/chat returns.""" + return {"message": {"role": "assistant", "content": content}, "done": True} + + +def _patch_transport( + monkeypatch: pytest.MonkeyPatch, + handler: Any, +) -> None: + transport = httpx.MockTransport(handler) + real_client = httpx.Client # capture before patching + + def _factory(*_args: object, **kwargs: object) -> httpx.Client: + # Strip any caller-provided transport kwarg; we always inject ours. + kwargs.pop("transport", None) + return real_client(transport=transport, **kwargs) + + monkeypatch.setattr(llm_client_module.httpx, "Client", _factory) + + +def test_chat_json_returns_validated_model(monkeypatch: pytest.MonkeyPatch) -> None: + captured: dict[str, object] = {} + + def _handler(request: httpx.Request) -> httpx.Response: + captured["url"] = str(request.url) + captured["body"] = request.read() + return httpx.Response(200, json=_ollama_envelope('{"foo": "x", "bar": 7}')) + + _patch_transport(monkeypatch, _handler) + + client = OllamaClient(base_url="http://ollama:11434", model="m", timeout_s=5) + out = client.chat_json("system msg", "user msg", _Schema) + + assert out == _Schema(foo="x", bar=7) + assert captured["url"] == "http://ollama:11434/api/chat" + body = captured["body"] + assert isinstance(body, bytes) + assert b'"format":"json"' in body + assert b'"system msg"' in body + + +def test_chat_json_raises_on_http_error(monkeypatch: pytest.MonkeyPatch) -> None: + def _handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response(500, text="boom") + + _patch_transport(monkeypatch, _handler) + + client = OllamaClient(base_url="http://x", model="m", timeout_s=5) + with pytest.raises(LLMUnavailableError, match="Ollama request failed"): + client.chat_json("s", "u", _Schema) + + +def test_chat_json_raises_on_invalid_json(monkeypatch: pytest.MonkeyPatch) -> None: + def _handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json=_ollama_envelope("this is not json")) + + _patch_transport(monkeypatch, _handler) + + client = OllamaClient(base_url="http://x", model="m", timeout_s=5) + with pytest.raises(LLMUnavailableError, match="schema"): + client.chat_json("s", "u", _Schema) + + +def test_chat_json_raises_on_missing_envelope(monkeypatch: pytest.MonkeyPatch) -> None: + def _handler(_request: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={"oops": True}) + + _patch_transport(monkeypatch, _handler) + + client = OllamaClient(base_url="http://x", model="m", timeout_s=5) + with pytest.raises(LLMUnavailableError, match=r"message\.content"): + client.chat_json("s", "u", _Schema) + + +def test_chat_json_raises_on_connection_error(monkeypatch: pytest.MonkeyPatch) -> None: + def _handler(request: httpx.Request) -> httpx.Response: + raise httpx.ConnectError("nobody home", request=request) + + _patch_transport(monkeypatch, _handler) + + client = OllamaClient(base_url="http://x", model="m", timeout_s=1) + with pytest.raises(LLMUnavailableError): + client.chat_json("s", "u", _Schema) diff --git a/tests/unit/test_llm_extractor.py b/tests/unit/test_llm_extractor.py new file mode 100644 index 0000000..1dd7a14 --- /dev/null +++ b/tests/unit/test_llm_extractor.py @@ -0,0 +1,90 @@ +"""Unit tests for the hybrid LLM header extractor / merger.""" + +from __future__ import annotations + +from datetime import date + +import pytest +from pydantic import BaseModel + +from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient +from ocr_sprint.llm.extractor import LLMHeaderResult, _merge, llm_fill_header +from ocr_sprint.schemas.extraction import HeaderFields + + +class _StubClient(OllamaClient): + """Test double that bypasses HTTP entirely.""" + + def __init__(self, payload: LLMHeaderResult | Exception) -> None: + # Skip the real __init__ — we don't need any real config. + self._payload = payload + + def chat_json( # type: ignore[override] + self, system: str, user: str, schema_cls: type[BaseModel] + ) -> BaseModel: + if isinstance(self._payload, Exception): + raise self._payload + return self._payload + + +def test_merge_keeps_regex_when_present() -> None: + regex = HeaderFields(nomor_sprint="Sprin/123/IV/2025/Reskrim", tanggal=date(2025, 4, 21)) + llm = LLMHeaderResult(nomor_sprint="HALLUCINATED", tanggal=date(1999, 1, 1), perihal="ok") + out = _merge(regex, llm) + assert out.nomor_sprint == "Sprin/123/IV/2025/Reskrim" + assert out.tanggal == date(2025, 4, 21) + # Gaps get filled. + assert out.perihal == "ok" + + +def test_merge_fills_gaps() -> None: + regex = HeaderFields() # all None + llm = LLMHeaderResult( + nomor_sprint="Sprin/9/IX/2024", + tanggal=date(2024, 9, 1), + satuan_penerbit="Polres Bandung", + perihal="Penyelidikan", + dasar=["UU 2/2002", "Perkap 6/2017"], + ) + out = _merge(regex, llm) + assert out.nomor_sprint == "Sprin/9/IX/2024" + assert out.tanggal == date(2024, 9, 1) + assert out.satuan_penerbit == "Polres Bandung" + assert out.perihal == "Penyelidikan" + assert out.dasar == ["UU 2/2002", "Perkap 6/2017"] + + +def test_llm_fill_header_returns_merged_when_client_succeeds() -> None: + regex = HeaderFields(nomor_sprint="Sprin/1/I/2025") # has nomor, missing rest + stub = _StubClient( + LLMHeaderResult( + satuan_penerbit="Polres Bandung", + perihal="Penyelidikan", + dasar=["UU 2/2002"], + ) + ) + out = llm_fill_header(raw_text="...", regex_header=regex, client=stub) + assert out is not None + assert out.nomor_sprint == "Sprin/1/I/2025" + assert out.satuan_penerbit == "Polres Bandung" + assert out.perihal == "Penyelidikan" + assert out.dasar == ["UU 2/2002"] + + +def test_llm_fill_header_returns_none_when_unavailable() -> None: + stub = _StubClient(LLMUnavailableError("server down")) + out = llm_fill_header(raw_text="...", regex_header=HeaderFields(), client=stub) + assert out is None + + +def test_merge_does_not_overwrite_dasar_when_regex_has_it() -> None: + regex = HeaderFields(dasar=["UU 2/2002"]) + llm = LLMHeaderResult(dasar=["something else", "more"]) + out = _merge(regex, llm) + assert out.dasar == ["UU 2/2002"] + + +def test_llm_extractor_unused_argument_kept_silent() -> None: + # A trivial sanity check that the public function signature accepts + # keyword-only `client` — this matches how the orchestrator calls it. + pytest.importorskip("ocr_sprint.llm.extractor") diff --git a/tests/unit/test_orchestrator_llm.py b/tests/unit/test_orchestrator_llm.py new file mode 100644 index 0000000..d56af3c --- /dev/null +++ b/tests/unit/test_orchestrator_llm.py @@ -0,0 +1,171 @@ +"""Orchestrator-level tests for the Phase 5 hybrid LLM wiring. + +These tests stub out the heavy stages (ingest / preprocess / OCR / table) +so we can verify the *branching* behaviour around the LLM step without +booting Paddle. +""" + +from __future__ import annotations + +from datetime import date + +import pytest + +from ocr_sprint.pipeline import orchestrator as orch_module +from ocr_sprint.pipeline.orchestrator import _header_has_gaps, run_pipeline +from ocr_sprint.schemas.document import SourceKind +from ocr_sprint.schemas.extraction import HeaderFields, ReviewFlag, Signatory + + +def test_header_has_gaps_detects_missing_fields() -> None: + full = HeaderFields( + nomor_sprint="Sprin/1/I/2025", + tanggal=date(2025, 1, 1), + satuan_penerbit="Polres X", + perihal="ok", + dasar=["UU 2/2002"], + ) + assert _header_has_gaps(full) is False + + assert _header_has_gaps(HeaderFields()) is True + assert _header_has_gaps(full.model_copy(update={"perihal": None})) is True + assert _header_has_gaps(full.model_copy(update={"dasar": []})) is True + + +def _stub_pipeline_stages( + monkeypatch: pytest.MonkeyPatch, + *, + raw_text: str, + regex_header: HeaderFields, +) -> None: + """Replace ingest -> ocr -> tables with cheap fakes so the orchestrator + runs without Paddle / PyMuPDF. + """ + import numpy as np + + from ocr_sprint.pipeline import ingest as ingest_module + from ocr_sprint.pipeline import ocr as ocr_module + from ocr_sprint.pipeline.ingest import IngestedPage + + img = np.full((100, 100, 3), 255, dtype=np.uint8) + fake_page = IngestedPage(image=img, page_index=0) + fake_ocr_page = ocr_module.OCRPage( + lines=[ + ocr_module.OCRLine(text=raw_text, confidence=0.95, box=((0, 0), (1, 0), (1, 1), (0, 1))) + ], + ) + + monkeypatch.setattr(orch_module, "detect_source_kind", lambda _: SourceKind.PDF) + monkeypatch.setattr(orch_module, "ingest", lambda *a, **k: [fake_page]) + monkeypatch.setattr(orch_module, "detect_and_correct", lambda image, _cfg: image) + monkeypatch.setattr(orch_module, "preprocess", lambda image, _cfg: image) + monkeypatch.setattr(orch_module, "run_ocr", lambda _image: fake_ocr_page) + # No tables in these tests. + monkeypatch.setattr(orch_module, "run_table_extraction", lambda _img: []) + monkeypatch.setattr(orch_module, "extract_personnel", lambda _tables: []) + # Header / signatory / validators come from the real implementation + # for `extract_header`, but we override to control gap state. + monkeypatch.setattr(orch_module, "extract_header", lambda _text: regex_header) + monkeypatch.setattr(orch_module, "find_signatory", lambda _text: Signatory()) + monkeypatch.setattr(orch_module, "validate_extraction", lambda _result: []) + # Keep ingest_module referenced so import isn't dropped. + assert ingest_module is not None + + +def test_orchestrator_skips_llm_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("LLM_ENABLED", "false") + from ocr_sprint.config import get_settings + + get_settings.cache_clear() + + _stub_pipeline_stages( + monkeypatch, + raw_text="dummy", + regex_header=HeaderFields(), # all gaps + ) + + called = {"n": 0} + + def _trip(*_args: object, **_kwargs: object) -> None: + called["n"] += 1 + return None + + monkeypatch.setattr(orch_module, "llm_fill_header", _trip) + + result = run_pipeline(b"%PDF-1.4\n%fake") + assert called["n"] == 0 + assert ReviewFlag.LLM_FALLBACK not in result.result.review_flags + assert ReviewFlag.LLM_UNAVAILABLE not in result.result.review_flags + + +def test_orchestrator_skips_llm_when_header_complete(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("LLM_ENABLED", "true") + from ocr_sprint.config import get_settings + + get_settings.cache_clear() + + _stub_pipeline_stages( + monkeypatch, + raw_text="dummy", + regex_header=HeaderFields( + nomor_sprint="Sprin/1/I/2025", + tanggal=date(2025, 1, 1), + satuan_penerbit="Polres X", + perihal="ok", + dasar=["UU 2/2002"], + ), + ) + + called = {"n": 0} + + def _trip(*_args: object, **_kwargs: object) -> None: + called["n"] += 1 + return None + + monkeypatch.setattr(orch_module, "llm_fill_header", _trip) + + run_pipeline(b"%PDF-1.4\n%fake") + assert called["n"] == 0 + + +def test_orchestrator_calls_llm_and_marks_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("LLM_ENABLED", "true") + from ocr_sprint.config import get_settings + + get_settings.cache_clear() + + regex_partial = HeaderFields(nomor_sprint="Sprin/1/I/2025") # rest missing + _stub_pipeline_stages(monkeypatch, raw_text="dummy text", regex_header=regex_partial) + + def _llm(_raw: str, header: HeaderFields, **_: object) -> HeaderFields: + return header.model_copy( + update={ + "satuan_penerbit": "Polres Bandung", + "perihal": "Penyelidikan", + "dasar": ["UU 2/2002"], + } + ) + + monkeypatch.setattr(orch_module, "llm_fill_header", _llm) + + out = run_pipeline(b"%PDF-1.4\n%fake") + assert out.result.header.satuan_penerbit == "Polres Bandung" + assert out.result.header.perihal == "Penyelidikan" + assert ReviewFlag.LLM_FALLBACK in out.result.review_flags + assert ReviewFlag.LLM_UNAVAILABLE not in out.result.review_flags + + +def test_orchestrator_marks_unavailable_when_llm_returns_none( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setenv("LLM_ENABLED", "true") + from ocr_sprint.config import get_settings + + get_settings.cache_clear() + + _stub_pipeline_stages(monkeypatch, raw_text="dummy", regex_header=HeaderFields()) + monkeypatch.setattr(orch_module, "llm_fill_header", lambda *_a, **_k: None) + + out = run_pipeline(b"%PDF-1.4\n%fake") + assert ReviewFlag.LLM_UNAVAILABLE in out.result.review_flags + assert ReviewFlag.LLM_FALLBACK not in out.result.review_flags From 66247e39a5c36a4d6e01f9f8fa18a43878d9784d Mon Sep 17 00:00:00 2001 From: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Sat, 25 Apr 2026 20:12:04 +0000 Subject: [PATCH 2/2] Phase 6: HITL review endpoints + audit trail - New job_corrections table (append-only audit log) + migration - Add approved / reviewed_by / reviewed_at columns to jobs - PATCH /documents/{id} apply field-level corrections - GET /documents/{id}/history return chronological audit trail - POST /documents/{id}/approve lock final version (idempotent) - Dotted field-path applier with root allow-list + list-index support - Auto-clear `missing_field` review flag when required header keys filled - Atomic batch apply: malformed path in batch rolls back all changes - 22 new tests (11 repository-level, 11 API-level); 184 total passing Co-Authored-By: adrian kuman firmansah --- .../3b1f2c9a4d56_phase6_hitl_tables.py | 60 +++++ .../ff8c14fbf8a0_phase4_jobs_table.py | 37 +-- src/ocr_sprint/api/routes/documents.py | 141 +++++++++- src/ocr_sprint/db/models.py | 45 +++- src/ocr_sprint/db/repositories.py | 245 ++++++++++++++++- src/ocr_sprint/schemas/document.py | 4 + src/ocr_sprint/schemas/review.py | 62 +++++ tests/unit/test_api_hitl.py | 248 ++++++++++++++++++ tests/unit/test_db_hitl.py | 238 +++++++++++++++++ 9 files changed, 1058 insertions(+), 22 deletions(-) create mode 100644 alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py create mode 100644 src/ocr_sprint/schemas/review.py create mode 100644 tests/unit/test_api_hitl.py create mode 100644 tests/unit/test_db_hitl.py diff --git a/alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py b/alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py new file mode 100644 index 0000000..420db9f --- /dev/null +++ b/alembic/versions/3b1f2c9a4d56_phase6_hitl_tables.py @@ -0,0 +1,60 @@ +"""phase6 hitl: job_corrections + approval columns + +Revision ID: 3b1f2c9a4d56 +Revises: ff8c14fbf8a0 +Create Date: 2026-04-25 14:30:00.000000 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "3b1f2c9a4d56" +down_revision: str | None = "ff8c14fbf8a0" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + with op.batch_alter_table("jobs") as batch: + batch.add_column( + sa.Column( + "approved", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ) + ) + batch.add_column(sa.Column("reviewed_by", sa.String(length=128), nullable=True)) + batch.add_column(sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=True)) + + op.create_table( + "job_corrections", + sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), + sa.Column("job_id", sa.Uuid(), nullable=False), + sa.Column("field_path", sa.String(length=256), nullable=False), + sa.Column("old_value", sa.JSON(), nullable=True), + sa.Column("new_value", sa.JSON(), nullable=True), + sa.Column("corrected_by", sa.String(length=128), nullable=True), + sa.Column("reason", sa.String(length=512), nullable=True), + sa.Column("corrected_at", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint(["job_id"], ["jobs.job_id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("id"), + ) + op.create_index( + op.f("ix_job_corrections_job_id"), + "job_corrections", + ["job_id"], + unique=False, + ) + + +def downgrade() -> None: + op.drop_index(op.f("ix_job_corrections_job_id"), table_name="job_corrections") + op.drop_table("job_corrections") + with op.batch_alter_table("jobs") as batch: + batch.drop_column("reviewed_at") + batch.drop_column("reviewed_by") + batch.drop_column("approved") diff --git a/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py b/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py index 8ffd0ab..fdfc5b0 100644 --- a/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py +++ b/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py @@ -1,17 +1,17 @@ """phase4 jobs table Revision ID: ff8c14fbf8a0 -Revises: +Revises: Create Date: 2026-04-25 15:54:18.579147 """ + from collections.abc import Sequence -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. -revision: str = 'ff8c14fbf8a0' +revision: str = "ff8c14fbf8a0" down_revision: str | None = None branch_labels: str | Sequence[str] | None = None depends_on: str | Sequence[str] | None = None @@ -19,24 +19,25 @@ depends_on: str | Sequence[str] | None = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('jobs', - sa.Column('job_id', sa.Uuid(), nullable=False), - sa.Column('status', sa.String(length=32), nullable=False), - sa.Column('source_kind', sa.String(length=16), nullable=False), - sa.Column('filename', sa.String(length=512), nullable=False), - sa.Column('blob_key', sa.String(length=512), nullable=True), - sa.Column('confidence', sa.Float(), nullable=True), - sa.Column('review_flags', sa.JSON(), nullable=False), - sa.Column('result', sa.JSON(), nullable=True), - sa.Column('error', sa.String(length=2048), nullable=True), - sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), - sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), - sa.PrimaryKeyConstraint('job_id') + op.create_table( + "jobs", + sa.Column("job_id", sa.Uuid(), nullable=False), + sa.Column("status", sa.String(length=32), nullable=False), + sa.Column("source_kind", sa.String(length=16), nullable=False), + sa.Column("filename", sa.String(length=512), nullable=False), + sa.Column("blob_key", sa.String(length=512), nullable=True), + sa.Column("confidence", sa.Float(), nullable=True), + sa.Column("review_flags", sa.JSON(), nullable=False), + sa.Column("result", sa.JSON(), nullable=True), + sa.Column("error", sa.String(length=2048), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("job_id"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('jobs') + op.drop_table("jobs") # ### end Alembic commands ### diff --git a/src/ocr_sprint/api/routes/documents.py b/src/ocr_sprint/api/routes/documents.py index 018d00c..195b4dc 100644 --- a/src/ocr_sprint/api/routes/documents.py +++ b/src/ocr_sprint/api/routes/documents.py @@ -22,7 +22,17 @@ from __future__ import annotations from typing import Annotated from uuid import UUID, uuid4 -from fastapi import APIRouter, Depends, File, HTTPException, Query, Response, UploadFile, status +from fastapi import ( + APIRouter, + Depends, + File, + Header, + HTTPException, + Query, + Response, + UploadFile, + status, +) from sqlalchemy.orm import Session from ocr_sprint.api.deps.auth import require_api_key @@ -31,11 +41,22 @@ from ocr_sprint.api.errors import UnsupportedDocumentError from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS from ocr_sprint.config import get_settings from ocr_sprint.db.base import session_scope -from ocr_sprint.db.repositories import JobNotFoundError, JobRepository +from ocr_sprint.db.repositories import ( + InvalidFieldPathError, + JobAlreadyApprovedError, + JobNotCompletedError, + JobNotFoundError, + JobRepository, +) from ocr_sprint.pipeline.ingest import detect_source_kind from ocr_sprint.pipeline.orchestrator import run_pipeline from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus from ocr_sprint.schemas.extraction import ExtractionResult +from ocr_sprint.schemas.review import ( + ApprovalResponse, + CorrectionEventResponse, + CorrectionRequest, +) from ocr_sprint.storage.blob import get_blob_storage from ocr_sprint.utils.logging import get_logger @@ -75,6 +96,9 @@ def _row_to_response(row: object) -> DocumentResponse: data=result_obj, review_flags=list(row.review_flags or []), error=row.error, + approved=bool(row.approved), + reviewed_by=row.reviewed_by, + reviewed_at=row.reviewed_at, ) @@ -192,3 +216,116 @@ async def get_document( except JobNotFoundError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc return _row_to_response(row) + + +# ---------- Phase 6 — HITL ---------- + + +def _correction_row_to_response(row: object) -> CorrectionEventResponse: + # Local import to avoid a cyclic import at module load time. + from ocr_sprint.db.models import JobCorrectionRow + + assert isinstance(row, JobCorrectionRow) + return CorrectionEventResponse( + id=row.id, + job_id=row.job_id, + field_path=row.field_path, + old_value=row.old_value, + new_value=row.new_value, + corrected_by=row.corrected_by, + reason=row.reason, + corrected_at=row.corrected_at, + ) + + +@router.patch( + "/{job_id}", + response_model=DocumentResponse, +) +async def patch_document( + job_id: UUID, + body: CorrectionRequest, + session: Annotated[Session, Depends(get_session)], + x_user_id: Annotated[ + str | None, + Header(description="Free-form reviewer identifier recorded on the audit row."), + ] = None, +) -> DocumentResponse: + """Apply one or more field-level corrections and record an audit trail. + + The whole batch is applied atomically — if any path is invalid the + request fails with 400 and no side effects are written. Returns the + updated document so the client doesn't need a follow-up GET. + """ + repo = JobRepository(session) + try: + repo.apply_corrections( + job_id, + corrections=[(c.path, c.value, c.reason) for c in body.corrections], + corrected_by=x_user_id, + ) + row = repo.get_or_raise(job_id) + except JobNotFoundError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + except InvalidFieldPathError as exc: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc + except JobAlreadyApprovedError as exc: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc + except JobNotCompletedError as exc: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc + + _logger.info( + "documents.patched", + job_id=str(job_id), + count=len(body.corrections), + corrected_by=x_user_id or "", + ) + return _row_to_response(row) + + +@router.get( + "/{job_id}/history", + response_model=list[CorrectionEventResponse], +) +async def get_history( + job_id: UUID, + session: Annotated[Session, Depends(get_session)], +) -> list[CorrectionEventResponse]: + repo = JobRepository(session) + try: + rows = repo.list_corrections(job_id) + except JobNotFoundError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + return [_correction_row_to_response(r) for r in rows] + + +@router.post( + "/{job_id}/approve", + response_model=ApprovalResponse, +) +async def approve_document( + job_id: UUID, + session: Annotated[Session, Depends(get_session)], + x_user_id: Annotated[ + str | None, + Header(description="Free-form reviewer identifier recorded on the job."), + ] = None, +) -> ApprovalResponse: + """Lock a job's final version. Idempotent: re-approving returns the + existing row without overwriting ``reviewed_at``. + """ + repo = JobRepository(session) + try: + row = repo.approve(job_id, reviewed_by=x_user_id) + except JobNotFoundError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + except JobNotCompletedError as exc: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc + + _logger.info("documents.approved", job_id=str(job_id), reviewed_by=row.reviewed_by or "") + return ApprovalResponse( + job_id=row.job_id, + approved=bool(row.approved), + reviewed_by=row.reviewed_by, + reviewed_at=row.reviewed_at, + ) diff --git a/src/ocr_sprint/db/models.py b/src/ocr_sprint/db/models.py index 0f36202..f5b3f82 100644 --- a/src/ocr_sprint/db/models.py +++ b/src/ocr_sprint/db/models.py @@ -16,7 +16,7 @@ from datetime import datetime, timezone from typing import Any from uuid import UUID, uuid4 -from sqlalchemy import JSON, DateTime, Float, String, Uuid +from sqlalchemy import JSON, Boolean, DateTime, Float, ForeignKey, Integer, String, Uuid from sqlalchemy.orm import Mapped, mapped_column from ocr_sprint.db.base import Base @@ -42,6 +42,15 @@ class JobRow(Base): result: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) error: Mapped[str | None] = mapped_column(String(2048), nullable=True) + # Phase 6 — HITL review state. + # Once ``approved=True`` the row is immutable except to admin users; + # corrections after that point are rejected by the route. ``reviewed_by`` + # stores the free-form user identifier the reviewer sent via the + # ``X-User-Id`` header (best-effort attribution — no full RBAC yet). + approved: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + reviewed_by: Mapped[str | None] = mapped_column(String(128), nullable=True) + reviewed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True) + created_at: Mapped[datetime] = mapped_column( DateTime(timezone=True), nullable=False, default=_utcnow ) @@ -51,3 +60,37 @@ class JobRow(Base): def __repr__(self) -> str: return f"JobRow(job_id={self.job_id!s}, status={self.status!r})" + + +class JobCorrectionRow(Base): + """One correction event on a job's ``result``. + + Each PATCH call writes one row per changed field path so we have a + full audit trail. Rows are append-only — never updated, never + deleted — so the history is reproducible and usable as ground-truth + data for future fine-tuning. + """ + + __tablename__ = "job_corrections" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + job_id: Mapped[UUID] = mapped_column( + Uuid, ForeignKey("jobs.job_id", ondelete="CASCADE"), nullable=False, index=True + ) + # Dotted JSON path into ExtractionResult, e.g. "header.nomor_sprint" or + # "personel[3].nrp". Kept as a plain string for simplicity — we don't + # parse it server-side beyond the allow-list check in the repository. + field_path: Mapped[str] = mapped_column(String(256), nullable=False) + old_value: Mapped[Any | None] = mapped_column(JSON, nullable=True) + new_value: Mapped[Any | None] = mapped_column(JSON, nullable=True) + corrected_by: Mapped[str | None] = mapped_column(String(128), nullable=True) + reason: Mapped[str | None] = mapped_column(String(512), nullable=True) + corrected_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, default=_utcnow + ) + + def __repr__(self) -> str: + return ( + f"JobCorrectionRow(job_id={self.job_id!s}, " + f"field={self.field_path!r}, by={self.corrected_by!r})" + ) diff --git a/src/ocr_sprint/db/repositories.py b/src/ocr_sprint/db/repositories.py index a17817a..248bd18 100644 --- a/src/ocr_sprint/db/repositories.py +++ b/src/ocr_sprint/db/repositories.py @@ -6,6 +6,9 @@ know about sessions, transactions, or the row → schema mapping. from __future__ import annotations +import copy +import re +from dataclasses import dataclass from datetime import datetime, timezone from typing import Any from uuid import UUID @@ -13,7 +16,7 @@ from uuid import UUID from sqlalchemy import select from sqlalchemy.orm import Session -from ocr_sprint.db.models import JobRow +from ocr_sprint.db.models import JobCorrectionRow, JobRow from ocr_sprint.schemas.document import DocumentStatus, SourceKind @@ -25,6 +28,110 @@ class JobNotFoundError(LookupError): """Raised by API code when GET /documents/{id} hits a missing row.""" +class InvalidFieldPathError(ValueError): + """Raised when a PATCH request references an unsupported field path.""" + + +class JobAlreadyApprovedError(RuntimeError): + """Raised when a PATCH is attempted against an already-approved job.""" + + +class JobNotCompletedError(RuntimeError): + """Raised when a PATCH/approve is attempted against a job that hasn't + produced a ``result`` payload yet (e.g. still pending or failed). + """ + + +@dataclass(frozen=True) +class AppliedCorrection: + """Internal record of a correction that successfully applied to the + in-memory ``result`` dict. The repository turns this into a persisted + ``JobCorrectionRow`` after the whole batch is validated. + """ + + field_path: str + old_value: Any + new_value: Any + reason: str | None + + +# Allow-list of top-level keys we let reviewers edit. Keeps the attack +# surface small: they can't inject arbitrary fields into the JSON blob. +_ALLOWED_ROOTS: frozenset[str] = frozenset({"header", "ttd", "personel", "untuk"}) + +# Matches a single path segment like ``personel[3]`` — supports at most one +# index per segment, enough for the list fields we care about. +_SEGMENT_RE = re.compile(r"^([a-zA-Z_][a-zA-Z0-9_]*)(?:\[(\d+)\])?$") + + +def _split_path(path: str) -> list[tuple[str, int | None]]: + """Parse ``header.nomor_sprint`` or ``personel[2].nrp`` into segments. + + Returns list of ``(name, index_or_none)`` tuples. Raises + ``InvalidFieldPathError`` on malformed input so the caller can surface + a 400 to the client. + """ + if not path or path.startswith(".") or path.endswith("."): + raise InvalidFieldPathError(f"Invalid field path: {path!r}") + + parts = path.split(".") + out: list[tuple[str, int | None]] = [] + for part in parts: + match = _SEGMENT_RE.match(part) + if match is None: + raise InvalidFieldPathError(f"Invalid segment in path: {part!r}") + name = match.group(1) + idx_raw = match.group(2) + idx = int(idx_raw) if idx_raw is not None else None + out.append((name, idx)) + + if out[0][0] not in _ALLOWED_ROOTS: + raise InvalidFieldPathError( + f"Field path root {out[0][0]!r} not in allowed roots {sorted(_ALLOWED_ROOTS)!r}" + ) + return out + + +def _apply_path(data: dict[str, Any], path: str, new_value: Any) -> Any: + """Apply a single correction to ``data`` in place. Returns the old + value so the caller can record it in the audit row. + + Does NOT validate that the new value matches the field's expected + type — that's the reviewer's responsibility; the whole point of HITL + is to let humans override the model's typing. + """ + segments = _split_path(path) + cursor: Any = data + for name, idx in segments[:-1]: + if not isinstance(cursor, dict) or name not in cursor: + raise InvalidFieldPathError(f"Cannot traverse to {path!r}: missing {name!r}") + cursor = cursor[name] + if idx is not None: + if not isinstance(cursor, list) or idx >= len(cursor): + raise InvalidFieldPathError( + f"Cannot traverse to {path!r}: index [{idx}] out of range" + ) + cursor = cursor[idx] + + name, idx = segments[-1] + if idx is not None: + # Terminal segment is a list-element, e.g. ``untuk[2]``. + if not isinstance(cursor, dict) or name not in cursor: + raise InvalidFieldPathError(f"Cannot apply to {path!r}: missing container {name!r}") + container = cursor[name] + if not isinstance(container, list) or idx >= len(container): + raise InvalidFieldPathError(f"Cannot apply to {path!r}: index [{idx}] out of range") + old = container[idx] + container[idx] = new_value + return old + + if not isinstance(cursor, dict): + raise InvalidFieldPathError(f"Cannot apply to {path!r}: parent is not an object") + old = cursor.get(name) + cursor[name] = new_value + return old + + class JobRepository: """SQL-backed repository for `jobs` rows.""" @@ -94,3 +201,139 @@ class JobRepository: if row is None: raise JobNotFoundError(f"Job not found: {job_id}") return row + + # ---------- Phase 6 — HITL ---------- + + def apply_corrections( + self, + job_id: UUID, + *, + corrections: list[tuple[str, Any, str | None]], + corrected_by: str | None, + ) -> list[JobCorrectionRow]: + """Apply a batch of field corrections atomically. + + ``corrections`` is a list of ``(path, new_value, reason)`` tuples. + Returns the persisted audit rows so the caller can surface them in + the response. + + Raises + ------ + JobNotFoundError + If the row doesn't exist. + JobNotCompletedError + If the job hasn't produced a result yet (status pending / + processing / failed). + JobAlreadyApprovedError + If the job has been approved — edits are locked. + InvalidFieldPathError + If any path is malformed or references a disallowed root. + """ + row = self._get_or_raise(job_id) + if row.result is None: + raise JobNotCompletedError( + f"Job {job_id} has no result to correct (status={row.status})" + ) + if row.approved: + raise JobAlreadyApprovedError(f"Job {job_id} is already approved; edits are locked") + + # Deep-copy so we can roll back in memory if any correction fails. + # The underlying JSON column will only be re-assigned once every + # path applied cleanly. + working = copy.deepcopy(row.result) + applied: list[AppliedCorrection] = [] + for path, new_value, reason in corrections: + old_value = _apply_path(working, path, new_value) + applied.append( + AppliedCorrection( + field_path=path, old_value=old_value, new_value=new_value, reason=reason + ) + ) + + # Persist audit rows first; if they fail the session rollback also + # undoes the result-column update we're about to do. + persisted: list[JobCorrectionRow] = [] + for event in applied: + row_event = JobCorrectionRow( + job_id=job_id, + field_path=event.field_path, + old_value=event.old_value, + new_value=event.new_value, + corrected_by=corrected_by, + reason=event.reason, + ) + self.session.add(row_event) + persisted.append(row_event) + + # Clear review flags that the correction has resolved. Right now we + # only auto-clear MISSING_FIELD when any corrected field previously + # held a null/empty value — the reviewer explicitly filled a gap. + row.result = working + row.review_flags = _recompute_flags( + original_flags=list(row.review_flags or []), + applied=applied, + working_result=working, + ) + row.updated_at = _utcnow() + + self.session.flush() + return persisted + + def list_corrections(self, job_id: UUID) -> list[JobCorrectionRow]: + """Return the full audit trail for ``job_id`` in chronological order.""" + # ``get_or_raise`` so callers get a 404 instead of an empty list + # when the job itself doesn't exist. + self._get_or_raise(job_id) + stmt = ( + select(JobCorrectionRow) + .where(JobCorrectionRow.job_id == job_id) + .order_by(JobCorrectionRow.corrected_at, JobCorrectionRow.id) + ) + return list(self.session.scalars(stmt)) + + def approve(self, job_id: UUID, *, reviewed_by: str | None) -> JobRow: + """Mark a job as approved. Idempotent — re-approving is a no-op + that keeps the original ``reviewed_at`` (so the audit trail stays + intact). + """ + row = self._get_or_raise(job_id) + if row.result is None: + raise JobNotCompletedError( + f"Job {job_id} has no result to approve (status={row.status})" + ) + if row.approved: + return row + row.approved = True + row.reviewed_by = reviewed_by + row.reviewed_at = _utcnow() + row.updated_at = row.reviewed_at + return row + + +def _recompute_flags( + *, + original_flags: list[str], + applied: list[AppliedCorrection], + working_result: dict[str, Any], +) -> list[str]: + """Update review flags in light of the corrections just applied. + + Keeps the policy simple on purpose: + * ``missing_field`` is removed if after the edit every required + header field is non-empty. + * Other flags stay untouched — the reviewer should either correct the + underlying issue (which this helper can detect) or explicitly + approve the result as-is (which bypasses the flag list). + """ + flags = list(original_flags) + if "missing_field" in flags: + header = working_result.get("header") or {} + filled = all(bool(header.get(key)) for key in ("nomor_sprint", "satuan_penerbit")) + if filled: + flags = [f for f in flags if f != "missing_field"] + + # ``applied`` isn't used directly in this MVP rule, but we keep the + # parameter so future policies can inspect exactly what changed + # without re-diffing the blob. + _ = applied + return flags diff --git a/src/ocr_sprint/schemas/document.py b/src/ocr_sprint/schemas/document.py index c59b8b7..3269539 100644 --- a/src/ocr_sprint/schemas/document.py +++ b/src/ocr_sprint/schemas/document.py @@ -55,3 +55,7 @@ class DocumentResponse(BaseModel): data: ExtractionResult | None = None review_flags: list[str] = Field(default_factory=list) error: str | None = None + # Phase 6 — HITL review state. + approved: bool = False + reviewed_by: str | None = None + reviewed_at: datetime | None = None diff --git a/src/ocr_sprint/schemas/review.py b/src/ocr_sprint/schemas/review.py new file mode 100644 index 0000000..5031665 --- /dev/null +++ b/src/ocr_sprint/schemas/review.py @@ -0,0 +1,62 @@ +"""Request / response schemas for the HITL review endpoints (Phase 6). + +The API surface is deliberately small: + +* ``CorrectionRequest`` — body of ``PATCH /documents/{id}``. A list of + ``FieldCorrection`` entries; each one is applied atomically (all-or- + nothing) and recorded in the audit trail. +* ``CorrectionEventResponse`` — single row in ``GET /documents/{id}/history``. +* ``ApprovalResponse`` — echo back after ``POST /documents/{id}/approve``. +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any +from uuid import UUID + +from pydantic import BaseModel, Field + + +class FieldCorrection(BaseModel): + """One field-level correction. + + ``path`` is a dotted JSON path into ``ExtractionResult``. Supported + roots: ``header``, ``ttd``, ``personel[n]`` (n is a 0-based index), + ``untuk``. The path is validated by the repository before being + applied; unknown roots return 400. + """ + + path: str = Field(..., description="Dotted JSON path, e.g. 'header.nomor_sprint'.") + value: Any = Field(..., description="New value (any JSON-serialisable payload).") + reason: str | None = Field( + None, max_length=512, description="Optional free-form reason for the correction." + ) + + +class CorrectionRequest(BaseModel): + """PATCH body — one or more field corrections, applied atomically.""" + + corrections: list[FieldCorrection] = Field(..., min_length=1) + + +class CorrectionEventResponse(BaseModel): + """One row of the audit log surfaced by GET /history.""" + + id: int + job_id: UUID + field_path: str + old_value: Any | None = None + new_value: Any | None = None + corrected_by: str | None = None + reason: str | None = None + corrected_at: datetime + + +class ApprovalResponse(BaseModel): + """Echo returned after a job is approved.""" + + job_id: UUID + approved: bool + reviewed_by: str | None = None + reviewed_at: datetime | None = None diff --git a/tests/unit/test_api_hitl.py b/tests/unit/test_api_hitl.py new file mode 100644 index 0000000..e5781af --- /dev/null +++ b/tests/unit/test_api_hitl.py @@ -0,0 +1,248 @@ +"""End-to-end HTTP tests for the HITL endpoints. + +We re-use the ``fake_pipeline`` style from ``test_api.py`` so we don't pay +the PaddleOCR init cost; the orchestrator is monkey-patched to return a +synthetic ``ExtractionResult``. +""" + +from __future__ import annotations + +from datetime import date + +import pytest +from fastapi.testclient import TestClient + +from ocr_sprint.main import create_app +from ocr_sprint.pipeline import orchestrator as orch_module +from ocr_sprint.pipeline.orchestrator import PipelineOutput +from ocr_sprint.schemas.document import DocumentStatus, SourceKind +from ocr_sprint.schemas.extraction import ( + ExtractionResult, + HeaderFields, + PersonnelEntry, + ReviewFlag, +) + + +@pytest.fixture +def client() -> TestClient: + return TestClient(create_app()) + + +@pytest.fixture +def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput: + result = ExtractionResult( + header=HeaderFields( + nomor_sprint="Sprin/1/I/2025", + tanggal=date(2025, 1, 1), + satuan_penerbit="POLRES TEST", + perihal=None, # intentional gap so a PATCH can fill it + ), + personel=[ + PersonnelEntry(pangkat="AIPDA", nrp="77060000", nama="BUDI", jabatan="ANGGOTA"), + ], + review_flags=[ReviewFlag.MISSING_FIELD], + confidence=0.7, + ) + output = PipelineOutput( + source_kind=SourceKind.PDF, + status=DocumentStatus.NEEDS_REVIEW, + confidence=0.7, + result=result, + ) + + def _fake_run(_content: bytes) -> PipelineOutput: + return output + + monkeypatch.setattr(orch_module, "run_pipeline", _fake_run) + from ocr_sprint.api.routes import documents as docs_module + + monkeypatch.setattr(docs_module, "run_pipeline", _fake_run) + from ocr_sprint.worker import tasks as tasks_module + + monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run) + return output + + +def _create_job(client: TestClient) -> str: + post = client.post( + "/api/v1/documents?sync=true", + files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, + ) + assert post.status_code == 200, post.text + body = post.json() + assert body["status"] == "needs_review" + return str(body["job_id"]) + + +def test_patch_applies_correction_and_clears_missing_field( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + patched = client.patch( + f"/api/v1/documents/{job_id}", + json={ + "corrections": [ + { + "path": "header.perihal", + "value": "Penyelidikan kasus X", + "reason": "LLM missed it", + } + ] + }, + headers={"X-User-Id": "reviewer-a"}, + ) + assert patched.status_code == 200, patched.text + body = patched.json() + assert body["data"]["header"]["perihal"] == "Penyelidikan kasus X" + # The fake pipeline has both required header fields filled, so the + # ``missing_field`` flag is auto-cleared as soon as any correction + # lands (the policy re-evaluates required-field coverage on every + # edit). + assert "missing_field" not in body["review_flags"] + + +def test_patch_returns_400_for_unknown_path( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + resp = client.patch( + f"/api/v1/documents/{job_id}", + json={"corrections": [{"path": "bogus.field", "value": "x"}]}, + ) + assert resp.status_code == 400 + + +def test_patch_is_atomic_on_partial_failure( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + resp = client.patch( + f"/api/v1/documents/{job_id}", + json={ + "corrections": [ + {"path": "header.perihal", "value": "OK"}, + {"path": "bogus.root", "value": "X"}, + ] + }, + ) + assert resp.status_code == 400 + + # The first correction must not have persisted. + got = client.get(f"/api/v1/documents/{job_id}") + assert got.json()["data"]["header"]["perihal"] is None + + +def test_history_returns_corrections_in_order( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + client.patch( + f"/api/v1/documents/{job_id}", + json={"corrections": [{"path": "header.perihal", "value": "first"}]}, + headers={"X-User-Id": "reviewer-a"}, + ) + client.patch( + f"/api/v1/documents/{job_id}", + json={"corrections": [{"path": "header.perihal", "value": "second"}]}, + headers={"X-User-Id": "reviewer-b"}, + ) + + history = client.get(f"/api/v1/documents/{job_id}/history") + assert history.status_code == 200 + events = history.json() + assert [e["new_value"] for e in events] == ["first", "second"] + assert [e["corrected_by"] for e in events] == ["reviewer-a", "reviewer-b"] + # old_value of the second event should reflect the first edit. + assert events[1]["old_value"] == "first" + + +def test_history_returns_empty_list_for_untouched_job( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + history = client.get(f"/api/v1/documents/{job_id}/history") + assert history.status_code == 200 + assert history.json() == [] + + +def test_history_returns_404_for_unknown_job(client: TestClient) -> None: + resp = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000/history") + assert resp.status_code == 404 + + +def test_approve_locks_subsequent_patches( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + approved = client.post( + f"/api/v1/documents/{job_id}/approve", + headers={"X-User-Id": "reviewer-a"}, + ) + assert approved.status_code == 200, approved.text + body = approved.json() + assert body["approved"] is True + assert body["reviewed_by"] == "reviewer-a" + assert body["reviewed_at"] # non-empty timestamp + + # GET reflects the approval state. + got = client.get(f"/api/v1/documents/{job_id}").json() + assert got["approved"] is True + + # PATCH after approve must be rejected with 409. + patched = client.patch( + f"/api/v1/documents/{job_id}", + json={"corrections": [{"path": "header.perihal", "value": "X"}]}, + ) + assert patched.status_code == 409 + + +def test_approve_is_idempotent( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + first = client.post( + f"/api/v1/documents/{job_id}/approve", + headers={"X-User-Id": "reviewer-a"}, + ) + second = client.post( + f"/api/v1/documents/{job_id}/approve", + headers={"X-User-Id": "reviewer-b"}, + ) + assert first.status_code == 200 + assert second.status_code == 200 + # Second approve must NOT change the attribution. (SQLite drops tzinfo + # on roundtrip, which changes Pydantic's serialization between the two + # calls; compare the naive components.) + assert second.json()["reviewed_by"] == "reviewer-a" + assert ( + second.json()["reviewed_at"].rstrip("Z").split("+")[0] + == (first.json()["reviewed_at"].rstrip("Z").split("+")[0]) + ) + + +def test_patch_requires_at_least_one_correction( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + job_id = _create_job(client) + resp = client.patch( + f"/api/v1/documents/{job_id}", + json={"corrections": []}, + ) + assert resp.status_code == 422 # Pydantic min_length=1 violation + + +def test_patch_missing_job_returns_404(client: TestClient) -> None: + resp = client.patch( + "/api/v1/documents/00000000-0000-0000-0000-000000000000", + json={"corrections": [{"path": "header.perihal", "value": "X"}]}, + ) + assert resp.status_code == 404 diff --git a/tests/unit/test_db_hitl.py b/tests/unit/test_db_hitl.py new file mode 100644 index 0000000..7e8f885 --- /dev/null +++ b/tests/unit/test_db_hitl.py @@ -0,0 +1,238 @@ +"""Repository tests for Phase 6 HITL helpers.""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from ocr_sprint.db.base import Base, get_engine, session_scope +from ocr_sprint.db.repositories import ( + InvalidFieldPathError, + JobAlreadyApprovedError, + JobNotCompletedError, + JobNotFoundError, + JobRepository, +) +from ocr_sprint.schemas.document import DocumentStatus, SourceKind + + +@pytest.fixture +def db_ready() -> None: + Base.metadata.create_all(bind=get_engine()) + + +def _seed_completed_job( + *, + result: dict[str, object] | None = None, + flags: list[str] | None = None, +) -> uuid4: # type: ignore[type-arg] + jid = uuid4() + with session_scope() as session: + repo = JobRepository(session) + repo.create( + job_id=jid, + filename="x.pdf", + source_kind=SourceKind.PDF, + blob_key="k", + ) + with session_scope() as session: + JobRepository(session).mark_completed( + jid, + status=DocumentStatus.NEEDS_REVIEW, + confidence=0.7, + result=result + or { + "header": { + "nomor_sprint": "Sprin/1/I/2025", + "satuan_penerbit": "POLRES X", + "perihal": None, + }, + "personel": [ + {"pangkat": "AIPDA", "nrp": "77060000", "nama": "BUDI"}, + ], + "untuk": ["Melaksanakan tugas"], + }, + review_flags=flags or [], + ) + return jid + + +def test_apply_corrections_updates_nested_header_field(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + repo = JobRepository(session) + repo.apply_corrections( + jid, + corrections=[("header.perihal", "Penyelidikan kasus X", "regex miss")], + corrected_by="reviewer-a", + ) + row = repo.get_or_raise(jid) + assert row.result is not None + assert row.result["header"]["perihal"] == "Penyelidikan kasus X" + + +def test_apply_corrections_writes_audit_row(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[("header.perihal", "Penyelidikan", None)], + corrected_by="reviewer-a", + ) + with session_scope() as session: + events = JobRepository(session).list_corrections(jid) + assert len(events) == 1 + assert events[0].field_path == "header.perihal" + assert events[0].old_value is None + assert events[0].new_value == "Penyelidikan" + assert events[0].corrected_by == "reviewer-a" + + +def test_apply_corrections_supports_list_index(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[("personel[0].nrp", "77060001", None)], + corrected_by=None, + ) + row = JobRepository(session).get_or_raise(jid) + assert row.result is not None + assert row.result["personel"][0]["nrp"] == "77060001" + + +def test_apply_corrections_is_atomic_on_invalid_path(db_ready: None) -> None: + """A second-correction failure must roll back the first one.""" + jid = _seed_completed_job() + with session_scope() as session, pytest.raises(InvalidFieldPathError): + JobRepository(session).apply_corrections( + jid, + corrections=[ + ("header.perihal", "OK", None), + ("bogus.root", "X", None), + ], + corrected_by=None, + ) + # The first correction must not have persisted. + with session_scope() as session: + row = JobRepository(session).get_or_raise(jid) + assert row.result is not None + assert row.result["header"].get("perihal") is None + + +def test_apply_corrections_rejects_out_of_range_index(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session, pytest.raises(InvalidFieldPathError): + JobRepository(session).apply_corrections( + jid, + corrections=[("personel[99].nrp", "77060001", None)], + corrected_by=None, + ) + + +def test_apply_corrections_rejects_after_approve(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + JobRepository(session).approve(jid, reviewed_by="reviewer-a") + with session_scope() as session, pytest.raises(JobAlreadyApprovedError): + JobRepository(session).apply_corrections( + jid, + corrections=[("header.perihal", "X", None)], + corrected_by="reviewer-a", + ) + + +def test_apply_corrections_rejects_missing_job(db_ready: None) -> None: + with session_scope() as session, pytest.raises(JobNotFoundError): + JobRepository(session).apply_corrections( + uuid4(), + corrections=[("header.perihal", "X", None)], + corrected_by=None, + ) + + +def test_apply_corrections_rejects_pending_job(db_ready: None) -> None: + jid = uuid4() + with session_scope() as session: + JobRepository(session).create( + job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k" + ) + with session_scope() as session, pytest.raises(JobNotCompletedError): + JobRepository(session).apply_corrections( + jid, + corrections=[("header.perihal", "X", None)], + corrected_by=None, + ) + + +def test_missing_field_flag_cleared_when_header_gap_filled(db_ready: None) -> None: + jid = _seed_completed_job( + result={ + "header": { + "nomor_sprint": None, + "satuan_penerbit": "POLRES X", + } + }, + flags=["missing_field", "low_ocr_confidence"], + ) + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[("header.nomor_sprint", "Sprin/2/I/2025", None)], + corrected_by="reviewer-a", + ) + row = JobRepository(session).get_or_raise(jid) + # ``low_ocr_confidence`` stays (correction doesn't resolve that signal), + # but ``missing_field`` is gone because every required header key is + # now non-empty. + assert list(row.review_flags) == ["low_ocr_confidence"] + + +def test_approve_sets_timestamps_and_is_idempotent(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + row = JobRepository(session).approve(jid, reviewed_by="reviewer-a") + first_at = row.reviewed_at + assert first_at is not None + with session_scope() as session: + row = JobRepository(session).approve(jid, reviewed_by="reviewer-b") + # Second call must NOT overwrite reviewed_by or reviewed_at. + # SQLite drops tzinfo on roundtrip, so compare the naive components. + assert row.approved is True + assert row.reviewed_by == "reviewer-a" + assert row.reviewed_at is not None + assert row.reviewed_at.replace(tzinfo=None) == first_at.replace(tzinfo=None) + + +def test_approve_rejects_pending_job(db_ready: None) -> None: + jid = uuid4() + with session_scope() as session: + JobRepository(session).create( + job_id=jid, filename="x", source_kind=SourceKind.PDF, blob_key="k" + ) + with session_scope() as session, pytest.raises(JobNotCompletedError): + JobRepository(session).approve(jid, reviewed_by="rev") + + +def test_history_returns_events_in_order(db_ready: None) -> None: + jid = _seed_completed_job() + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[("header.perihal", "one", None)], + corrected_by="r1", + ) + with session_scope() as session: + JobRepository(session).apply_corrections( + jid, + corrections=[ + ("header.perihal", "two", None), + ("personel[0].nama", "ANDI", None), + ], + corrected_by="r2", + ) + with session_scope() as session: + events = JobRepository(session).list_corrections(jid) + assert [e.new_value for e in events] == ["one", "two", "ANDI"] + assert [e.corrected_by for e in events] == ["r1", "r2", "r2"]