Phase 4: async pipeline (Celery+Redis), Postgres job state, local-fs blob storage, API-key auth, Prometheus metrics (#3)
* Phase 4: async pipeline (Celery+Redis), Postgres job state, local-fs blob storage, API-key auth, Prometheus metrics Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com> * Phase 4: fix sync-mode rollback orphaning blobs + use is_relative_to for path-escape check Devin Review on PR #3 found two real bugs: 1. Sync path mark_failed was rolled back by the request-scoped session. When the pipeline raised an exception in ?sync=true mode, _run_inline modified the FastAPI session and re-raised; get_session caught the exception, called session.rollback(), and wiped both the create() and the mark_failed() writes. The blob was already on disk, so it was permanently orphaned with no DB record. Fix: commit the pending row immediately after create(), and run all subsequent state transitions in independent session_scope blocks (matching the worker task pattern). 2. _resolve used str.startswith for path-escape detection, which lets a sibling directory whose name begins with the storage root pass (e.g. /app/blobs_evil vs /app/blobs). Switched to Path.is_relative_to. Added regression tests for both. Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com> * Phase 4: honor queue_enabled setting + resolve base_dir for path comparisons Two more bugs found by Devin Review: 3. queue_enabled was declared in config and documented in .env.example but never read by the route. A fresh dev install with QUEUE_ENABLED=false (the default) would still enqueue, then fail with a Redis connection error. Fixed by making the ?sync= query param default to None and resolving to (not queue_enabled) inside the route. Tests now set QUEUE_ENABLED=true so the async flow stays exercised, and a new test verifies the inline fallback when the queue is disabled. 4. LocalFsBlobStorage stored base_dir as-is. _resolve resolved its candidate paths, so the empty-dir cleanup loop in delete() compared a resolved candidate against an unresolved base_dir and broke on the first iteration (no cleanup ever happened). Fixed by resolving base_dir once in __init__ so every path comparison is apples-to-apples. Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com> * Phase 4: derive ocr_jobs_total from DB so worker writes are visible at /metrics Devin Review correctly noted the Counter-based JOBS_TOTAL would never increment in production because the worker runs in a separate process from the API and the registry is process-local. Replaced JOBS_TOTAL with a custom Collector that issues SELECT status, COUNT(*) FROM jobs GROUP BY status on every /metrics scrape. Result: the metric stays accurate regardless of which process wrote the row. Also corrected the metrics.py docstring (the old comment claimed the counter was 'incremented by the worker', which was the bug). Removed the JOBS_TOTAL.inc() calls from the sync route — the DB collector covers both paths now. JOB_PROCESSING_SECONDS stays as an API-process histogram with an updated docstring noting its scope; cross-process latency belongs to derived dashboards over jobs.created_at/updated_at. Added regression test test_metrics_jobs_total_reflects_worker_writes. Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com> --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: adrian kuman firmansah <adriancuman@gmail.com>
This commit is contained in:
committed by
GitHub
parent
33b38aacc7
commit
2112023b6e
6
src/ocr_sprint/api/deps/__init__.py
Normal file
6
src/ocr_sprint/api/deps/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Reusable FastAPI dependencies (auth, db session)."""
|
||||
|
||||
from ocr_sprint.api.deps.auth import require_api_key
|
||||
from ocr_sprint.api.deps.db import get_session
|
||||
|
||||
__all__ = ["get_session", "require_api_key"]
|
||||
35
src/ocr_sprint/api/deps/auth.py
Normal file
35
src/ocr_sprint/api/deps/auth.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""API-key authentication.
|
||||
|
||||
The MVP uses a static list of keys loaded from `Settings.api_keys`. This is
|
||||
deliberate: the service is intended to run on-prem behind an internal
|
||||
reverse proxy, with a small set of trusted clients (the police HITL UI and
|
||||
internal automation). Anything more sophisticated (JWT / OAuth / mTLS) is
|
||||
deferred until there's a concrete need.
|
||||
|
||||
When `Settings.api_keys` is empty the dependency permits all requests.
|
||||
This makes the local dev experience friction-free; production deploys MUST
|
||||
set at least one key — tested by `test_auth_rejects_missing_key`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Header, HTTPException, status
|
||||
|
||||
from ocr_sprint.config import get_settings
|
||||
|
||||
|
||||
async def require_api_key(
|
||||
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
|
||||
) -> None:
|
||||
settings = get_settings()
|
||||
keys = settings.api_keys_list
|
||||
if not keys:
|
||||
return # auth disabled for local dev
|
||||
if not x_api_key or x_api_key not in keys:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="invalid or missing API key",
|
||||
headers={"WWW-Authenticate": settings.api_key_header},
|
||||
)
|
||||
23
src/ocr_sprint/api/deps/db.py
Normal file
23
src/ocr_sprint/api/deps/db.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Per-request SQLAlchemy session dependency."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ocr_sprint.db.base import get_sessionmaker
|
||||
|
||||
|
||||
def get_session() -> Iterator[Session]:
|
||||
"""Yield a session, committing on success and rolling back on errors."""
|
||||
factory = get_sessionmaker()
|
||||
session = factory()
|
||||
try:
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
114
src/ocr_sprint/api/metrics.py
Normal file
114
src/ocr_sprint/api/metrics.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""Prometheus metrics exposed at `/metrics`.
|
||||
|
||||
We keep the surface tiny on purpose:
|
||||
|
||||
* `http_requests_total{method,route,status}` — request count
|
||||
* `http_request_duration_seconds{method,route}` — latency histogram
|
||||
* `ocr_jobs_total{status}` — current count of jobs in each status, derived
|
||||
from the database at scrape time. Because this is computed from the
|
||||
``jobs`` table it is correct regardless of whether the writer was the API
|
||||
process (sync mode) or a Celery worker process. Note that this is
|
||||
technically a gauge of the cumulative-by-status count, not a strict
|
||||
monotonic counter — values can decrease if rows are deleted/reaped.
|
||||
* `ocr_job_processing_seconds` — pipeline wall-time histogram. Only the API
|
||||
process observes events on this histogram (sync path); the worker writes
|
||||
its timing into the DB row and is not exposed here.
|
||||
|
||||
A custom registry is used so tests can reset counters cleanly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable, Iterable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from prometheus_client import (
|
||||
CONTENT_TYPE_LATEST,
|
||||
CollectorRegistry,
|
||||
Counter,
|
||||
Histogram,
|
||||
generate_latest,
|
||||
)
|
||||
from prometheus_client.core import GaugeMetricFamily
|
||||
from prometheus_client.metrics_core import Metric
|
||||
from sqlalchemy import func, select
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from ocr_sprint.db.base import session_scope
|
||||
from ocr_sprint.db.models import JobRow
|
||||
from ocr_sprint.utils.logging import get_logger
|
||||
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
REGISTRY = CollectorRegistry()
|
||||
|
||||
REQUEST_COUNT = Counter(
|
||||
"http_requests_total",
|
||||
"HTTP requests handled by the API",
|
||||
("method", "route", "status"),
|
||||
registry=REGISTRY,
|
||||
)
|
||||
REQUEST_LATENCY = Histogram(
|
||||
"http_request_duration_seconds",
|
||||
"HTTP request latency",
|
||||
("method", "route"),
|
||||
registry=REGISTRY,
|
||||
)
|
||||
JOB_PROCESSING_SECONDS = Histogram(
|
||||
"ocr_job_processing_seconds",
|
||||
"OCR job pipeline wall-time as observed by the API process",
|
||||
registry=REGISTRY,
|
||||
)
|
||||
|
||||
|
||||
class _JobStatusCollector:
|
||||
"""Custom collector that queries the ``jobs`` table on every scrape.
|
||||
|
||||
Because the worker runs in a separate process from the API, an in-memory
|
||||
``Counter`` cannot accurately track terminal job counts — its writes
|
||||
would never reach the API's ``/metrics`` endpoint. Reading from the
|
||||
shared DB on each scrape keeps the metric correct across processes.
|
||||
"""
|
||||
|
||||
def collect(self) -> Iterable[Metric]:
|
||||
family = GaugeMetricFamily(
|
||||
"ocr_jobs_total",
|
||||
"Current count of jobs grouped by status (read from DB).",
|
||||
labels=["status"],
|
||||
)
|
||||
try:
|
||||
with session_scope() as session:
|
||||
stmt = select(JobRow.status, func.count()).group_by(JobRow.status)
|
||||
for status_value, count in session.execute(stmt).all():
|
||||
family.add_metric([status_value], float(count))
|
||||
except Exception as exc:
|
||||
_logger.warning("metrics.jobs_collect_failed", error=str(exc))
|
||||
return [family]
|
||||
|
||||
|
||||
REGISTRY.register(_JobStatusCollector())
|
||||
|
||||
|
||||
class MetricsMiddleware(BaseHTTPMiddleware):
|
||||
"""Record request count + latency. The `route` label is the path
|
||||
template, not the raw URL, so per-id endpoints don't blow up cardinality.
|
||||
"""
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
start = time.perf_counter()
|
||||
response = await call_next(request)
|
||||
elapsed = time.perf_counter() - start
|
||||
route = request.scope.get("route")
|
||||
path = getattr(route, "path", request.url.path) if route else request.url.path
|
||||
REQUEST_COUNT.labels(request.method, path, str(response.status_code)).inc()
|
||||
REQUEST_LATENCY.labels(request.method, path).observe(elapsed)
|
||||
return response
|
||||
|
||||
|
||||
async def metrics_endpoint() -> Response:
|
||||
return Response(content=generate_latest(REGISTRY), media_type=CONTENT_TYPE_LATEST)
|
||||
@@ -1,58 +1,194 @@
|
||||
"""Documents API — Phase 1 synchronous endpoint.
|
||||
"""Documents API.
|
||||
|
||||
POST /documents accepts a single PDF or image upload, runs the synchronous
|
||||
pipeline inline, and returns the structured result. This is suitable for
|
||||
development and low-traffic production; Phase 4 will introduce an async
|
||||
queue and a polling-style API at the same path.
|
||||
Phase 1 shipped a single synchronous endpoint. Phase 4 adds an async
|
||||
flow on top:
|
||||
|
||||
* `POST /documents` — async by default. Saves the upload to blob
|
||||
storage, creates a `pending` job row, and
|
||||
enqueues a Celery task. Returns `202` with
|
||||
the job id.
|
||||
* `POST /documents?sync=true` — runs the pipeline inline (the original
|
||||
Phase 1 behaviour). Useful for tests and
|
||||
small-volume single-tenant deploys without
|
||||
a Celery worker.
|
||||
* `GET /documents/{job_id}` — returns the current job state. Async
|
||||
clients poll this until `status` is in a
|
||||
terminal state (completed / needs_review /
|
||||
failed).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import uuid4
|
||||
from typing import Annotated
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from fastapi import APIRouter, File, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Query, Response, UploadFile, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ocr_sprint.api.deps.auth import require_api_key
|
||||
from ocr_sprint.api.deps.db import get_session
|
||||
from ocr_sprint.api.errors import UnsupportedDocumentError
|
||||
from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS
|
||||
from ocr_sprint.config import get_settings
|
||||
from ocr_sprint.db.base import session_scope
|
||||
from ocr_sprint.db.repositories import JobNotFoundError, JobRepository
|
||||
from ocr_sprint.pipeline.ingest import detect_source_kind
|
||||
from ocr_sprint.pipeline.orchestrator import run_pipeline
|
||||
from ocr_sprint.schemas.document import DocumentResponse
|
||||
from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus
|
||||
from ocr_sprint.schemas.extraction import ExtractionResult
|
||||
from ocr_sprint.storage.blob import get_blob_storage
|
||||
from ocr_sprint.utils.logging import get_logger
|
||||
|
||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
||||
router = APIRouter(
|
||||
prefix="/documents",
|
||||
tags=["documents"],
|
||||
dependencies=[Depends(require_api_key)],
|
||||
)
|
||||
_logger = get_logger(__name__)
|
||||
|
||||
_MAX_UPLOAD_BYTES = 25 * 1024 * 1024 # 25 MB
|
||||
|
||||
# ---------- helpers ----------
|
||||
|
||||
|
||||
@router.post("", status_code=status.HTTP_200_OK, response_model=DocumentResponse)
|
||||
async def create_document(file: UploadFile = File(...)) -> DocumentResponse:
|
||||
"""Run OCR + extraction synchronously on a single upload."""
|
||||
def _enforce_size(content: bytes) -> None:
|
||||
s = get_settings()
|
||||
if not content:
|
||||
raise UnsupportedDocumentError("Uploaded file is empty.")
|
||||
max_bytes = s.blob_max_upload_mb * 1024 * 1024
|
||||
if len(content) > max_bytes:
|
||||
raise UnsupportedDocumentError(f"Uploaded file exceeds {s.blob_max_upload_mb} MB limit.")
|
||||
|
||||
|
||||
def _row_to_response(row: object) -> DocumentResponse:
|
||||
# Local import to avoid a circular import at module load time.
|
||||
from ocr_sprint.db.models import JobRow
|
||||
|
||||
assert isinstance(row, JobRow)
|
||||
status_enum = DocumentStatus(row.status)
|
||||
result_obj: ExtractionResult | None = None
|
||||
if row.result is not None:
|
||||
result_obj = ExtractionResult.model_validate(row.result)
|
||||
return DocumentResponse(
|
||||
job_id=row.job_id,
|
||||
status=status_enum,
|
||||
confidence=row.confidence,
|
||||
data=result_obj,
|
||||
review_flags=list(row.review_flags or []),
|
||||
error=row.error,
|
||||
)
|
||||
|
||||
|
||||
# ---------- POST ----------
|
||||
|
||||
|
||||
@router.post("", response_model=DocumentResponse)
|
||||
async def create_document(
|
||||
file: Annotated[UploadFile, File(...)],
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
response: Response,
|
||||
sync: Annotated[
|
||||
bool | None,
|
||||
Query(description="Run pipeline inline (skip queue). Defaults to !queue_enabled."),
|
||||
] = None,
|
||||
) -> DocumentResponse:
|
||||
# When the queue is disabled (default for local dev), running the async
|
||||
# path would try to dial Redis and fail with a 500. Auto-fall-back to the
|
||||
# inline pipeline unless the caller explicitly asked for async.
|
||||
if sync is None:
|
||||
sync = not get_settings().queue_enabled
|
||||
|
||||
job_id = uuid4()
|
||||
log = _logger.bind(job_id=str(job_id), filename=file.filename or "")
|
||||
|
||||
content = await file.read()
|
||||
if not content:
|
||||
raise UnsupportedDocumentError("Uploaded file is empty.")
|
||||
if len(content) > _MAX_UPLOAD_BYTES:
|
||||
raise UnsupportedDocumentError(
|
||||
f"Uploaded file exceeds {_MAX_UPLOAD_BYTES // (1024 * 1024)} MB limit."
|
||||
)
|
||||
_enforce_size(content)
|
||||
|
||||
log.info("documents.received", size=len(content))
|
||||
storage = get_blob_storage()
|
||||
blob_key = storage.put(content, original_filename=file.filename)
|
||||
source_kind = detect_source_kind(content)
|
||||
JobRepository(session).create(
|
||||
job_id=job_id,
|
||||
filename=file.filename or "",
|
||||
source_kind=source_kind,
|
||||
blob_key=blob_key,
|
||||
)
|
||||
# Commit the `pending` row immediately so it is observable regardless
|
||||
# of what happens next. Both code paths below open their own session
|
||||
# for state transitions; that way an exception in `_run_inline` cannot
|
||||
# roll back the create() (which would orphan the blob on disk).
|
||||
session.commit()
|
||||
log.info("documents.received", size=len(content), blob_key=blob_key, sync=sync)
|
||||
|
||||
if sync:
|
||||
# Status code stays at the default 200; the body's `status` field
|
||||
# tells the client whether the job needs review.
|
||||
return await _run_inline(job_id, content)
|
||||
|
||||
# Async path — enqueue and return 202. The Celery worker will pick up
|
||||
# the row using its own session.
|
||||
from ocr_sprint.worker.tasks import process_document_task
|
||||
|
||||
process_document_task.delay(str(job_id))
|
||||
with session_scope() as poll:
|
||||
row = JobRepository(poll).get_or_raise(job_id)
|
||||
body = _row_to_response(row)
|
||||
response.status_code = status.HTTP_202_ACCEPTED
|
||||
return body
|
||||
|
||||
|
||||
async def _run_inline(job_id: UUID, content: bytes) -> DocumentResponse:
|
||||
"""Synchronous pipeline execution.
|
||||
|
||||
Each state transition opens its own short session so the request-scoped
|
||||
session's rollback-on-exception behaviour cannot wipe out the
|
||||
``mark_failed`` write or strand the blob on disk.
|
||||
"""
|
||||
import time
|
||||
|
||||
with session_scope() as s:
|
||||
JobRepository(s).mark_processing(job_id)
|
||||
|
||||
started = time.perf_counter()
|
||||
try:
|
||||
output = run_pipeline(content)
|
||||
except ValueError as exc:
|
||||
with session_scope() as s:
|
||||
JobRepository(s).mark_failed(job_id, error=str(exc))
|
||||
raise UnsupportedDocumentError(str(exc)) from exc
|
||||
except Exception as exc:
|
||||
with session_scope() as s:
|
||||
JobRepository(s).mark_failed(job_id, error=str(exc))
|
||||
raise
|
||||
|
||||
log.info(
|
||||
"documents.completed",
|
||||
status=output.status.value,
|
||||
confidence=round(output.confidence, 3),
|
||||
flags=[f.value for f in output.result.review_flags],
|
||||
)
|
||||
return DocumentResponse(
|
||||
job_id=job_id,
|
||||
status=output.status,
|
||||
confidence=output.confidence,
|
||||
data=output.result,
|
||||
review_flags=[f.value for f in output.result.review_flags],
|
||||
)
|
||||
flags = [f.value for f in output.result.review_flags]
|
||||
JOB_PROCESSING_SECONDS.observe(time.perf_counter() - started)
|
||||
with session_scope() as s:
|
||||
repo = JobRepository(s)
|
||||
repo.mark_completed(
|
||||
job_id,
|
||||
status=output.status,
|
||||
confidence=output.confidence,
|
||||
result=output.result.model_dump(mode="json"),
|
||||
review_flags=flags,
|
||||
)
|
||||
row = repo.get_or_raise(job_id)
|
||||
return _row_to_response(row)
|
||||
|
||||
|
||||
# ---------- GET ----------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{job_id}",
|
||||
response_model=DocumentResponse,
|
||||
)
|
||||
async def get_document(
|
||||
job_id: UUID,
|
||||
session: Annotated[Session, Depends(get_session)],
|
||||
) -> DocumentResponse:
|
||||
repo = JobRepository(session)
|
||||
try:
|
||||
row = repo.get_or_raise(job_id)
|
||||
except JobNotFoundError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||
return _row_to_response(row)
|
||||
|
||||
Reference in New Issue
Block a user