Phase 4: async pipeline (Celery+Redis), Postgres job state, local-fs blob storage, API-key auth, Prometheus metrics (#3)

* Phase 4: async pipeline (Celery+Redis), Postgres job state, local-fs blob storage, API-key auth, Prometheus metrics

Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com>

* Phase 4: fix sync-mode rollback orphaning blobs + use is_relative_to for path-escape check

Devin Review on PR #3 found two real bugs:

1. Sync path mark_failed was rolled back by the request-scoped session.
   When the pipeline raised an exception in ?sync=true mode, _run_inline
   modified the FastAPI session and re-raised; get_session caught the
   exception, called session.rollback(), and wiped both the create() and
   the mark_failed() writes. The blob was already on disk, so it was
   permanently orphaned with no DB record. Fix: commit the pending row
   immediately after create(), and run all subsequent state transitions in
   independent session_scope blocks (matching the worker task pattern).

2. _resolve used str.startswith for path-escape detection, which lets a
   sibling directory whose name begins with the storage root pass (e.g.
   /app/blobs_evil vs /app/blobs). Switched to Path.is_relative_to.

Added regression tests for both.

Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com>

* Phase 4: honor queue_enabled setting + resolve base_dir for path comparisons

Two more bugs found by Devin Review:

3. queue_enabled was declared in config and documented in .env.example but
   never read by the route. A fresh dev install with QUEUE_ENABLED=false
   (the default) would still enqueue, then fail with a Redis connection
   error. Fixed by making the ?sync= query param default to None and
   resolving to (not queue_enabled) inside the route. Tests now set
   QUEUE_ENABLED=true so the async flow stays exercised, and a new test
   verifies the inline fallback when the queue is disabled.

4. LocalFsBlobStorage stored base_dir as-is. _resolve resolved its
   candidate paths, so the empty-dir cleanup loop in delete() compared a
   resolved candidate against an unresolved base_dir and broke on the
   first iteration (no cleanup ever happened). Fixed by resolving base_dir
   once in __init__ so every path comparison is apples-to-apples.

Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com>

* Phase 4: derive ocr_jobs_total from DB so worker writes are visible at /metrics

Devin Review correctly noted the Counter-based JOBS_TOTAL would never
increment in production because the worker runs in a separate process from
the API and the registry is process-local. Replaced JOBS_TOTAL with a
custom Collector that issues SELECT status, COUNT(*) FROM jobs GROUP BY
status on every /metrics scrape. Result: the metric stays accurate
regardless of which process wrote the row.

Also corrected the metrics.py docstring (the old comment claimed the
counter was 'incremented by the worker', which was the bug).

Removed the JOBS_TOTAL.inc() calls from the sync route — the DB collector
covers both paths now. JOB_PROCESSING_SECONDS stays as an API-process
histogram with an updated docstring noting its scope; cross-process
latency belongs to derived dashboards over jobs.created_at/updated_at.

Added regression test test_metrics_jobs_total_reflects_worker_writes.

Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com>

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: adrian kuman firmansah <adriancuman@gmail.com>
This commit is contained in:
devin-ai-integration[bot]
2026-04-25 16:50:51 +00:00
committed by GitHub
parent 33b38aacc7
commit 2112023b6e
31 changed files with 1646 additions and 105 deletions

View File

@@ -0,0 +1,6 @@
"""Reusable FastAPI dependencies (auth, db session)."""
from ocr_sprint.api.deps.auth import require_api_key
from ocr_sprint.api.deps.db import get_session
__all__ = ["get_session", "require_api_key"]

View File

@@ -0,0 +1,35 @@
"""API-key authentication.
The MVP uses a static list of keys loaded from `Settings.api_keys`. This is
deliberate: the service is intended to run on-prem behind an internal
reverse proxy, with a small set of trusted clients (the police HITL UI and
internal automation). Anything more sophisticated (JWT / OAuth / mTLS) is
deferred until there's a concrete need.
When `Settings.api_keys` is empty the dependency permits all requests.
This makes the local dev experience friction-free; production deploys MUST
set at least one key — tested by `test_auth_rejects_missing_key`.
"""
from __future__ import annotations
from typing import Annotated
from fastapi import Header, HTTPException, status
from ocr_sprint.config import get_settings
async def require_api_key(
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
) -> None:
settings = get_settings()
keys = settings.api_keys_list
if not keys:
return # auth disabled for local dev
if not x_api_key or x_api_key not in keys:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="invalid or missing API key",
headers={"WWW-Authenticate": settings.api_key_header},
)

View File

@@ -0,0 +1,23 @@
"""Per-request SQLAlchemy session dependency."""
from __future__ import annotations
from collections.abc import Iterator
from sqlalchemy.orm import Session
from ocr_sprint.db.base import get_sessionmaker
def get_session() -> Iterator[Session]:
"""Yield a session, committing on success and rolling back on errors."""
factory = get_sessionmaker()
session = factory()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()

View File

@@ -0,0 +1,114 @@
"""Prometheus metrics exposed at `/metrics`.
We keep the surface tiny on purpose:
* `http_requests_total{method,route,status}` — request count
* `http_request_duration_seconds{method,route}` — latency histogram
* `ocr_jobs_total{status}` — current count of jobs in each status, derived
from the database at scrape time. Because this is computed from the
``jobs`` table it is correct regardless of whether the writer was the API
process (sync mode) or a Celery worker process. Note that this is
technically a gauge of the cumulative-by-status count, not a strict
monotonic counter — values can decrease if rows are deleted/reaped.
* `ocr_job_processing_seconds` — pipeline wall-time histogram. Only the API
process observes events on this histogram (sync path); the worker writes
its timing into the DB row and is not exposed here.
A custom registry is used so tests can reset counters cleanly.
"""
from __future__ import annotations
import time
from collections.abc import Awaitable, Callable, Iterable
from fastapi import Request, Response
from prometheus_client import (
CONTENT_TYPE_LATEST,
CollectorRegistry,
Counter,
Histogram,
generate_latest,
)
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.metrics_core import Metric
from sqlalchemy import func, select
from starlette.middleware.base import BaseHTTPMiddleware
from ocr_sprint.db.base import session_scope
from ocr_sprint.db.models import JobRow
from ocr_sprint.utils.logging import get_logger
_logger = get_logger(__name__)
REGISTRY = CollectorRegistry()
REQUEST_COUNT = Counter(
"http_requests_total",
"HTTP requests handled by the API",
("method", "route", "status"),
registry=REGISTRY,
)
REQUEST_LATENCY = Histogram(
"http_request_duration_seconds",
"HTTP request latency",
("method", "route"),
registry=REGISTRY,
)
JOB_PROCESSING_SECONDS = Histogram(
"ocr_job_processing_seconds",
"OCR job pipeline wall-time as observed by the API process",
registry=REGISTRY,
)
class _JobStatusCollector:
"""Custom collector that queries the ``jobs`` table on every scrape.
Because the worker runs in a separate process from the API, an in-memory
``Counter`` cannot accurately track terminal job counts — its writes
would never reach the API's ``/metrics`` endpoint. Reading from the
shared DB on each scrape keeps the metric correct across processes.
"""
def collect(self) -> Iterable[Metric]:
family = GaugeMetricFamily(
"ocr_jobs_total",
"Current count of jobs grouped by status (read from DB).",
labels=["status"],
)
try:
with session_scope() as session:
stmt = select(JobRow.status, func.count()).group_by(JobRow.status)
for status_value, count in session.execute(stmt).all():
family.add_metric([status_value], float(count))
except Exception as exc:
_logger.warning("metrics.jobs_collect_failed", error=str(exc))
return [family]
REGISTRY.register(_JobStatusCollector())
class MetricsMiddleware(BaseHTTPMiddleware):
"""Record request count + latency. The `route` label is the path
template, not the raw URL, so per-id endpoints don't blow up cardinality.
"""
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
start = time.perf_counter()
response = await call_next(request)
elapsed = time.perf_counter() - start
route = request.scope.get("route")
path = getattr(route, "path", request.url.path) if route else request.url.path
REQUEST_COUNT.labels(request.method, path, str(response.status_code)).inc()
REQUEST_LATENCY.labels(request.method, path).observe(elapsed)
return response
async def metrics_endpoint() -> Response:
return Response(content=generate_latest(REGISTRY), media_type=CONTENT_TYPE_LATEST)

View File

@@ -1,58 +1,194 @@
"""Documents API — Phase 1 synchronous endpoint.
"""Documents API.
POST /documents accepts a single PDF or image upload, runs the synchronous
pipeline inline, and returns the structured result. This is suitable for
development and low-traffic production; Phase 4 will introduce an async
queue and a polling-style API at the same path.
Phase 1 shipped a single synchronous endpoint. Phase 4 adds an async
flow on top:
* `POST /documents` — async by default. Saves the upload to blob
storage, creates a `pending` job row, and
enqueues a Celery task. Returns `202` with
the job id.
* `POST /documents?sync=true` — runs the pipeline inline (the original
Phase 1 behaviour). Useful for tests and
small-volume single-tenant deploys without
a Celery worker.
* `GET /documents/{job_id}` — returns the current job state. Async
clients poll this until `status` is in a
terminal state (completed / needs_review /
failed).
"""
from __future__ import annotations
from uuid import uuid4
from typing import Annotated
from uuid import UUID, uuid4
from fastapi import APIRouter, File, UploadFile, status
from fastapi import APIRouter, Depends, File, HTTPException, Query, Response, UploadFile, status
from sqlalchemy.orm import Session
from ocr_sprint.api.deps.auth import require_api_key
from ocr_sprint.api.deps.db import get_session
from ocr_sprint.api.errors import UnsupportedDocumentError
from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS
from ocr_sprint.config import get_settings
from ocr_sprint.db.base import session_scope
from ocr_sprint.db.repositories import JobNotFoundError, JobRepository
from ocr_sprint.pipeline.ingest import detect_source_kind
from ocr_sprint.pipeline.orchestrator import run_pipeline
from ocr_sprint.schemas.document import DocumentResponse
from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus
from ocr_sprint.schemas.extraction import ExtractionResult
from ocr_sprint.storage.blob import get_blob_storage
from ocr_sprint.utils.logging import get_logger
router = APIRouter(prefix="/documents", tags=["documents"])
router = APIRouter(
prefix="/documents",
tags=["documents"],
dependencies=[Depends(require_api_key)],
)
_logger = get_logger(__name__)
_MAX_UPLOAD_BYTES = 25 * 1024 * 1024 # 25 MB
# ---------- helpers ----------
@router.post("", status_code=status.HTTP_200_OK, response_model=DocumentResponse)
async def create_document(file: UploadFile = File(...)) -> DocumentResponse:
"""Run OCR + extraction synchronously on a single upload."""
def _enforce_size(content: bytes) -> None:
s = get_settings()
if not content:
raise UnsupportedDocumentError("Uploaded file is empty.")
max_bytes = s.blob_max_upload_mb * 1024 * 1024
if len(content) > max_bytes:
raise UnsupportedDocumentError(f"Uploaded file exceeds {s.blob_max_upload_mb} MB limit.")
def _row_to_response(row: object) -> DocumentResponse:
# Local import to avoid a circular import at module load time.
from ocr_sprint.db.models import JobRow
assert isinstance(row, JobRow)
status_enum = DocumentStatus(row.status)
result_obj: ExtractionResult | None = None
if row.result is not None:
result_obj = ExtractionResult.model_validate(row.result)
return DocumentResponse(
job_id=row.job_id,
status=status_enum,
confidence=row.confidence,
data=result_obj,
review_flags=list(row.review_flags or []),
error=row.error,
)
# ---------- POST ----------
@router.post("", response_model=DocumentResponse)
async def create_document(
file: Annotated[UploadFile, File(...)],
session: Annotated[Session, Depends(get_session)],
response: Response,
sync: Annotated[
bool | None,
Query(description="Run pipeline inline (skip queue). Defaults to !queue_enabled."),
] = None,
) -> DocumentResponse:
# When the queue is disabled (default for local dev), running the async
# path would try to dial Redis and fail with a 500. Auto-fall-back to the
# inline pipeline unless the caller explicitly asked for async.
if sync is None:
sync = not get_settings().queue_enabled
job_id = uuid4()
log = _logger.bind(job_id=str(job_id), filename=file.filename or "")
content = await file.read()
if not content:
raise UnsupportedDocumentError("Uploaded file is empty.")
if len(content) > _MAX_UPLOAD_BYTES:
raise UnsupportedDocumentError(
f"Uploaded file exceeds {_MAX_UPLOAD_BYTES // (1024 * 1024)} MB limit."
)
_enforce_size(content)
log.info("documents.received", size=len(content))
storage = get_blob_storage()
blob_key = storage.put(content, original_filename=file.filename)
source_kind = detect_source_kind(content)
JobRepository(session).create(
job_id=job_id,
filename=file.filename or "",
source_kind=source_kind,
blob_key=blob_key,
)
# Commit the `pending` row immediately so it is observable regardless
# of what happens next. Both code paths below open their own session
# for state transitions; that way an exception in `_run_inline` cannot
# roll back the create() (which would orphan the blob on disk).
session.commit()
log.info("documents.received", size=len(content), blob_key=blob_key, sync=sync)
if sync:
# Status code stays at the default 200; the body's `status` field
# tells the client whether the job needs review.
return await _run_inline(job_id, content)
# Async path — enqueue and return 202. The Celery worker will pick up
# the row using its own session.
from ocr_sprint.worker.tasks import process_document_task
process_document_task.delay(str(job_id))
with session_scope() as poll:
row = JobRepository(poll).get_or_raise(job_id)
body = _row_to_response(row)
response.status_code = status.HTTP_202_ACCEPTED
return body
async def _run_inline(job_id: UUID, content: bytes) -> DocumentResponse:
"""Synchronous pipeline execution.
Each state transition opens its own short session so the request-scoped
session's rollback-on-exception behaviour cannot wipe out the
``mark_failed`` write or strand the blob on disk.
"""
import time
with session_scope() as s:
JobRepository(s).mark_processing(job_id)
started = time.perf_counter()
try:
output = run_pipeline(content)
except ValueError as exc:
with session_scope() as s:
JobRepository(s).mark_failed(job_id, error=str(exc))
raise UnsupportedDocumentError(str(exc)) from exc
except Exception as exc:
with session_scope() as s:
JobRepository(s).mark_failed(job_id, error=str(exc))
raise
log.info(
"documents.completed",
status=output.status.value,
confidence=round(output.confidence, 3),
flags=[f.value for f in output.result.review_flags],
)
return DocumentResponse(
job_id=job_id,
status=output.status,
confidence=output.confidence,
data=output.result,
review_flags=[f.value for f in output.result.review_flags],
)
flags = [f.value for f in output.result.review_flags]
JOB_PROCESSING_SECONDS.observe(time.perf_counter() - started)
with session_scope() as s:
repo = JobRepository(s)
repo.mark_completed(
job_id,
status=output.status,
confidence=output.confidence,
result=output.result.model_dump(mode="json"),
review_flags=flags,
)
row = repo.get_or_raise(job_id)
return _row_to_response(row)
# ---------- GET ----------
@router.get(
"/{job_id}",
response_model=DocumentResponse,
)
async def get_document(
job_id: UUID,
session: Annotated[Session, Depends(get_session)],
) -> DocumentResponse:
repo = JobRepository(session)
try:
row = repo.get_or_raise(job_id)
except JobNotFoundError as exc:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
return _row_to_response(row)

View File

@@ -64,12 +64,29 @@ class Settings(BaseSettings):
# Async pipeline (Phase 4)
queue_enabled: bool = False
redis_url: str = "redis://localhost:6379/0"
database_url: str = "postgresql+psycopg://ocr:ocr@localhost:5432/ocr_sprint"
minio_endpoint: str = "localhost:9000"
minio_access_key: str = "minioadmin"
minio_secret_key: str = "minioadmin"
minio_bucket: str = "ocr-sprint"
minio_secure: bool = False
celery_task_default_queue: str = "ocr_sprint"
# Persistence (Phase 4). Use sqlite for local dev / tests; Postgres for
# production via docker-compose.
database_url: str = "sqlite:///./storage/ocr_sprint.sqlite"
database_echo: bool = False
# Blob storage (Phase 4). Local filesystem only for the MVP; the
# `BlobStorage` interface is designed to swap to S3/MinIO without API
# changes when needed.
blob_storage_dir: Path = Path("./storage/blobs")
blob_max_upload_mb: int = 25
# Auth (Phase 4). Comma-separated list of API keys accepted by the API.
# Empty string disables auth (only intended for local dev / tests).
# We use ``str`` rather than ``list[str]`` because pydantic-settings rejects
# a bare empty string when binding to a list type.
api_keys: str = ""
api_key_header: str = "X-API-Key"
@property
def api_keys_list(self) -> list[str]:
return [k.strip() for k in self.api_keys.split(",") if k.strip()]
@lru_cache(maxsize=1)
@@ -77,4 +94,5 @@ def get_settings() -> Settings:
"""Cached accessor so settings are loaded once per process."""
settings = Settings()
settings.storage_local_dir.mkdir(parents=True, exist_ok=True)
settings.blob_storage_dir.mkdir(parents=True, exist_ok=True)
return settings

View File

@@ -0,0 +1,14 @@
"""Persistence layer (Phase 4) — SQLAlchemy 2.0 models, session, repositories."""
from ocr_sprint.db.base import Base, get_engine, get_sessionmaker, session_scope
from ocr_sprint.db.models import JobRow
from ocr_sprint.db.repositories import JobRepository
__all__ = [
"Base",
"JobRepository",
"JobRow",
"get_engine",
"get_sessionmaker",
"session_scope",
]

67
src/ocr_sprint/db/base.py Normal file
View File

@@ -0,0 +1,67 @@
"""SQLAlchemy 2.0 engine + session factory.
We use a single global engine per process. For tests (and SQLite in dev),
StaticPool keeps the same in-memory database across connections; for
Postgres in production we use SQLAlchemy's default pool.
"""
from __future__ import annotations
from collections.abc import Iterator
from contextlib import contextmanager
from functools import lru_cache
from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from sqlalchemy.pool import StaticPool
from ocr_sprint.config import get_settings
class Base(DeclarativeBase):
"""Common SQLAlchemy declarative base."""
@lru_cache(maxsize=1)
def get_engine() -> Engine:
s = get_settings()
kwargs: dict[str, object] = {
"echo": s.database_echo,
"future": True,
}
# SQLite needs special handling: same connection across threads (Celery
# eager mode + FastAPI) requires `check_same_thread=False`. For the
# ``sqlite:///:memory:`` URL we also use StaticPool to reuse the same
# underlying connection so test fixtures see committed data.
if s.database_url.startswith("sqlite"):
kwargs["connect_args"] = {"check_same_thread": False}
if ":memory:" in s.database_url or s.database_url.endswith(":memory:"):
kwargs["poolclass"] = StaticPool
return create_engine(s.database_url, **kwargs)
@lru_cache(maxsize=1)
def get_sessionmaker() -> sessionmaker[Session]:
return sessionmaker(bind=get_engine(), expire_on_commit=False, autoflush=False)
@contextmanager
def session_scope() -> Iterator[Session]:
"""Yield a SQLAlchemy session and commit/rollback at the boundary."""
factory = get_sessionmaker()
session = factory()
try:
yield session
session.commit()
except Exception:
session.rollback()
raise
finally:
session.close()
def reset_engine_cache() -> None:
"""Clear cached engine + sessionmaker. Used by tests when changing DB URL."""
get_engine.cache_clear()
get_sessionmaker.cache_clear()

View File

@@ -0,0 +1,53 @@
"""SQLAlchemy ORM models for jobs + extraction results.
We store the structured result as JSON. PaddleOCR's `raw_text` can run into
the tens of kilobytes for multi-page documents; that's well within Postgres'
JSONB row-size budget. SQLite stores it as TEXT under the hood.
Schema choice: we keep the result inline on the same row instead of a
separate `extraction_results` table. The 1:1 relationship would otherwise
add a join on every read, with no real benefit since results are immutable
once written and there's no use-case for fetching just the metadata.
"""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from uuid import UUID, uuid4
from sqlalchemy import JSON, DateTime, Float, String, Uuid
from sqlalchemy.orm import Mapped, mapped_column
from ocr_sprint.db.base import Base
def _utcnow() -> datetime:
return datetime.now(timezone.utc)
class JobRow(Base):
__tablename__ = "jobs"
# SQLAlchemy 2.0's Uuid type maps to native UUID on Postgres and CHAR(32)
# on SQLite, so the same model works in both environments.
job_id: Mapped[UUID] = mapped_column(Uuid, primary_key=True, default=uuid4)
status: Mapped[str] = mapped_column(String(32), nullable=False, default="pending")
source_kind: Mapped[str] = mapped_column(String(16), nullable=False, default="unknown")
filename: Mapped[str] = mapped_column(String(512), nullable=False, default="")
blob_key: Mapped[str | None] = mapped_column(String(512), nullable=True)
confidence: Mapped[float | None] = mapped_column(Float, nullable=True)
review_flags: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
result: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
error: Mapped[str | None] = mapped_column(String(2048), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, default=_utcnow
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, default=_utcnow, onupdate=_utcnow
)
def __repr__(self) -> str:
return f"JobRow(job_id={self.job_id!s}, status={self.status!r})"

View File

@@ -0,0 +1,96 @@
"""Thin data-access layer over the ORM.
Repositories encapsulate the SQL so the API + Celery task code never has to
know about sessions, transactions, or the row → schema mapping.
"""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from ocr_sprint.db.models import JobRow
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
def _utcnow() -> datetime:
return datetime.now(timezone.utc)
class JobNotFoundError(LookupError):
"""Raised by API code when GET /documents/{id} hits a missing row."""
class JobRepository:
"""SQL-backed repository for `jobs` rows."""
def __init__(self, session: Session) -> None:
self.session = session
# ---------- writes ----------
def create(
self,
*,
job_id: UUID,
filename: str,
source_kind: SourceKind,
blob_key: str,
) -> JobRow:
row = JobRow(
job_id=job_id,
status=DocumentStatus.PENDING.value,
source_kind=source_kind.value,
filename=filename,
blob_key=blob_key,
)
self.session.add(row)
self.session.flush()
return row
def mark_processing(self, job_id: UUID) -> None:
row = self._get_or_raise(job_id)
row.status = DocumentStatus.PROCESSING.value
row.updated_at = _utcnow()
def mark_completed(
self,
job_id: UUID,
*,
status: DocumentStatus,
confidence: float,
result: dict[str, Any],
review_flags: list[str],
) -> None:
row = self._get_or_raise(job_id)
row.status = status.value
row.confidence = confidence
row.result = result
row.review_flags = review_flags
row.error = None
row.updated_at = _utcnow()
def mark_failed(self, job_id: UUID, *, error: str) -> None:
row = self._get_or_raise(job_id)
row.status = DocumentStatus.FAILED.value
row.error = error[:2048]
row.updated_at = _utcnow()
# ---------- reads ----------
def get(self, job_id: UUID) -> JobRow | None:
stmt = select(JobRow).where(JobRow.job_id == job_id)
return self.session.scalar(stmt)
def get_or_raise(self, job_id: UUID) -> JobRow:
return self._get_or_raise(job_id)
def _get_or_raise(self, job_id: UUID) -> JobRow:
row = self.get(job_id)
if row is None:
raise JobNotFoundError(f"Job not found: {job_id}")
return row

View File

@@ -6,15 +6,29 @@ from fastapi import FastAPI
from ocr_sprint import __version__
from ocr_sprint.api.errors import register_error_handlers
from ocr_sprint.api.metrics import MetricsMiddleware, metrics_endpoint
from ocr_sprint.api.routes import documents, health
from ocr_sprint.config import get_settings
from ocr_sprint.db import models as _models # noqa: F401 (register ORM tables)
from ocr_sprint.db.base import Base, get_engine
from ocr_sprint.utils.logging import configure_logging
def _ensure_schema() -> None:
"""Create tables if they don't exist.
Production deploys should run Alembic migrations explicitly; this is a
convenience for local dev / tests so the API works without a manual
`alembic upgrade head` step.
"""
Base.metadata.create_all(bind=get_engine())
def create_app() -> FastAPI:
"""Application factory — keeps top-level state easy to test."""
settings = get_settings()
configure_logging(settings.app_log_level)
_ensure_schema()
app = FastAPI(
title="OCR Sprint Service",
@@ -26,8 +40,10 @@ def create_app() -> FastAPI:
)
register_error_handlers(app)
app.add_middleware(MetricsMiddleware)
app.include_router(health.router, prefix="/api/v1")
app.include_router(documents.router, prefix="/api/v1")
app.add_api_route("/metrics", metrics_endpoint, methods=["GET"], include_in_schema=False)
return app

View File

@@ -0,0 +1,5 @@
"""Pluggable blob storage. Local-fs only for the MVP."""
from ocr_sprint.storage.blob import BlobStorage, LocalFsBlobStorage, get_blob_storage
__all__ = ["BlobStorage", "LocalFsBlobStorage", "get_blob_storage"]

View File

@@ -0,0 +1,146 @@
"""Blob storage abstraction.
The MVP only ships a local-filesystem backend. The `BlobStorage` Protocol is
deliberately small (put / get / exists / delete) so that an S3- or MinIO-
backed implementation can be dropped in later without touching API code.
Layout on disk:
{blob_storage_dir}/
2026/04/25/
<uuid4>.<ext>
The date hierarchy keeps the directory listing manageable when the service
processes thousands of documents per day, and makes manual rsync-based
backup straightforward.
"""
from __future__ import annotations
from datetime import datetime, timezone
from pathlib import Path
from typing import BinaryIO, Protocol
from uuid import uuid4
from ocr_sprint.config import get_settings
from ocr_sprint.utils.logging import get_logger
_logger = get_logger(__name__)
# Map of upload extensions we'll honor when persisting blobs. Anything else
# falls back to `.bin` and the OCR pipeline's magic-byte sniffing handles
# the actual content kind.
_KNOWN_EXTS = {".pdf", ".png", ".jpg", ".jpeg", ".tif", ".tiff", ".webp"}
class BlobStorage(Protocol):
"""Minimal interface a blob backend must satisfy."""
def put(self, content: bytes, original_filename: str | None = None) -> str:
"""Persist `content` and return an opaque key the caller can use later."""
def get(self, key: str) -> bytes:
"""Return the raw bytes for `key`. Raises FileNotFoundError on miss."""
def open(self, key: str) -> BinaryIO:
"""Return a binary file-like object for streaming reads."""
def exists(self, key: str) -> bool:
"""True if `key` is currently stored."""
def delete(self, key: str) -> None:
"""Remove a blob. No-op if it doesn't exist."""
class LocalFsBlobStorage:
"""Filesystem-backed implementation rooted at `base_dir`."""
def __init__(self, base_dir: Path) -> None:
# Resolve once so every subsequent path comparison (escape check,
# empty-dir cleanup) is apples-to-apples — ``Path.parents`` of a
# resolved key would otherwise never equal a relative ``base_dir``.
base_dir.mkdir(parents=True, exist_ok=True)
self.base_dir = base_dir.resolve()
# ---------- helpers ----------
@staticmethod
def _safe_ext(original_filename: str | None) -> str:
if not original_filename:
return ".bin"
suffix = Path(original_filename).suffix.lower()
return suffix if suffix in _KNOWN_EXTS else ".bin"
def _resolve(self, key: str) -> Path:
# Defensive: keys come from the DB but we still reject paths that try
# to escape the blob root. ``Path.is_relative_to`` does proper path
# containment — string ``startswith`` would let ``/app/blobs_evil``
# slip past when the root is ``/app/blobs``.
candidate = (self.base_dir / key).resolve()
if not candidate.is_relative_to(self.base_dir):
raise ValueError(f"Blob key escapes storage root: {key!r}")
return candidate
# ---------- BlobStorage protocol ----------
def put(self, content: bytes, original_filename: str | None = None) -> str:
now = datetime.now(timezone.utc)
date_dir = Path(f"{now:%Y/%m/%d}")
ext = self._safe_ext(original_filename)
key = str(date_dir / f"{uuid4().hex}{ext}")
target = self._resolve(key)
target.parent.mkdir(parents=True, exist_ok=True)
# Write to a temp file in the same directory then rename. This avoids
# a half-written blob being read by a concurrent worker.
tmp = target.with_suffix(target.suffix + ".tmp")
tmp.write_bytes(content)
tmp.rename(target)
_logger.info("blob.put", key=key, size=len(content))
return key
def get(self, key: str) -> bytes:
path = self._resolve(key)
if not path.exists():
raise FileNotFoundError(f"Blob not found: {key}")
return path.read_bytes()
def open(self, key: str) -> BinaryIO:
path = self._resolve(key)
if not path.exists():
raise FileNotFoundError(f"Blob not found: {key}")
return path.open("rb")
def exists(self, key: str) -> bool:
try:
return self._resolve(key).exists()
except ValueError:
return False
def delete(self, key: str) -> None:
try:
path = self._resolve(key)
except ValueError:
return
if path.exists():
path.unlink()
_logger.info("blob.delete", key=key)
# Best-effort cleanup of empty date dirs so we don't accumulate
# 365 directories per year forever. ``self.base_dir`` is already
# resolved (see __init__), so it can be compared against
# ``path.parents`` directly.
for parent in path.parents:
if parent == self.base_dir or self.base_dir not in parent.parents:
break
try:
parent.rmdir()
except OSError:
break
def get_blob_storage() -> BlobStorage:
"""Build the configured blob backend. Single-process cache lives in `Settings`."""
s = get_settings()
return LocalFsBlobStorage(s.blob_storage_dir)
__all__ = ["BlobStorage", "LocalFsBlobStorage", "get_blob_storage"]

View File

@@ -0,0 +1,6 @@
"""Celery worker (Phase 4) — async OCR pipeline."""
from ocr_sprint.worker.celery_app import celery_app
from ocr_sprint.worker.tasks import process_document_task
__all__ = ["celery_app", "process_document_task"]

View File

@@ -0,0 +1,49 @@
"""Celery application factory.
The broker and result backend are both Redis. We deliberately don't use
Postgres as the result backend — Celery's pg result-backend creates noisy
schema artefacts and we already store the structured result on the `jobs`
table.
Tasks register themselves by calling `celery_app.task` in `tasks.py`. Eager
mode (used in tests) is enabled by setting `CELERY_TASK_ALWAYS_EAGER=true`
in the environment.
"""
from __future__ import annotations
import os
from celery import Celery
from ocr_sprint.config import get_settings
def build_celery_app() -> Celery:
settings = get_settings()
app = Celery(
"ocr_sprint",
broker=settings.redis_url,
backend=settings.redis_url,
include=["ocr_sprint.worker.tasks"],
)
app.conf.update(
task_default_queue=settings.celery_task_default_queue,
task_serializer="json",
result_serializer="json",
accept_content=["json"],
timezone="UTC",
enable_utc=True,
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_prefetch_multiplier=1, # OCR is CPU-bound; one task at a time.
broker_connection_retry_on_startup=True,
result_expires=24 * 3600, # results live in DB; redis is just a cache
)
if os.getenv("CELERY_TASK_ALWAYS_EAGER", "").lower() in {"1", "true", "yes"}:
app.conf.task_always_eager = True
app.conf.task_eager_propagates = True
return app
celery_app = build_celery_app()

View File

@@ -0,0 +1,84 @@
"""Celery tasks.
`process_document_task` is the single async entrypoint: the API enqueues
one task per upload, the worker pulls the blob, runs the orchestrator, and
writes the result back to the `jobs` table.
Tasks must be idempotent at the boundary: if the worker crashes mid-OCR,
the next retry should pick up the same blob and produce the same result.
We therefore re-fetch the row by id on every transition rather than
threading state through closures.
"""
from __future__ import annotations
from uuid import UUID
from celery.exceptions import Reject
from ocr_sprint.db.base import session_scope
from ocr_sprint.db.repositories import JobRepository
from ocr_sprint.pipeline.orchestrator import run_pipeline
from ocr_sprint.storage.blob import get_blob_storage
from ocr_sprint.utils.logging import get_logger
from ocr_sprint.worker.celery_app import celery_app
_logger = get_logger(__name__)
@celery_app.task(name="ocr_sprint.process_document", bind=True, max_retries=0) # type: ignore[untyped-decorator]
def process_document_task(self: object, job_id_str: str) -> str:
"""Run the OCR pipeline for `job_id` and persist the structured result.
Returns the final job status as a string so callers can wait on
`AsyncResult.get()` if they want to (mainly tests).
"""
job_id = UUID(job_id_str)
log = _logger.bind(job_id=job_id_str)
storage = get_blob_storage()
with session_scope() as session:
repo = JobRepository(session)
row = repo.get(job_id)
if row is None:
raise Reject(f"Job {job_id_str} not found", requeue=False)
repo.mark_processing(job_id)
blob_key = row.blob_key
if not blob_key:
with session_scope() as session:
JobRepository(session).mark_failed(job_id, error="missing blob_key")
return "failed"
try:
content = storage.get(blob_key)
except FileNotFoundError as exc:
log.error("worker.blob_missing", error=str(exc))
with session_scope() as session:
JobRepository(session).mark_failed(job_id, error=f"blob missing: {exc}")
return "failed"
try:
output = run_pipeline(content)
except Exception as exc:
log.exception("worker.pipeline_error")
with session_scope() as session:
JobRepository(session).mark_failed(job_id, error=str(exc))
return "failed"
flags = [f.value for f in output.result.review_flags]
log.info(
"worker.completed",
status=output.status.value,
confidence=round(output.confidence, 3),
flags=flags,
)
with session_scope() as session:
JobRepository(session).mark_completed(
job_id,
status=output.status,
confidence=output.confidence,
result=output.result.model_dump(mode="json"),
review_flags=flags,
)
return output.status.value