Phase 5: hybrid LLM extraction (Ollama) for header gaps
Adds a small Ollama HTTP client (httpx-based, no extra runtime deps),
prompt builders, and a hybrid header extractor that runs *after* the
deterministic regex layer. The merger never overwrites a regex-filled
field — the LLM only fills gaps. If LLM_ENABLED=false (the default), or
the Ollama server is unreachable, the pipeline degrades gracefully:
- LLM_ENABLED=false -> no LLM call at all, no flag.
- LLM_ENABLED=true,
header complete -> no LLM call.
- LLM_ENABLED=true,
header has gaps,
LLM responded ok -> merge + LLM_FALLBACK flag (review hint).
- LLM_ENABLED=true,
header has gaps,
LLM unavailable -> keep regex result + LLM_UNAVAILABLE flag.
Default model qwen2.5:1.5b on http://localhost:11434 — chosen for CPU
throughput (~5-15s per call) at acceptable accuracy. The LLM only fills
the *header* (nomor, tanggal, satuan, perihal, dasar). Personnel rows
stay with PP-Structure since that's more accurate and doesn't need LLM.
Tests:
- test_llm_client.py: httpx MockTransport-driven tests for the wire
format, error paths (HTTP 5xx, malformed JSON, missing envelope,
ConnectError), and request shape.
- test_llm_extractor.py: merge policy + None-on-unavailable behaviour.
- test_orchestrator_llm.py: end-to-end orchestrator wiring with stubs
for ingest/preprocess/OCR/table — verifies LLM is skipped when
disabled, skipped when header is complete, called and flagged when
gaps exist, and marked unavailable when the client returns None.
162 unit tests pass total (was 146).
Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com>
This commit is contained in:
18
src/ocr_sprint/llm/__init__.py
Normal file
18
src/ocr_sprint/llm/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""LLM-based extraction (Phase 5).
|
||||
|
||||
The hybrid extractor first runs the deterministic regex layer and then —
|
||||
only for fields that came back missing or low-confidence — calls a local
|
||||
Ollama model with a Pydantic-typed prompt. Everything is gated by
|
||||
``LLM_ENABLED``; if the flag is off or the Ollama server is unreachable,
|
||||
the pipeline degrades gracefully back to the regex result.
|
||||
"""
|
||||
|
||||
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||
from ocr_sprint.llm.extractor import LLMHeaderResult, llm_fill_header
|
||||
|
||||
__all__ = [
|
||||
"LLMHeaderResult",
|
||||
"LLMUnavailableError",
|
||||
"OllamaClient",
|
||||
"llm_fill_header",
|
||||
]
|
||||
97
src/ocr_sprint/llm/client.py
Normal file
97
src/ocr_sprint/llm/client.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Ollama HTTP client.
|
||||
|
||||
We deliberately avoid the ``ollama`` Python package — the wire format is a
|
||||
single ``POST /api/chat`` with ``format="json"`` and a system + user message,
|
||||
so a small ``httpx`` wrapper is enough. This keeps the runtime dependency
|
||||
footprint smaller and makes the mock-based unit tests trivial.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TypeVar
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from ocr_sprint.config import get_settings
|
||||
from ocr_sprint.utils.logging import get_logger
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class LLMUnavailableError(RuntimeError):
|
||||
"""Raised when the Ollama server is unreachable, times out, or returns
|
||||
a malformed payload. The pipeline catches this and falls back to the
|
||||
regex-only result with a ``llm_fallback`` review flag.
|
||||
"""
|
||||
|
||||
|
||||
class OllamaClient:
|
||||
"""Tiny synchronous HTTP wrapper around the Ollama ``/api/chat`` endpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
base_url:
|
||||
Ollama server URL, e.g. ``http://localhost:11434``.
|
||||
model:
|
||||
Model tag to invoke (default ``qwen2.5:1.5b`` — chosen for CPU
|
||||
latency at acceptable accuracy).
|
||||
timeout_s:
|
||||
Hard wall-clock timeout for a single request.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
timeout_s: int | None = None,
|
||||
) -> None:
|
||||
s = get_settings()
|
||||
self.base_url = (base_url or s.llm_base_url).rstrip("/")
|
||||
self.model = model or s.llm_model
|
||||
self.timeout_s = timeout_s if timeout_s is not None else s.llm_timeout_s
|
||||
|
||||
# ---------- public API ----------
|
||||
|
||||
def chat_json(self, system: str, user: str, schema_cls: type[T]) -> T:
|
||||
"""Run a single chat completion in JSON mode and validate the
|
||||
response against ``schema_cls``. Raises ``LLMUnavailableError`` on
|
||||
any transport / parse / validation failure so callers only have one
|
||||
exception to handle.
|
||||
"""
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"stream": False,
|
||||
"format": "json",
|
||||
"messages": [
|
||||
{"role": "system", "content": system},
|
||||
{"role": "user", "content": user},
|
||||
],
|
||||
# Keep determinism reasonable — we want extraction, not creativity.
|
||||
"options": {"temperature": 0.0, "num_ctx": 4096},
|
||||
}
|
||||
url = f"{self.base_url}/api/chat"
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=self.timeout_s) as client:
|
||||
response = client.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
except (httpx.HTTPError, ValueError) as exc:
|
||||
_logger.warning("llm.transport_error", url=url, error=str(exc))
|
||||
raise LLMUnavailableError(f"Ollama request failed: {exc}") from exc
|
||||
|
||||
# Ollama returns {"message": {"role": "assistant", "content": "<json>"}}.
|
||||
try:
|
||||
content = data["message"]["content"]
|
||||
except (KeyError, TypeError) as exc:
|
||||
_logger.warning("llm.bad_envelope", payload=data)
|
||||
raise LLMUnavailableError(f"Ollama response missing message.content: {data!r}") from exc
|
||||
|
||||
try:
|
||||
return schema_cls.model_validate_json(content)
|
||||
except ValidationError as exc:
|
||||
_logger.warning("llm.validation_error", error=str(exc), content=content[:400])
|
||||
raise LLMUnavailableError(f"LLM JSON failed schema: {exc}") from exc
|
||||
84
src/ocr_sprint/llm/extractor.py
Normal file
84
src/ocr_sprint/llm/extractor.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""High-level LLM extractor.
|
||||
|
||||
The job is *narrow*: take the raw OCR text plus the partial header that
|
||||
came back from the regex layer, and return an LLM-derived header that the
|
||||
caller can merge in. We never let the LLM populate the personnel table —
|
||||
PP-Structure is more accurate and cheaper for that.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||
from ocr_sprint.llm.prompts import SYSTEM_HEADER, build_user_prompt
|
||||
from ocr_sprint.schemas.extraction import HeaderFields
|
||||
from ocr_sprint.utils.logging import get_logger
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
|
||||
class LLMHeaderResult(BaseModel):
|
||||
"""Schema we ask the model to fill. Mirrors ``HeaderFields`` but is
|
||||
intentionally separate so we control exactly what the prompt and
|
||||
validation surface look like — the public ``HeaderFields`` may grow
|
||||
fields later that we don't want the LLM touching.
|
||||
"""
|
||||
|
||||
nomor_sprint: str | None = None
|
||||
tanggal: date | None = None
|
||||
satuan_penerbit: str | None = None
|
||||
perihal: str | None = None
|
||||
dasar: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
def llm_fill_header(
|
||||
raw_text: str,
|
||||
regex_header: HeaderFields,
|
||||
*,
|
||||
client: OllamaClient | None = None,
|
||||
) -> HeaderFields | None:
|
||||
"""Run the LLM extractor and return a *merged* HeaderFields.
|
||||
|
||||
Returns ``None`` if the model is unavailable so the caller can decide
|
||||
what to do (typically: keep the regex result and emit a fallback
|
||||
review flag).
|
||||
"""
|
||||
client = client or OllamaClient()
|
||||
|
||||
user = build_user_prompt(
|
||||
raw_text=raw_text,
|
||||
regex_partial=regex_header.model_dump(mode="json"),
|
||||
)
|
||||
|
||||
try:
|
||||
llm = client.chat_json(SYSTEM_HEADER, user, LLMHeaderResult)
|
||||
except LLMUnavailableError as exc:
|
||||
_logger.warning("llm.unavailable", error=str(exc))
|
||||
return None
|
||||
|
||||
return _merge(regex_header, llm)
|
||||
|
||||
|
||||
def _merge(regex: HeaderFields, llm: LLMHeaderResult) -> HeaderFields:
|
||||
"""Merge LLM output into the regex result.
|
||||
|
||||
Policy: regex wins for any field it already filled. The LLM only fills
|
||||
the *gaps*. This keeps deterministic / verifiable extractions for the
|
||||
fields where regex is reliable and prevents the LLM from "correcting"
|
||||
a value that happens to look unusual but is in fact correct.
|
||||
"""
|
||||
merged = regex.model_copy(deep=True)
|
||||
if merged.nomor_sprint is None and llm.nomor_sprint:
|
||||
merged.nomor_sprint = llm.nomor_sprint
|
||||
if merged.tanggal is None and llm.tanggal is not None:
|
||||
merged.tanggal = llm.tanggal
|
||||
if not merged.satuan_penerbit and llm.satuan_penerbit:
|
||||
merged.satuan_penerbit = llm.satuan_penerbit
|
||||
if not merged.perihal and llm.perihal:
|
||||
merged.perihal = llm.perihal
|
||||
if not merged.dasar and llm.dasar:
|
||||
merged.dasar = list(llm.dasar)
|
||||
return merged
|
||||
48
src/ocr_sprint/llm/prompts.py
Normal file
48
src/ocr_sprint/llm/prompts.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Prompt builders for the LLM extractor.
|
||||
|
||||
Kept in their own module so the prompts can be edited / version-tracked
|
||||
without touching the orchestration logic. We build prompts in Indonesian
|
||||
because the source documents are too — the model performs better when the
|
||||
field labels in the prompt match the OCR text it's being asked about.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
SYSTEM_HEADER = (
|
||||
"Anda adalah asisten ekstraksi data untuk dokumen Surat Perintah (Sprint) "
|
||||
"Kepolisian Republik Indonesia (POLRI). Pengguna akan memberikan teks hasil "
|
||||
"OCR sebuah surat sprint, dan Anda harus mengembalikan JSON yang sesuai "
|
||||
"dengan skema yang diberikan.\n\n"
|
||||
"Aturan keras:\n"
|
||||
"1. Jangan mengarang. Jika sebuah field tidak terlihat di teks, kembalikan null.\n"
|
||||
"2. Jangan menerjemahkan field. Output harus identik ejaannya dengan teks "
|
||||
"sumber (kecuali normalisasi spasi/kapitalisasi yang jelas hasil OCR error).\n"
|
||||
"3. Tanggal: kembalikan format ISO YYYY-MM-DD jika tanggal terlihat, "
|
||||
"selain itu null.\n"
|
||||
"4. Dasar hukum: array string berisi tiap butir, urut sesuai teks.\n"
|
||||
"5. Jangan menambahkan field apa pun di luar skema. Output WAJIB JSON valid."
|
||||
)
|
||||
|
||||
|
||||
def build_user_prompt(raw_text: str, regex_partial: dict[str, object]) -> str:
|
||||
"""Construct the user message: OCR text + a hint about which fields the
|
||||
deterministic regex layer already filled. Telling the LLM what we
|
||||
*already have* keeps it from "creatively" overwriting good values.
|
||||
"""
|
||||
known_fields = "\n".join(f" - {k}: {v!r}" for k, v in sorted(regex_partial.items()) if v)
|
||||
known_block = (
|
||||
f"\nField yang sudah berhasil diekstrak dengan regex:\n{known_fields}\n"
|
||||
if known_fields
|
||||
else ""
|
||||
)
|
||||
|
||||
return (
|
||||
"Teks OCR:\n"
|
||||
"----------\n"
|
||||
f"{raw_text}\n"
|
||||
"----------\n"
|
||||
f"{known_block}"
|
||||
"Tugas: kembalikan JSON dengan field nomor_sprint, tanggal (ISO date | null), "
|
||||
"satuan_penerbit, perihal, dasar (array string). Hanya field yang terlihat — "
|
||||
"yang tidak ada di teks isi null (atau array kosong untuk dasar)."
|
||||
)
|
||||
@@ -15,6 +15,7 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ocr_sprint.config import get_settings
|
||||
from ocr_sprint.llm.extractor import llm_fill_header
|
||||
from ocr_sprint.pipeline.confidence import compute_confidence, route
|
||||
from ocr_sprint.pipeline.document_detect import DocumentDetectConfig, detect_and_correct
|
||||
from ocr_sprint.pipeline.extract.personnel import extract_personnel
|
||||
@@ -35,6 +36,18 @@ _logger = get_logger(__name__)
|
||||
_OCR_CONFIDENCE_FLAG_THRESHOLD = 0.80
|
||||
|
||||
|
||||
def _header_has_gaps(header: object) -> bool:
|
||||
"""True if any header field worth asking the LLM about is missing.
|
||||
|
||||
Using ``getattr`` so this stays decoupled from the exact attribute
|
||||
names; the schema change cost was too large last time we hard-coded.
|
||||
"""
|
||||
for field in ("nomor_sprint", "tanggal", "satuan_penerbit", "perihal"):
|
||||
if not getattr(header, field, None):
|
||||
return True
|
||||
return not getattr(header, "dasar", None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineOutput:
|
||||
"""Bundle returned by the orchestrator."""
|
||||
@@ -84,6 +97,20 @@ def run_pipeline(content: bytes) -> PipelineOutput:
|
||||
header = extract_header(full_text)
|
||||
ttd = find_signatory(full_text)
|
||||
|
||||
# Phase 5 — hybrid extraction. The regex layer is deterministic but
|
||||
# brittle to layout variants between satuan; if any header field is
|
||||
# still missing we ask the local LLM to fill the gaps. The merger
|
||||
# never lets the LLM overwrite a field that regex already captured.
|
||||
llm_flags: list[ReviewFlag] = []
|
||||
if s.llm_enabled and _header_has_gaps(header):
|
||||
merged = llm_fill_header(full_text, header)
|
||||
if merged is None:
|
||||
llm_flags.append(ReviewFlag.LLM_UNAVAILABLE)
|
||||
else:
|
||||
if merged.model_dump() != header.model_dump():
|
||||
llm_flags.append(ReviewFlag.LLM_FALLBACK)
|
||||
header = merged
|
||||
|
||||
personel: list[PersonnelEntry] = []
|
||||
if s.tables_enabled and cleaned_pages:
|
||||
all_tables: list[DetectedTable] = []
|
||||
@@ -99,7 +126,7 @@ def run_pipeline(content: bytes) -> PipelineOutput:
|
||||
personel_rows=len(personel),
|
||||
)
|
||||
|
||||
initial_flags: list[ReviewFlag] = []
|
||||
initial_flags: list[ReviewFlag] = list(llm_flags)
|
||||
if mean_ocr_conf < _OCR_CONFIDENCE_FLAG_THRESHOLD:
|
||||
initial_flags.append(ReviewFlag.LOW_OCR_CONFIDENCE)
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ class ReviewFlag(str, Enum):
|
||||
UNKNOWN_PANGKAT = "unknown_pangkat"
|
||||
PERSONNEL_COUNT_MISMATCH = "personnel_count_mismatch"
|
||||
DATE_PARSE_FAILED = "date_parse_failed"
|
||||
LLM_FALLBACK = "llm_fallback"
|
||||
LLM_UNAVAILABLE = "llm_unavailable"
|
||||
|
||||
|
||||
class Signatory(BaseModel):
|
||||
|
||||
108
tests/unit/test_llm_client.py
Normal file
108
tests/unit/test_llm_client.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Unit tests for the Ollama HTTP client wrapper.
|
||||
|
||||
We swap ``httpx.Client`` inside ``ocr_sprint.llm.client`` for a builder that
|
||||
returns a real ``httpx.Client`` wrapping a ``MockTransport``. Capturing the
|
||||
original constructor *before* patching avoids infinite recursion in the
|
||||
patched callable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import ocr_sprint.llm.client as llm_client_module
|
||||
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||
|
||||
|
||||
class _Schema(BaseModel):
|
||||
foo: str
|
||||
bar: int
|
||||
|
||||
|
||||
def _ollama_envelope(content: str) -> dict[str, object]:
|
||||
"""Mimic the shape Ollama's /api/chat returns."""
|
||||
return {"message": {"role": "assistant", "content": content}, "done": True}
|
||||
|
||||
|
||||
def _patch_transport(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
handler: Any,
|
||||
) -> None:
|
||||
transport = httpx.MockTransport(handler)
|
||||
real_client = httpx.Client # capture before patching
|
||||
|
||||
def _factory(*_args: object, **kwargs: object) -> httpx.Client:
|
||||
# Strip any caller-provided transport kwarg; we always inject ours.
|
||||
kwargs.pop("transport", None)
|
||||
return real_client(transport=transport, **kwargs)
|
||||
|
||||
monkeypatch.setattr(llm_client_module.httpx, "Client", _factory)
|
||||
|
||||
|
||||
def test_chat_json_returns_validated_model(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def _handler(request: httpx.Request) -> httpx.Response:
|
||||
captured["url"] = str(request.url)
|
||||
captured["body"] = request.read()
|
||||
return httpx.Response(200, json=_ollama_envelope('{"foo": "x", "bar": 7}'))
|
||||
|
||||
_patch_transport(monkeypatch, _handler)
|
||||
|
||||
client = OllamaClient(base_url="http://ollama:11434", model="m", timeout_s=5)
|
||||
out = client.chat_json("system msg", "user msg", _Schema)
|
||||
|
||||
assert out == _Schema(foo="x", bar=7)
|
||||
assert captured["url"] == "http://ollama:11434/api/chat"
|
||||
body = captured["body"]
|
||||
assert isinstance(body, bytes)
|
||||
assert b'"format":"json"' in body
|
||||
assert b'"system msg"' in body
|
||||
|
||||
|
||||
def test_chat_json_raises_on_http_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _handler(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(500, text="boom")
|
||||
|
||||
_patch_transport(monkeypatch, _handler)
|
||||
|
||||
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
|
||||
with pytest.raises(LLMUnavailableError, match="Ollama request failed"):
|
||||
client.chat_json("s", "u", _Schema)
|
||||
|
||||
|
||||
def test_chat_json_raises_on_invalid_json(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _handler(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json=_ollama_envelope("this is not json"))
|
||||
|
||||
_patch_transport(monkeypatch, _handler)
|
||||
|
||||
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
|
||||
with pytest.raises(LLMUnavailableError, match="schema"):
|
||||
client.chat_json("s", "u", _Schema)
|
||||
|
||||
|
||||
def test_chat_json_raises_on_missing_envelope(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _handler(_request: httpx.Request) -> httpx.Response:
|
||||
return httpx.Response(200, json={"oops": True})
|
||||
|
||||
_patch_transport(monkeypatch, _handler)
|
||||
|
||||
client = OllamaClient(base_url="http://x", model="m", timeout_s=5)
|
||||
with pytest.raises(LLMUnavailableError, match=r"message\.content"):
|
||||
client.chat_json("s", "u", _Schema)
|
||||
|
||||
|
||||
def test_chat_json_raises_on_connection_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def _handler(request: httpx.Request) -> httpx.Response:
|
||||
raise httpx.ConnectError("nobody home", request=request)
|
||||
|
||||
_patch_transport(monkeypatch, _handler)
|
||||
|
||||
client = OllamaClient(base_url="http://x", model="m", timeout_s=1)
|
||||
with pytest.raises(LLMUnavailableError):
|
||||
client.chat_json("s", "u", _Schema)
|
||||
90
tests/unit/test_llm_extractor.py
Normal file
90
tests/unit/test_llm_extractor.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""Unit tests for the hybrid LLM header extractor / merger."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ocr_sprint.llm.client import LLMUnavailableError, OllamaClient
|
||||
from ocr_sprint.llm.extractor import LLMHeaderResult, _merge, llm_fill_header
|
||||
from ocr_sprint.schemas.extraction import HeaderFields
|
||||
|
||||
|
||||
class _StubClient(OllamaClient):
|
||||
"""Test double that bypasses HTTP entirely."""
|
||||
|
||||
def __init__(self, payload: LLMHeaderResult | Exception) -> None:
|
||||
# Skip the real __init__ — we don't need any real config.
|
||||
self._payload = payload
|
||||
|
||||
def chat_json( # type: ignore[override]
|
||||
self, system: str, user: str, schema_cls: type[BaseModel]
|
||||
) -> BaseModel:
|
||||
if isinstance(self._payload, Exception):
|
||||
raise self._payload
|
||||
return self._payload
|
||||
|
||||
|
||||
def test_merge_keeps_regex_when_present() -> None:
|
||||
regex = HeaderFields(nomor_sprint="Sprin/123/IV/2025/Reskrim", tanggal=date(2025, 4, 21))
|
||||
llm = LLMHeaderResult(nomor_sprint="HALLUCINATED", tanggal=date(1999, 1, 1), perihal="ok")
|
||||
out = _merge(regex, llm)
|
||||
assert out.nomor_sprint == "Sprin/123/IV/2025/Reskrim"
|
||||
assert out.tanggal == date(2025, 4, 21)
|
||||
# Gaps get filled.
|
||||
assert out.perihal == "ok"
|
||||
|
||||
|
||||
def test_merge_fills_gaps() -> None:
|
||||
regex = HeaderFields() # all None
|
||||
llm = LLMHeaderResult(
|
||||
nomor_sprint="Sprin/9/IX/2024",
|
||||
tanggal=date(2024, 9, 1),
|
||||
satuan_penerbit="Polres Bandung",
|
||||
perihal="Penyelidikan",
|
||||
dasar=["UU 2/2002", "Perkap 6/2017"],
|
||||
)
|
||||
out = _merge(regex, llm)
|
||||
assert out.nomor_sprint == "Sprin/9/IX/2024"
|
||||
assert out.tanggal == date(2024, 9, 1)
|
||||
assert out.satuan_penerbit == "Polres Bandung"
|
||||
assert out.perihal == "Penyelidikan"
|
||||
assert out.dasar == ["UU 2/2002", "Perkap 6/2017"]
|
||||
|
||||
|
||||
def test_llm_fill_header_returns_merged_when_client_succeeds() -> None:
|
||||
regex = HeaderFields(nomor_sprint="Sprin/1/I/2025") # has nomor, missing rest
|
||||
stub = _StubClient(
|
||||
LLMHeaderResult(
|
||||
satuan_penerbit="Polres Bandung",
|
||||
perihal="Penyelidikan",
|
||||
dasar=["UU 2/2002"],
|
||||
)
|
||||
)
|
||||
out = llm_fill_header(raw_text="...", regex_header=regex, client=stub)
|
||||
assert out is not None
|
||||
assert out.nomor_sprint == "Sprin/1/I/2025"
|
||||
assert out.satuan_penerbit == "Polres Bandung"
|
||||
assert out.perihal == "Penyelidikan"
|
||||
assert out.dasar == ["UU 2/2002"]
|
||||
|
||||
|
||||
def test_llm_fill_header_returns_none_when_unavailable() -> None:
|
||||
stub = _StubClient(LLMUnavailableError("server down"))
|
||||
out = llm_fill_header(raw_text="...", regex_header=HeaderFields(), client=stub)
|
||||
assert out is None
|
||||
|
||||
|
||||
def test_merge_does_not_overwrite_dasar_when_regex_has_it() -> None:
|
||||
regex = HeaderFields(dasar=["UU 2/2002"])
|
||||
llm = LLMHeaderResult(dasar=["something else", "more"])
|
||||
out = _merge(regex, llm)
|
||||
assert out.dasar == ["UU 2/2002"]
|
||||
|
||||
|
||||
def test_llm_extractor_unused_argument_kept_silent() -> None:
|
||||
# A trivial sanity check that the public function signature accepts
|
||||
# keyword-only `client` — this matches how the orchestrator calls it.
|
||||
pytest.importorskip("ocr_sprint.llm.extractor")
|
||||
171
tests/unit/test_orchestrator_llm.py
Normal file
171
tests/unit/test_orchestrator_llm.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Orchestrator-level tests for the Phase 5 hybrid LLM wiring.
|
||||
|
||||
These tests stub out the heavy stages (ingest / preprocess / OCR / table)
|
||||
so we can verify the *branching* behaviour around the LLM step without
|
||||
booting Paddle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import date
|
||||
|
||||
import pytest
|
||||
|
||||
from ocr_sprint.pipeline import orchestrator as orch_module
|
||||
from ocr_sprint.pipeline.orchestrator import _header_has_gaps, run_pipeline
|
||||
from ocr_sprint.schemas.document import SourceKind
|
||||
from ocr_sprint.schemas.extraction import HeaderFields, ReviewFlag, Signatory
|
||||
|
||||
|
||||
def test_header_has_gaps_detects_missing_fields() -> None:
|
||||
full = HeaderFields(
|
||||
nomor_sprint="Sprin/1/I/2025",
|
||||
tanggal=date(2025, 1, 1),
|
||||
satuan_penerbit="Polres X",
|
||||
perihal="ok",
|
||||
dasar=["UU 2/2002"],
|
||||
)
|
||||
assert _header_has_gaps(full) is False
|
||||
|
||||
assert _header_has_gaps(HeaderFields()) is True
|
||||
assert _header_has_gaps(full.model_copy(update={"perihal": None})) is True
|
||||
assert _header_has_gaps(full.model_copy(update={"dasar": []})) is True
|
||||
|
||||
|
||||
def _stub_pipeline_stages(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
raw_text: str,
|
||||
regex_header: HeaderFields,
|
||||
) -> None:
|
||||
"""Replace ingest -> ocr -> tables with cheap fakes so the orchestrator
|
||||
runs without Paddle / PyMuPDF.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
from ocr_sprint.pipeline import ingest as ingest_module
|
||||
from ocr_sprint.pipeline import ocr as ocr_module
|
||||
from ocr_sprint.pipeline.ingest import IngestedPage
|
||||
|
||||
img = np.full((100, 100, 3), 255, dtype=np.uint8)
|
||||
fake_page = IngestedPage(image=img, page_index=0)
|
||||
fake_ocr_page = ocr_module.OCRPage(
|
||||
lines=[
|
||||
ocr_module.OCRLine(text=raw_text, confidence=0.95, box=((0, 0), (1, 0), (1, 1), (0, 1)))
|
||||
],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(orch_module, "detect_source_kind", lambda _: SourceKind.PDF)
|
||||
monkeypatch.setattr(orch_module, "ingest", lambda *a, **k: [fake_page])
|
||||
monkeypatch.setattr(orch_module, "detect_and_correct", lambda image, _cfg: image)
|
||||
monkeypatch.setattr(orch_module, "preprocess", lambda image, _cfg: image)
|
||||
monkeypatch.setattr(orch_module, "run_ocr", lambda _image: fake_ocr_page)
|
||||
# No tables in these tests.
|
||||
monkeypatch.setattr(orch_module, "run_table_extraction", lambda _img: [])
|
||||
monkeypatch.setattr(orch_module, "extract_personnel", lambda _tables: [])
|
||||
# Header / signatory / validators come from the real implementation
|
||||
# for `extract_header`, but we override to control gap state.
|
||||
monkeypatch.setattr(orch_module, "extract_header", lambda _text: regex_header)
|
||||
monkeypatch.setattr(orch_module, "find_signatory", lambda _text: Signatory())
|
||||
monkeypatch.setattr(orch_module, "validate_extraction", lambda _result: [])
|
||||
# Keep ingest_module referenced so import isn't dropped.
|
||||
assert ingest_module is not None
|
||||
|
||||
|
||||
def test_orchestrator_skips_llm_when_disabled(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("LLM_ENABLED", "false")
|
||||
from ocr_sprint.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
|
||||
_stub_pipeline_stages(
|
||||
monkeypatch,
|
||||
raw_text="dummy",
|
||||
regex_header=HeaderFields(), # all gaps
|
||||
)
|
||||
|
||||
called = {"n": 0}
|
||||
|
||||
def _trip(*_args: object, **_kwargs: object) -> None:
|
||||
called["n"] += 1
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(orch_module, "llm_fill_header", _trip)
|
||||
|
||||
result = run_pipeline(b"%PDF-1.4\n%fake")
|
||||
assert called["n"] == 0
|
||||
assert ReviewFlag.LLM_FALLBACK not in result.result.review_flags
|
||||
assert ReviewFlag.LLM_UNAVAILABLE not in result.result.review_flags
|
||||
|
||||
|
||||
def test_orchestrator_skips_llm_when_header_complete(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("LLM_ENABLED", "true")
|
||||
from ocr_sprint.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
|
||||
_stub_pipeline_stages(
|
||||
monkeypatch,
|
||||
raw_text="dummy",
|
||||
regex_header=HeaderFields(
|
||||
nomor_sprint="Sprin/1/I/2025",
|
||||
tanggal=date(2025, 1, 1),
|
||||
satuan_penerbit="Polres X",
|
||||
perihal="ok",
|
||||
dasar=["UU 2/2002"],
|
||||
),
|
||||
)
|
||||
|
||||
called = {"n": 0}
|
||||
|
||||
def _trip(*_args: object, **_kwargs: object) -> None:
|
||||
called["n"] += 1
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(orch_module, "llm_fill_header", _trip)
|
||||
|
||||
run_pipeline(b"%PDF-1.4\n%fake")
|
||||
assert called["n"] == 0
|
||||
|
||||
|
||||
def test_orchestrator_calls_llm_and_marks_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("LLM_ENABLED", "true")
|
||||
from ocr_sprint.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
|
||||
regex_partial = HeaderFields(nomor_sprint="Sprin/1/I/2025") # rest missing
|
||||
_stub_pipeline_stages(monkeypatch, raw_text="dummy text", regex_header=regex_partial)
|
||||
|
||||
def _llm(_raw: str, header: HeaderFields, **_: object) -> HeaderFields:
|
||||
return header.model_copy(
|
||||
update={
|
||||
"satuan_penerbit": "Polres Bandung",
|
||||
"perihal": "Penyelidikan",
|
||||
"dasar": ["UU 2/2002"],
|
||||
}
|
||||
)
|
||||
|
||||
monkeypatch.setattr(orch_module, "llm_fill_header", _llm)
|
||||
|
||||
out = run_pipeline(b"%PDF-1.4\n%fake")
|
||||
assert out.result.header.satuan_penerbit == "Polres Bandung"
|
||||
assert out.result.header.perihal == "Penyelidikan"
|
||||
assert ReviewFlag.LLM_FALLBACK in out.result.review_flags
|
||||
assert ReviewFlag.LLM_UNAVAILABLE not in out.result.review_flags
|
||||
|
||||
|
||||
def test_orchestrator_marks_unavailable_when_llm_returns_none(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setenv("LLM_ENABLED", "true")
|
||||
from ocr_sprint.config import get_settings
|
||||
|
||||
get_settings.cache_clear()
|
||||
|
||||
_stub_pipeline_stages(monkeypatch, raw_text="dummy", regex_header=HeaderFields())
|
||||
monkeypatch.setattr(orch_module, "llm_fill_header", lambda *_a, **_k: None)
|
||||
|
||||
out = run_pipeline(b"%PDF-1.4\n%fake")
|
||||
assert ReviewFlag.LLM_UNAVAILABLE in out.result.review_flags
|
||||
assert ReviewFlag.LLM_FALLBACK not in out.result.review_flags
|
||||
Reference in New Issue
Block a user