diff --git a/.env.example b/.env.example index 07585c9..68b2964 100644 --- a/.env.example +++ b/.env.example @@ -40,12 +40,20 @@ LLM_MODEL=qwen2.5:1.5b # CPU-friendly default LLM_BASE_URL=http://localhost:11434 LLM_TIMEOUT_S=60 -# ==== Async pipeline (Phase 4, optional) ==== -QUEUE_ENABLED=false +# ==== Async pipeline + persistence (Phase 4) ==== +QUEUE_ENABLED=false # POST /documents queues async when true REDIS_URL=redis://localhost:6379/0 -DATABASE_URL=postgresql+psycopg://ocr:ocr@localhost:5432/ocr_sprint -MINIO_ENDPOINT=localhost:9000 -MINIO_ACCESS_KEY=minioadmin -MINIO_SECRET_KEY=minioadmin -MINIO_BUCKET=ocr-sprint -MINIO_SECURE=false +CELERY_TASK_DEFAULT_QUEUE=ocr_sprint + +# Persistence: sqlite for local dev, Postgres for production via docker-compose. +DATABASE_URL=sqlite:///./storage/ocr_sprint.sqlite +DATABASE_ECHO=false + +# Blob storage: local filesystem only for the MVP (no S3/MinIO). +BLOB_STORAGE_DIR=./storage/blobs +BLOB_MAX_UPLOAD_MB=25 + +# Auth: comma-separated list of accepted API keys. Empty = auth disabled +# (local dev only; production must set at least one). +API_KEYS= +API_KEY_HEADER=X-API-Key diff --git a/.gitignore b/.gitignore index 9897bab..3035456 100644 --- a/.gitignore +++ b/.gitignore @@ -48,7 +48,10 @@ samples/*.tif samples/*.tiff !samples/README.md data/local/ -storage/ +# Runtime data dirs (blobs, sqlite). The src tree's `storage` package is a +# real Python module — see the `!` rule below. +/storage/ +!src/ocr_sprint/storage/ *.db *.sqlite *.sqlite3 diff --git a/Dockerfile b/Dockerfile index 110cf97..25064fc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -28,17 +28,22 @@ WORKDIR /app FROM base AS builder COPY pyproject.toml README.md ./ COPY src/ ./src/ -RUN pip install --upgrade pip && pip install ".[dev]" +# `[ocr]` pulls Paddle wheels (~1.5 GB). `[dev]` keeps test+lint deps so +# that `make test` works inside the image. +RUN pip install --upgrade pip && pip install ".[ocr,dev]" # ----- runtime layer ----- FROM base AS runtime COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages COPY --from=builder /usr/local/bin /usr/local/bin -COPY pyproject.toml README.md ./ +COPY pyproject.toml README.md alembic.ini ./ COPY src/ ./src/ +COPY alembic/ ./alembic/ -# Pre-create cache dirs so PaddleOCR can write models on first run. -RUN mkdir -p /home/app/.paddleocr /app/storage \ +# Pre-create cache dirs so PaddleOCR can write models on first run, and +# the blob storage root so the API can write uploads as the unprivileged +# `app` user. +RUN mkdir -p /home/app/.paddleocr /app/storage/blobs \ && useradd --create-home --uid 1000 app \ && chown -R app:app /home/app /app diff --git a/README.md b/README.md index c952258..c5ab382 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ OCR + structured extraction service for Indonesian police "surat sprint" (surat perintah) documents. Built around **FastAPI + PaddleOCR + hybrid extraction (regex → LLM lokal → validation)** with **on-premise** deployment as a hard requirement. -> **Status:** Phase 1+2+3 — synchronous PDF/image OCR with regex header extraction, validation, confidence scoring, document detection + perspective correction + shadow removal for phone photos, and **PP-Structure table extraction** for personnel rows. Phase 4–6 (async pipeline, LLM extraction, HITL) are tracked in [`docs/architecture.md`](docs/architecture.md). +> **Status:** Phase 1–4 — synchronous + async PDF/image OCR with regex header extraction, PP-Structure personnel-table extraction, validation, confidence scoring, document detection / perspective correction / shadow removal, **Celery + Redis job queue, Postgres job state, local-filesystem blob storage, API-key auth, and Prometheus metrics**. Phase 5–6 (LLM extraction, HITL) are tracked in [`docs/architecture.md`](docs/architecture.md). ## Why this stack @@ -29,6 +29,7 @@ cd ocr-sprint-service python -m venv .venv && source .venv/bin/activate make install # installs runtime + dev deps + pre-commit +pip install -e ".[ocr]" # only on the worker host — pulls Paddle wheels (~1.5 GB) cp .env.example .env # edit if you need GPU / different storage path ``` @@ -41,8 +42,21 @@ make dev ### Try it out +The default `POST /documents` is async — it returns `202 Accepted` with a `job_id` and the worker fills in the result. For tests / local one-shot usage you can append `?sync=true` to run inline. + ```bash -curl -F "file=@samples/pdf/example.pdf" http://localhost:8000/api/v1/documents | jq +# Async (production flow) +curl -F "file=@samples/pdf/example.pdf" \ + -H "X-API-Key: $API_KEY" \ + http://localhost:8000/api/v1/documents | jq +# → {"job_id":"8f2a...","status":"pending",...} + +curl -H "X-API-Key: $API_KEY" \ + http://localhost:8000/api/v1/documents/8f2a... | jq + +# Sync (single small doc, no worker required) +curl -F "file=@samples/pdf/example.pdf" \ + "http://localhost:8000/api/v1/documents?sync=true" | jq ``` Expected response (truncated): @@ -71,13 +85,17 @@ Expected response (truncated): ### Docker +The Phase 4 stack runs four services: `api`, `worker` (Celery), `redis`, and `postgres`. Blob uploads are persisted to a Docker volume — there is **no MinIO/S3** dependency. + ```bash docker compose build docker compose up -d -docker compose logs -f api +docker compose logs -f api worker ``` -The first request will trigger PaddleOCR to download its detection/recognition/cls models (~200 MB) into the `paddle-models` volume. +The API container runs `alembic upgrade head` on start, so the `jobs` table is created on first boot. The first request will trigger PaddleOCR to download its detection/recognition/cls models (~200 MB) into the `paddle-models` volume. + +Metrics are exposed at in Prometheus text format. ## Development @@ -114,7 +132,7 @@ docs/ # architecture & decision records | 1 | Sync API, PDF/image ingest, basic preprocessing, PaddleOCR, regex header extraction, validation, confidence scoring | **Done** | | 2 | OpenCV-based document detection, perspective transform, shadow removal for phone photos | **Done** | | 3 | PP-Structure table extraction for personnel rows + column mapper | **Done** | -| 4 | Async pipeline (Celery + Redis), Postgres + MinIO, auth, observability | Planned | +| 4 | Async pipeline (Celery + Redis), Postgres job state, local-filesystem blob storage, API-key auth, Prometheus metrics | **Done** | | 5 | LLM hybrid extraction (Ollama + structured output) | Planned | | 6 | HITL review endpoints + audit trail | Planned | diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 0000000..6eefbea --- /dev/null +++ b/alembic.ini @@ -0,0 +1,38 @@ +[alembic] +script_location = alembic +prepend_sys_path = src +sqlalchemy.url = + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 0000000..625e429 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,56 @@ +"""Alembic environment. + +The DB URL is taken from `Settings.database_url`, which respects the same +`.env` and environment variables as the main app. CLI users can override +with `-x sqlalchemy.url=...` if they need to point at a different DB. +""" + +from __future__ import annotations + +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import engine_from_config, pool + +from ocr_sprint.config import get_settings +from ocr_sprint.db.models import Base + +config = context.config +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +settings = get_settings() +config.set_main_option("sqlalchemy.url", settings.database_url) + +target_metadata = Base.metadata + + +def run_migrations_offline() -> None: + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + compare_type=True, + ) + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + connectable = engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + with connectable.connect() as connection: + context.configure(connection=connection, target_metadata=target_metadata, compare_type=True) + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 0000000..5716ac7 --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,25 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} +""" +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: str | None = ${repr(down_revision)} +branch_labels: str | Sequence[str] | None = ${repr(branch_labels)} +depends_on: str | Sequence[str] | None = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py b/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py new file mode 100644 index 0000000..8ffd0ab --- /dev/null +++ b/alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py @@ -0,0 +1,42 @@ +"""phase4 jobs table + +Revision ID: ff8c14fbf8a0 +Revises: +Create Date: 2026-04-25 15:54:18.579147 +""" +from collections.abc import Sequence + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'ff8c14fbf8a0' +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('jobs', + sa.Column('job_id', sa.Uuid(), nullable=False), + sa.Column('status', sa.String(length=32), nullable=False), + sa.Column('source_kind', sa.String(length=16), nullable=False), + sa.Column('filename', sa.String(length=512), nullable=False), + sa.Column('blob_key', sa.String(length=512), nullable=True), + sa.Column('confidence', sa.Float(), nullable=True), + sa.Column('review_flags', sa.JSON(), nullable=False), + sa.Column('result', sa.JSON(), nullable=True), + sa.Column('error', sa.String(length=2048), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), nullable=False), + sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint('job_id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('jobs') + # ### end Alembic commands ### diff --git a/docker-compose.yml b/docker-compose.yml index cd520ff..069919d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,5 +1,6 @@ -# Phase 1 MVP compose: API only. -# Phase 4 will add redis, postgres, minio, and worker services. +# Phase 4 stack: API + Celery worker + Redis (broker/result backend) + +# Postgres (job state). Object storage is intentionally NOT here — the +# `BlobStorage` interface uses the local filesystem mounted at /app/storage. services: api: build: @@ -7,17 +8,83 @@ services: dockerfile: Dockerfile image: ocr-sprint-service:dev container_name: ocr-sprint-api + command: + [ + "sh", + "-c", + "alembic upgrade head && uvicorn ocr_sprint.main:app --host 0.0.0.0 --port 8000", + ] ports: - "8000:8000" environment: - APP_ENV: local + APP_ENV: docker APP_LOG_LEVEL: INFO OCR_USE_GPU: "false" STORAGE_LOCAL_DIR: /app/storage + BLOB_STORAGE_DIR: /app/storage/blobs + REDIS_URL: redis://redis:6379/0 + DATABASE_URL: postgresql+psycopg://ocr:ocr@postgres:5432/ocr_sprint + QUEUE_ENABLED: "true" volumes: - - ./storage:/app/storage + - blob-storage:/app/storage/blobs - paddle-models:/home/app/.paddleocr + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + restart: unless-stopped + + worker: + image: ocr-sprint-service:dev + container_name: ocr-sprint-worker + command: ["celery", "-A", "ocr_sprint.worker.celery_app", "worker", "-l", "info", "--concurrency=1"] + environment: + APP_ENV: docker + APP_LOG_LEVEL: INFO + OCR_USE_GPU: "false" + BLOB_STORAGE_DIR: /app/storage/blobs + REDIS_URL: redis://redis:6379/0 + DATABASE_URL: postgresql+psycopg://ocr:ocr@postgres:5432/ocr_sprint + volumes: + - blob-storage:/app/storage/blobs + - paddle-models:/home/app/.paddleocr + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + api: + condition: service_started + restart: unless-stopped + + redis: + image: redis:7-alpine + container_name: ocr-sprint-redis + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 5s + timeout: 3s + retries: 5 + restart: unless-stopped + + postgres: + image: postgres:16-alpine + container_name: ocr-sprint-postgres + environment: + POSTGRES_USER: ocr + POSTGRES_PASSWORD: ocr + POSTGRES_DB: ocr_sprint + volumes: + - postgres-data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U ocr -d ocr_sprint"] + interval: 5s + timeout: 3s + retries: 10 restart: unless-stopped volumes: + blob-storage: paddle-models: + postgres-data: diff --git a/pyproject.toml b/pyproject.toml index 4ae79a8..ccefaa7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,15 @@ dependencies = [ "numpy>=1.26,<2.2", "PyMuPDF>=1.24,<2", "python-magic>=0.4.27", - # OCR (CPU build of paddle; GPU users override via extra index) - "paddlepaddle==2.6.1", - "paddleocr>=2.7.5,<3", # Logging / observability "structlog>=24.1", "prometheus-client>=0.20", + # Async pipeline + persistence (Phase 4) + "celery[redis]>=5.4", + "redis>=5.0", + "sqlalchemy>=2.0", + "psycopg[binary]>=3.2", + "alembic>=1.13", # Misc "httpx>=0.27", "tenacity>=8.5", @@ -46,22 +49,20 @@ dev = [ "pre-commit>=3.7", ] +# OCR runtime — kept as an optional extra so unit tests / dev installs don't +# pull ~1.5 GB of Paddle wheels. Install via `pip install -e ".[ocr]"` on +# the worker host. CPU build by default; GPU users override the index URL. +ocr = [ + "paddlepaddle>=2.6.2,<4", + "paddleocr>=2.7.5,<3", +] + # Extraction layer (Phase 5) — kept optional so MVP install stays light llm = [ "ollama>=0.3", "instructor>=1.4", ] -# Async pipeline (Phase 4) -async-pipeline = [ - "celery[redis]>=5.4", - "redis>=5.0", - "minio>=7.2", - "sqlalchemy>=2.0", - "psycopg[binary]>=3.2", - "alembic>=1.13", -] - [project.scripts] ocr-sprint-api = "ocr_sprint.main:run" @@ -111,7 +112,7 @@ namespace_packages = true explicit_package_bases = true [[tool.mypy.overrides]] -module = ["paddleocr.*", "paddle.*", "cv2.*", "fitz.*", "magic.*"] +module = ["paddleocr.*", "paddle.*", "cv2.*", "fitz.*", "magic.*", "celery.*", "kombu.*"] ignore_missing_imports = true [tool.pytest.ini_options] diff --git a/src/ocr_sprint/api/deps/__init__.py b/src/ocr_sprint/api/deps/__init__.py new file mode 100644 index 0000000..6b786dc --- /dev/null +++ b/src/ocr_sprint/api/deps/__init__.py @@ -0,0 +1,6 @@ +"""Reusable FastAPI dependencies (auth, db session).""" + +from ocr_sprint.api.deps.auth import require_api_key +from ocr_sprint.api.deps.db import get_session + +__all__ = ["get_session", "require_api_key"] diff --git a/src/ocr_sprint/api/deps/auth.py b/src/ocr_sprint/api/deps/auth.py new file mode 100644 index 0000000..ce87de9 --- /dev/null +++ b/src/ocr_sprint/api/deps/auth.py @@ -0,0 +1,35 @@ +"""API-key authentication. + +The MVP uses a static list of keys loaded from `Settings.api_keys`. This is +deliberate: the service is intended to run on-prem behind an internal +reverse proxy, with a small set of trusted clients (the police HITL UI and +internal automation). Anything more sophisticated (JWT / OAuth / mTLS) is +deferred until there's a concrete need. + +When `Settings.api_keys` is empty the dependency permits all requests. +This makes the local dev experience friction-free; production deploys MUST +set at least one key — tested by `test_auth_rejects_missing_key`. +""" + +from __future__ import annotations + +from typing import Annotated + +from fastapi import Header, HTTPException, status + +from ocr_sprint.config import get_settings + + +async def require_api_key( + x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None, +) -> None: + settings = get_settings() + keys = settings.api_keys_list + if not keys: + return # auth disabled for local dev + if not x_api_key or x_api_key not in keys: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="invalid or missing API key", + headers={"WWW-Authenticate": settings.api_key_header}, + ) diff --git a/src/ocr_sprint/api/deps/db.py b/src/ocr_sprint/api/deps/db.py new file mode 100644 index 0000000..d05d103 --- /dev/null +++ b/src/ocr_sprint/api/deps/db.py @@ -0,0 +1,23 @@ +"""Per-request SQLAlchemy session dependency.""" + +from __future__ import annotations + +from collections.abc import Iterator + +from sqlalchemy.orm import Session + +from ocr_sprint.db.base import get_sessionmaker + + +def get_session() -> Iterator[Session]: + """Yield a session, committing on success and rolling back on errors.""" + factory = get_sessionmaker() + session = factory() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() diff --git a/src/ocr_sprint/api/metrics.py b/src/ocr_sprint/api/metrics.py new file mode 100644 index 0000000..fc20b2f --- /dev/null +++ b/src/ocr_sprint/api/metrics.py @@ -0,0 +1,114 @@ +"""Prometheus metrics exposed at `/metrics`. + +We keep the surface tiny on purpose: + +* `http_requests_total{method,route,status}` — request count +* `http_request_duration_seconds{method,route}` — latency histogram +* `ocr_jobs_total{status}` — current count of jobs in each status, derived + from the database at scrape time. Because this is computed from the + ``jobs`` table it is correct regardless of whether the writer was the API + process (sync mode) or a Celery worker process. Note that this is + technically a gauge of the cumulative-by-status count, not a strict + monotonic counter — values can decrease if rows are deleted/reaped. +* `ocr_job_processing_seconds` — pipeline wall-time histogram. Only the API + process observes events on this histogram (sync path); the worker writes + its timing into the DB row and is not exposed here. + +A custom registry is used so tests can reset counters cleanly. +""" + +from __future__ import annotations + +import time +from collections.abc import Awaitable, Callable, Iterable + +from fastapi import Request, Response +from prometheus_client import ( + CONTENT_TYPE_LATEST, + CollectorRegistry, + Counter, + Histogram, + generate_latest, +) +from prometheus_client.core import GaugeMetricFamily +from prometheus_client.metrics_core import Metric +from sqlalchemy import func, select +from starlette.middleware.base import BaseHTTPMiddleware + +from ocr_sprint.db.base import session_scope +from ocr_sprint.db.models import JobRow +from ocr_sprint.utils.logging import get_logger + +_logger = get_logger(__name__) + +REGISTRY = CollectorRegistry() + +REQUEST_COUNT = Counter( + "http_requests_total", + "HTTP requests handled by the API", + ("method", "route", "status"), + registry=REGISTRY, +) +REQUEST_LATENCY = Histogram( + "http_request_duration_seconds", + "HTTP request latency", + ("method", "route"), + registry=REGISTRY, +) +JOB_PROCESSING_SECONDS = Histogram( + "ocr_job_processing_seconds", + "OCR job pipeline wall-time as observed by the API process", + registry=REGISTRY, +) + + +class _JobStatusCollector: + """Custom collector that queries the ``jobs`` table on every scrape. + + Because the worker runs in a separate process from the API, an in-memory + ``Counter`` cannot accurately track terminal job counts — its writes + would never reach the API's ``/metrics`` endpoint. Reading from the + shared DB on each scrape keeps the metric correct across processes. + """ + + def collect(self) -> Iterable[Metric]: + family = GaugeMetricFamily( + "ocr_jobs_total", + "Current count of jobs grouped by status (read from DB).", + labels=["status"], + ) + try: + with session_scope() as session: + stmt = select(JobRow.status, func.count()).group_by(JobRow.status) + for status_value, count in session.execute(stmt).all(): + family.add_metric([status_value], float(count)) + except Exception as exc: + _logger.warning("metrics.jobs_collect_failed", error=str(exc)) + return [family] + + +REGISTRY.register(_JobStatusCollector()) + + +class MetricsMiddleware(BaseHTTPMiddleware): + """Record request count + latency. The `route` label is the path + template, not the raw URL, so per-id endpoints don't blow up cardinality. + """ + + async def dispatch( + self, + request: Request, + call_next: Callable[[Request], Awaitable[Response]], + ) -> Response: + start = time.perf_counter() + response = await call_next(request) + elapsed = time.perf_counter() - start + route = request.scope.get("route") + path = getattr(route, "path", request.url.path) if route else request.url.path + REQUEST_COUNT.labels(request.method, path, str(response.status_code)).inc() + REQUEST_LATENCY.labels(request.method, path).observe(elapsed) + return response + + +async def metrics_endpoint() -> Response: + return Response(content=generate_latest(REGISTRY), media_type=CONTENT_TYPE_LATEST) diff --git a/src/ocr_sprint/api/routes/documents.py b/src/ocr_sprint/api/routes/documents.py index 26dd6eb..018d00c 100644 --- a/src/ocr_sprint/api/routes/documents.py +++ b/src/ocr_sprint/api/routes/documents.py @@ -1,58 +1,194 @@ -"""Documents API — Phase 1 synchronous endpoint. +"""Documents API. -POST /documents accepts a single PDF or image upload, runs the synchronous -pipeline inline, and returns the structured result. This is suitable for -development and low-traffic production; Phase 4 will introduce an async -queue and a polling-style API at the same path. +Phase 1 shipped a single synchronous endpoint. Phase 4 adds an async +flow on top: + +* `POST /documents` — async by default. Saves the upload to blob + storage, creates a `pending` job row, and + enqueues a Celery task. Returns `202` with + the job id. +* `POST /documents?sync=true` — runs the pipeline inline (the original + Phase 1 behaviour). Useful for tests and + small-volume single-tenant deploys without + a Celery worker. +* `GET /documents/{job_id}` — returns the current job state. Async + clients poll this until `status` is in a + terminal state (completed / needs_review / + failed). """ from __future__ import annotations -from uuid import uuid4 +from typing import Annotated +from uuid import UUID, uuid4 -from fastapi import APIRouter, File, UploadFile, status +from fastapi import APIRouter, Depends, File, HTTPException, Query, Response, UploadFile, status +from sqlalchemy.orm import Session +from ocr_sprint.api.deps.auth import require_api_key +from ocr_sprint.api.deps.db import get_session from ocr_sprint.api.errors import UnsupportedDocumentError +from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS +from ocr_sprint.config import get_settings +from ocr_sprint.db.base import session_scope +from ocr_sprint.db.repositories import JobNotFoundError, JobRepository +from ocr_sprint.pipeline.ingest import detect_source_kind from ocr_sprint.pipeline.orchestrator import run_pipeline -from ocr_sprint.schemas.document import DocumentResponse +from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus +from ocr_sprint.schemas.extraction import ExtractionResult +from ocr_sprint.storage.blob import get_blob_storage from ocr_sprint.utils.logging import get_logger -router = APIRouter(prefix="/documents", tags=["documents"]) +router = APIRouter( + prefix="/documents", + tags=["documents"], + dependencies=[Depends(require_api_key)], +) _logger = get_logger(__name__) -_MAX_UPLOAD_BYTES = 25 * 1024 * 1024 # 25 MB + +# ---------- helpers ---------- -@router.post("", status_code=status.HTTP_200_OK, response_model=DocumentResponse) -async def create_document(file: UploadFile = File(...)) -> DocumentResponse: - """Run OCR + extraction synchronously on a single upload.""" +def _enforce_size(content: bytes) -> None: + s = get_settings() + if not content: + raise UnsupportedDocumentError("Uploaded file is empty.") + max_bytes = s.blob_max_upload_mb * 1024 * 1024 + if len(content) > max_bytes: + raise UnsupportedDocumentError(f"Uploaded file exceeds {s.blob_max_upload_mb} MB limit.") + + +def _row_to_response(row: object) -> DocumentResponse: + # Local import to avoid a circular import at module load time. + from ocr_sprint.db.models import JobRow + + assert isinstance(row, JobRow) + status_enum = DocumentStatus(row.status) + result_obj: ExtractionResult | None = None + if row.result is not None: + result_obj = ExtractionResult.model_validate(row.result) + return DocumentResponse( + job_id=row.job_id, + status=status_enum, + confidence=row.confidence, + data=result_obj, + review_flags=list(row.review_flags or []), + error=row.error, + ) + + +# ---------- POST ---------- + + +@router.post("", response_model=DocumentResponse) +async def create_document( + file: Annotated[UploadFile, File(...)], + session: Annotated[Session, Depends(get_session)], + response: Response, + sync: Annotated[ + bool | None, + Query(description="Run pipeline inline (skip queue). Defaults to !queue_enabled."), + ] = None, +) -> DocumentResponse: + # When the queue is disabled (default for local dev), running the async + # path would try to dial Redis and fail with a 500. Auto-fall-back to the + # inline pipeline unless the caller explicitly asked for async. + if sync is None: + sync = not get_settings().queue_enabled + job_id = uuid4() log = _logger.bind(job_id=str(job_id), filename=file.filename or "") content = await file.read() - if not content: - raise UnsupportedDocumentError("Uploaded file is empty.") - if len(content) > _MAX_UPLOAD_BYTES: - raise UnsupportedDocumentError( - f"Uploaded file exceeds {_MAX_UPLOAD_BYTES // (1024 * 1024)} MB limit." - ) + _enforce_size(content) - log.info("documents.received", size=len(content)) + storage = get_blob_storage() + blob_key = storage.put(content, original_filename=file.filename) + source_kind = detect_source_kind(content) + JobRepository(session).create( + job_id=job_id, + filename=file.filename or "", + source_kind=source_kind, + blob_key=blob_key, + ) + # Commit the `pending` row immediately so it is observable regardless + # of what happens next. Both code paths below open their own session + # for state transitions; that way an exception in `_run_inline` cannot + # roll back the create() (which would orphan the blob on disk). + session.commit() + log.info("documents.received", size=len(content), blob_key=blob_key, sync=sync) + + if sync: + # Status code stays at the default 200; the body's `status` field + # tells the client whether the job needs review. + return await _run_inline(job_id, content) + + # Async path — enqueue and return 202. The Celery worker will pick up + # the row using its own session. + from ocr_sprint.worker.tasks import process_document_task + + process_document_task.delay(str(job_id)) + with session_scope() as poll: + row = JobRepository(poll).get_or_raise(job_id) + body = _row_to_response(row) + response.status_code = status.HTTP_202_ACCEPTED + return body + + +async def _run_inline(job_id: UUID, content: bytes) -> DocumentResponse: + """Synchronous pipeline execution. + + Each state transition opens its own short session so the request-scoped + session's rollback-on-exception behaviour cannot wipe out the + ``mark_failed`` write or strand the blob on disk. + """ + import time + + with session_scope() as s: + JobRepository(s).mark_processing(job_id) + + started = time.perf_counter() try: output = run_pipeline(content) except ValueError as exc: + with session_scope() as s: + JobRepository(s).mark_failed(job_id, error=str(exc)) raise UnsupportedDocumentError(str(exc)) from exc + except Exception as exc: + with session_scope() as s: + JobRepository(s).mark_failed(job_id, error=str(exc)) + raise - log.info( - "documents.completed", - status=output.status.value, - confidence=round(output.confidence, 3), - flags=[f.value for f in output.result.review_flags], - ) - return DocumentResponse( - job_id=job_id, - status=output.status, - confidence=output.confidence, - data=output.result, - review_flags=[f.value for f in output.result.review_flags], - ) + flags = [f.value for f in output.result.review_flags] + JOB_PROCESSING_SECONDS.observe(time.perf_counter() - started) + with session_scope() as s: + repo = JobRepository(s) + repo.mark_completed( + job_id, + status=output.status, + confidence=output.confidence, + result=output.result.model_dump(mode="json"), + review_flags=flags, + ) + row = repo.get_or_raise(job_id) + return _row_to_response(row) + + +# ---------- GET ---------- + + +@router.get( + "/{job_id}", + response_model=DocumentResponse, +) +async def get_document( + job_id: UUID, + session: Annotated[Session, Depends(get_session)], +) -> DocumentResponse: + repo = JobRepository(session) + try: + row = repo.get_or_raise(job_id) + except JobNotFoundError as exc: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc + return _row_to_response(row) diff --git a/src/ocr_sprint/config.py b/src/ocr_sprint/config.py index b85a40a..73f8c8f 100644 --- a/src/ocr_sprint/config.py +++ b/src/ocr_sprint/config.py @@ -64,12 +64,29 @@ class Settings(BaseSettings): # Async pipeline (Phase 4) queue_enabled: bool = False redis_url: str = "redis://localhost:6379/0" - database_url: str = "postgresql+psycopg://ocr:ocr@localhost:5432/ocr_sprint" - minio_endpoint: str = "localhost:9000" - minio_access_key: str = "minioadmin" - minio_secret_key: str = "minioadmin" - minio_bucket: str = "ocr-sprint" - minio_secure: bool = False + celery_task_default_queue: str = "ocr_sprint" + + # Persistence (Phase 4). Use sqlite for local dev / tests; Postgres for + # production via docker-compose. + database_url: str = "sqlite:///./storage/ocr_sprint.sqlite" + database_echo: bool = False + + # Blob storage (Phase 4). Local filesystem only for the MVP; the + # `BlobStorage` interface is designed to swap to S3/MinIO without API + # changes when needed. + blob_storage_dir: Path = Path("./storage/blobs") + blob_max_upload_mb: int = 25 + + # Auth (Phase 4). Comma-separated list of API keys accepted by the API. + # Empty string disables auth (only intended for local dev / tests). + # We use ``str`` rather than ``list[str]`` because pydantic-settings rejects + # a bare empty string when binding to a list type. + api_keys: str = "" + api_key_header: str = "X-API-Key" + + @property + def api_keys_list(self) -> list[str]: + return [k.strip() for k in self.api_keys.split(",") if k.strip()] @lru_cache(maxsize=1) @@ -77,4 +94,5 @@ def get_settings() -> Settings: """Cached accessor so settings are loaded once per process.""" settings = Settings() settings.storage_local_dir.mkdir(parents=True, exist_ok=True) + settings.blob_storage_dir.mkdir(parents=True, exist_ok=True) return settings diff --git a/src/ocr_sprint/db/__init__.py b/src/ocr_sprint/db/__init__.py new file mode 100644 index 0000000..3a620ff --- /dev/null +++ b/src/ocr_sprint/db/__init__.py @@ -0,0 +1,14 @@ +"""Persistence layer (Phase 4) — SQLAlchemy 2.0 models, session, repositories.""" + +from ocr_sprint.db.base import Base, get_engine, get_sessionmaker, session_scope +from ocr_sprint.db.models import JobRow +from ocr_sprint.db.repositories import JobRepository + +__all__ = [ + "Base", + "JobRepository", + "JobRow", + "get_engine", + "get_sessionmaker", + "session_scope", +] diff --git a/src/ocr_sprint/db/base.py b/src/ocr_sprint/db/base.py new file mode 100644 index 0000000..6eb4d14 --- /dev/null +++ b/src/ocr_sprint/db/base.py @@ -0,0 +1,67 @@ +"""SQLAlchemy 2.0 engine + session factory. + +We use a single global engine per process. For tests (and SQLite in dev), +StaticPool keeps the same in-memory database across connections; for +Postgres in production we use SQLAlchemy's default pool. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from functools import lru_cache + +from sqlalchemy import create_engine +from sqlalchemy.engine import Engine +from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker +from sqlalchemy.pool import StaticPool + +from ocr_sprint.config import get_settings + + +class Base(DeclarativeBase): + """Common SQLAlchemy declarative base.""" + + +@lru_cache(maxsize=1) +def get_engine() -> Engine: + s = get_settings() + kwargs: dict[str, object] = { + "echo": s.database_echo, + "future": True, + } + # SQLite needs special handling: same connection across threads (Celery + # eager mode + FastAPI) requires `check_same_thread=False`. For the + # ``sqlite:///:memory:`` URL we also use StaticPool to reuse the same + # underlying connection so test fixtures see committed data. + if s.database_url.startswith("sqlite"): + kwargs["connect_args"] = {"check_same_thread": False} + if ":memory:" in s.database_url or s.database_url.endswith(":memory:"): + kwargs["poolclass"] = StaticPool + return create_engine(s.database_url, **kwargs) + + +@lru_cache(maxsize=1) +def get_sessionmaker() -> sessionmaker[Session]: + return sessionmaker(bind=get_engine(), expire_on_commit=False, autoflush=False) + + +@contextmanager +def session_scope() -> Iterator[Session]: + """Yield a SQLAlchemy session and commit/rollback at the boundary.""" + factory = get_sessionmaker() + session = factory() + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + finally: + session.close() + + +def reset_engine_cache() -> None: + """Clear cached engine + sessionmaker. Used by tests when changing DB URL.""" + get_engine.cache_clear() + get_sessionmaker.cache_clear() diff --git a/src/ocr_sprint/db/models.py b/src/ocr_sprint/db/models.py new file mode 100644 index 0000000..0f36202 --- /dev/null +++ b/src/ocr_sprint/db/models.py @@ -0,0 +1,53 @@ +"""SQLAlchemy ORM models for jobs + extraction results. + +We store the structured result as JSON. PaddleOCR's `raw_text` can run into +the tens of kilobytes for multi-page documents; that's well within Postgres' +JSONB row-size budget. SQLite stores it as TEXT under the hood. + +Schema choice: we keep the result inline on the same row instead of a +separate `extraction_results` table. The 1:1 relationship would otherwise +add a join on every read, with no real benefit since results are immutable +once written and there's no use-case for fetching just the metadata. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any +from uuid import UUID, uuid4 + +from sqlalchemy import JSON, DateTime, Float, String, Uuid +from sqlalchemy.orm import Mapped, mapped_column + +from ocr_sprint.db.base import Base + + +def _utcnow() -> datetime: + return datetime.now(timezone.utc) + + +class JobRow(Base): + __tablename__ = "jobs" + + # SQLAlchemy 2.0's Uuid type maps to native UUID on Postgres and CHAR(32) + # on SQLite, so the same model works in both environments. + job_id: Mapped[UUID] = mapped_column(Uuid, primary_key=True, default=uuid4) + status: Mapped[str] = mapped_column(String(32), nullable=False, default="pending") + source_kind: Mapped[str] = mapped_column(String(16), nullable=False, default="unknown") + filename: Mapped[str] = mapped_column(String(512), nullable=False, default="") + blob_key: Mapped[str | None] = mapped_column(String(512), nullable=True) + + confidence: Mapped[float | None] = mapped_column(Float, nullable=True) + review_flags: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list) + result: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True) + error: Mapped[str | None] = mapped_column(String(2048), nullable=True) + + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, default=_utcnow + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False, default=_utcnow, onupdate=_utcnow + ) + + def __repr__(self) -> str: + return f"JobRow(job_id={self.job_id!s}, status={self.status!r})" diff --git a/src/ocr_sprint/db/repositories.py b/src/ocr_sprint/db/repositories.py new file mode 100644 index 0000000..a17817a --- /dev/null +++ b/src/ocr_sprint/db/repositories.py @@ -0,0 +1,96 @@ +"""Thin data-access layer over the ORM. + +Repositories encapsulate the SQL so the API + Celery task code never has to +know about sessions, transactions, or the row → schema mapping. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from ocr_sprint.db.models import JobRow +from ocr_sprint.schemas.document import DocumentStatus, SourceKind + + +def _utcnow() -> datetime: + return datetime.now(timezone.utc) + + +class JobNotFoundError(LookupError): + """Raised by API code when GET /documents/{id} hits a missing row.""" + + +class JobRepository: + """SQL-backed repository for `jobs` rows.""" + + def __init__(self, session: Session) -> None: + self.session = session + + # ---------- writes ---------- + + def create( + self, + *, + job_id: UUID, + filename: str, + source_kind: SourceKind, + blob_key: str, + ) -> JobRow: + row = JobRow( + job_id=job_id, + status=DocumentStatus.PENDING.value, + source_kind=source_kind.value, + filename=filename, + blob_key=blob_key, + ) + self.session.add(row) + self.session.flush() + return row + + def mark_processing(self, job_id: UUID) -> None: + row = self._get_or_raise(job_id) + row.status = DocumentStatus.PROCESSING.value + row.updated_at = _utcnow() + + def mark_completed( + self, + job_id: UUID, + *, + status: DocumentStatus, + confidence: float, + result: dict[str, Any], + review_flags: list[str], + ) -> None: + row = self._get_or_raise(job_id) + row.status = status.value + row.confidence = confidence + row.result = result + row.review_flags = review_flags + row.error = None + row.updated_at = _utcnow() + + def mark_failed(self, job_id: UUID, *, error: str) -> None: + row = self._get_or_raise(job_id) + row.status = DocumentStatus.FAILED.value + row.error = error[:2048] + row.updated_at = _utcnow() + + # ---------- reads ---------- + + def get(self, job_id: UUID) -> JobRow | None: + stmt = select(JobRow).where(JobRow.job_id == job_id) + return self.session.scalar(stmt) + + def get_or_raise(self, job_id: UUID) -> JobRow: + return self._get_or_raise(job_id) + + def _get_or_raise(self, job_id: UUID) -> JobRow: + row = self.get(job_id) + if row is None: + raise JobNotFoundError(f"Job not found: {job_id}") + return row diff --git a/src/ocr_sprint/main.py b/src/ocr_sprint/main.py index 4b5e9b1..05724cf 100644 --- a/src/ocr_sprint/main.py +++ b/src/ocr_sprint/main.py @@ -6,15 +6,29 @@ from fastapi import FastAPI from ocr_sprint import __version__ from ocr_sprint.api.errors import register_error_handlers +from ocr_sprint.api.metrics import MetricsMiddleware, metrics_endpoint from ocr_sprint.api.routes import documents, health from ocr_sprint.config import get_settings +from ocr_sprint.db import models as _models # noqa: F401 (register ORM tables) +from ocr_sprint.db.base import Base, get_engine from ocr_sprint.utils.logging import configure_logging +def _ensure_schema() -> None: + """Create tables if they don't exist. + + Production deploys should run Alembic migrations explicitly; this is a + convenience for local dev / tests so the API works without a manual + `alembic upgrade head` step. + """ + Base.metadata.create_all(bind=get_engine()) + + def create_app() -> FastAPI: """Application factory — keeps top-level state easy to test.""" settings = get_settings() configure_logging(settings.app_log_level) + _ensure_schema() app = FastAPI( title="OCR Sprint Service", @@ -26,8 +40,10 @@ def create_app() -> FastAPI: ) register_error_handlers(app) + app.add_middleware(MetricsMiddleware) app.include_router(health.router, prefix="/api/v1") app.include_router(documents.router, prefix="/api/v1") + app.add_api_route("/metrics", metrics_endpoint, methods=["GET"], include_in_schema=False) return app diff --git a/src/ocr_sprint/storage/__init__.py b/src/ocr_sprint/storage/__init__.py new file mode 100644 index 0000000..49bfc3f --- /dev/null +++ b/src/ocr_sprint/storage/__init__.py @@ -0,0 +1,5 @@ +"""Pluggable blob storage. Local-fs only for the MVP.""" + +from ocr_sprint.storage.blob import BlobStorage, LocalFsBlobStorage, get_blob_storage + +__all__ = ["BlobStorage", "LocalFsBlobStorage", "get_blob_storage"] diff --git a/src/ocr_sprint/storage/blob.py b/src/ocr_sprint/storage/blob.py new file mode 100644 index 0000000..8eb2237 --- /dev/null +++ b/src/ocr_sprint/storage/blob.py @@ -0,0 +1,146 @@ +"""Blob storage abstraction. + +The MVP only ships a local-filesystem backend. The `BlobStorage` Protocol is +deliberately small (put / get / exists / delete) so that an S3- or MinIO- +backed implementation can be dropped in later without touching API code. + +Layout on disk: + + {blob_storage_dir}/ + 2026/04/25/ + . + +The date hierarchy keeps the directory listing manageable when the service +processes thousands of documents per day, and makes manual rsync-based +backup straightforward. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from pathlib import Path +from typing import BinaryIO, Protocol +from uuid import uuid4 + +from ocr_sprint.config import get_settings +from ocr_sprint.utils.logging import get_logger + +_logger = get_logger(__name__) + +# Map of upload extensions we'll honor when persisting blobs. Anything else +# falls back to `.bin` and the OCR pipeline's magic-byte sniffing handles +# the actual content kind. +_KNOWN_EXTS = {".pdf", ".png", ".jpg", ".jpeg", ".tif", ".tiff", ".webp"} + + +class BlobStorage(Protocol): + """Minimal interface a blob backend must satisfy.""" + + def put(self, content: bytes, original_filename: str | None = None) -> str: + """Persist `content` and return an opaque key the caller can use later.""" + + def get(self, key: str) -> bytes: + """Return the raw bytes for `key`. Raises FileNotFoundError on miss.""" + + def open(self, key: str) -> BinaryIO: + """Return a binary file-like object for streaming reads.""" + + def exists(self, key: str) -> bool: + """True if `key` is currently stored.""" + + def delete(self, key: str) -> None: + """Remove a blob. No-op if it doesn't exist.""" + + +class LocalFsBlobStorage: + """Filesystem-backed implementation rooted at `base_dir`.""" + + def __init__(self, base_dir: Path) -> None: + # Resolve once so every subsequent path comparison (escape check, + # empty-dir cleanup) is apples-to-apples — ``Path.parents`` of a + # resolved key would otherwise never equal a relative ``base_dir``. + base_dir.mkdir(parents=True, exist_ok=True) + self.base_dir = base_dir.resolve() + + # ---------- helpers ---------- + + @staticmethod + def _safe_ext(original_filename: str | None) -> str: + if not original_filename: + return ".bin" + suffix = Path(original_filename).suffix.lower() + return suffix if suffix in _KNOWN_EXTS else ".bin" + + def _resolve(self, key: str) -> Path: + # Defensive: keys come from the DB but we still reject paths that try + # to escape the blob root. ``Path.is_relative_to`` does proper path + # containment — string ``startswith`` would let ``/app/blobs_evil`` + # slip past when the root is ``/app/blobs``. + candidate = (self.base_dir / key).resolve() + if not candidate.is_relative_to(self.base_dir): + raise ValueError(f"Blob key escapes storage root: {key!r}") + return candidate + + # ---------- BlobStorage protocol ---------- + + def put(self, content: bytes, original_filename: str | None = None) -> str: + now = datetime.now(timezone.utc) + date_dir = Path(f"{now:%Y/%m/%d}") + ext = self._safe_ext(original_filename) + key = str(date_dir / f"{uuid4().hex}{ext}") + target = self._resolve(key) + target.parent.mkdir(parents=True, exist_ok=True) + # Write to a temp file in the same directory then rename. This avoids + # a half-written blob being read by a concurrent worker. + tmp = target.with_suffix(target.suffix + ".tmp") + tmp.write_bytes(content) + tmp.rename(target) + _logger.info("blob.put", key=key, size=len(content)) + return key + + def get(self, key: str) -> bytes: + path = self._resolve(key) + if not path.exists(): + raise FileNotFoundError(f"Blob not found: {key}") + return path.read_bytes() + + def open(self, key: str) -> BinaryIO: + path = self._resolve(key) + if not path.exists(): + raise FileNotFoundError(f"Blob not found: {key}") + return path.open("rb") + + def exists(self, key: str) -> bool: + try: + return self._resolve(key).exists() + except ValueError: + return False + + def delete(self, key: str) -> None: + try: + path = self._resolve(key) + except ValueError: + return + if path.exists(): + path.unlink() + _logger.info("blob.delete", key=key) + # Best-effort cleanup of empty date dirs so we don't accumulate + # 365 directories per year forever. ``self.base_dir`` is already + # resolved (see __init__), so it can be compared against + # ``path.parents`` directly. + for parent in path.parents: + if parent == self.base_dir or self.base_dir not in parent.parents: + break + try: + parent.rmdir() + except OSError: + break + + +def get_blob_storage() -> BlobStorage: + """Build the configured blob backend. Single-process cache lives in `Settings`.""" + s = get_settings() + return LocalFsBlobStorage(s.blob_storage_dir) + + +__all__ = ["BlobStorage", "LocalFsBlobStorage", "get_blob_storage"] diff --git a/src/ocr_sprint/worker/__init__.py b/src/ocr_sprint/worker/__init__.py new file mode 100644 index 0000000..4242945 --- /dev/null +++ b/src/ocr_sprint/worker/__init__.py @@ -0,0 +1,6 @@ +"""Celery worker (Phase 4) — async OCR pipeline.""" + +from ocr_sprint.worker.celery_app import celery_app +from ocr_sprint.worker.tasks import process_document_task + +__all__ = ["celery_app", "process_document_task"] diff --git a/src/ocr_sprint/worker/celery_app.py b/src/ocr_sprint/worker/celery_app.py new file mode 100644 index 0000000..68d427e --- /dev/null +++ b/src/ocr_sprint/worker/celery_app.py @@ -0,0 +1,49 @@ +"""Celery application factory. + +The broker and result backend are both Redis. We deliberately don't use +Postgres as the result backend — Celery's pg result-backend creates noisy +schema artefacts and we already store the structured result on the `jobs` +table. + +Tasks register themselves by calling `celery_app.task` in `tasks.py`. Eager +mode (used in tests) is enabled by setting `CELERY_TASK_ALWAYS_EAGER=true` +in the environment. +""" + +from __future__ import annotations + +import os + +from celery import Celery + +from ocr_sprint.config import get_settings + + +def build_celery_app() -> Celery: + settings = get_settings() + app = Celery( + "ocr_sprint", + broker=settings.redis_url, + backend=settings.redis_url, + include=["ocr_sprint.worker.tasks"], + ) + app.conf.update( + task_default_queue=settings.celery_task_default_queue, + task_serializer="json", + result_serializer="json", + accept_content=["json"], + timezone="UTC", + enable_utc=True, + task_acks_late=True, + task_reject_on_worker_lost=True, + worker_prefetch_multiplier=1, # OCR is CPU-bound; one task at a time. + broker_connection_retry_on_startup=True, + result_expires=24 * 3600, # results live in DB; redis is just a cache + ) + if os.getenv("CELERY_TASK_ALWAYS_EAGER", "").lower() in {"1", "true", "yes"}: + app.conf.task_always_eager = True + app.conf.task_eager_propagates = True + return app + + +celery_app = build_celery_app() diff --git a/src/ocr_sprint/worker/tasks.py b/src/ocr_sprint/worker/tasks.py new file mode 100644 index 0000000..e3e2372 --- /dev/null +++ b/src/ocr_sprint/worker/tasks.py @@ -0,0 +1,84 @@ +"""Celery tasks. + +`process_document_task` is the single async entrypoint: the API enqueues +one task per upload, the worker pulls the blob, runs the orchestrator, and +writes the result back to the `jobs` table. + +Tasks must be idempotent at the boundary: if the worker crashes mid-OCR, +the next retry should pick up the same blob and produce the same result. +We therefore re-fetch the row by id on every transition rather than +threading state through closures. +""" + +from __future__ import annotations + +from uuid import UUID + +from celery.exceptions import Reject + +from ocr_sprint.db.base import session_scope +from ocr_sprint.db.repositories import JobRepository +from ocr_sprint.pipeline.orchestrator import run_pipeline +from ocr_sprint.storage.blob import get_blob_storage +from ocr_sprint.utils.logging import get_logger +from ocr_sprint.worker.celery_app import celery_app + +_logger = get_logger(__name__) + + +@celery_app.task(name="ocr_sprint.process_document", bind=True, max_retries=0) # type: ignore[untyped-decorator] +def process_document_task(self: object, job_id_str: str) -> str: + """Run the OCR pipeline for `job_id` and persist the structured result. + + Returns the final job status as a string so callers can wait on + `AsyncResult.get()` if they want to (mainly tests). + """ + job_id = UUID(job_id_str) + log = _logger.bind(job_id=job_id_str) + storage = get_blob_storage() + + with session_scope() as session: + repo = JobRepository(session) + row = repo.get(job_id) + if row is None: + raise Reject(f"Job {job_id_str} not found", requeue=False) + repo.mark_processing(job_id) + blob_key = row.blob_key + + if not blob_key: + with session_scope() as session: + JobRepository(session).mark_failed(job_id, error="missing blob_key") + return "failed" + + try: + content = storage.get(blob_key) + except FileNotFoundError as exc: + log.error("worker.blob_missing", error=str(exc)) + with session_scope() as session: + JobRepository(session).mark_failed(job_id, error=f"blob missing: {exc}") + return "failed" + + try: + output = run_pipeline(content) + except Exception as exc: + log.exception("worker.pipeline_error") + with session_scope() as session: + JobRepository(session).mark_failed(job_id, error=str(exc)) + return "failed" + + flags = [f.value for f in output.result.review_flags] + log.info( + "worker.completed", + status=output.status.value, + confidence=round(output.confidence, 3), + flags=flags, + ) + with session_scope() as session: + JobRepository(session).mark_completed( + job_id, + status=output.status, + confidence=output.confidence, + result=output.result.model_dump(mode="json"), + review_flags=flags, + ) + return output.status.value diff --git a/tests/conftest.py b/tests/conftest.py index 75f48d8..911782e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,10 +2,54 @@ from __future__ import annotations +import os +from collections.abc import Iterator +from pathlib import Path + import numpy as np import pytest +@pytest.fixture(autouse=True) +def _isolated_runtime(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[None]: + """Per-test sqlite + blob storage so tests don't share state. + + Setting these env vars before ``Settings`` is first read in the test gives + each test its own DB file and blob root. We also clear the lru_cache on + `get_settings`, the engine, and the sessionmaker so the fresh paths take + effect even if a previous test already loaded settings. + """ + db_path = tmp_path / "test.sqlite" + blob_dir = tmp_path / "blobs" + monkeypatch.setenv("DATABASE_URL", f"sqlite:///{db_path}") + monkeypatch.setenv("BLOB_STORAGE_DIR", str(blob_dir)) + monkeypatch.setenv("STORAGE_LOCAL_DIR", str(tmp_path / "storage")) + monkeypatch.setenv("API_KEYS", "") + # The async API path is exercised by the test suite, so default it on + # here. Production keeps ``QUEUE_ENABLED=false`` so the route falls back + # to the inline pipeline when no Redis is configured. + monkeypatch.setenv("QUEUE_ENABLED", "true") + # Force Celery to run tasks inline so we don't need a broker. + monkeypatch.setenv("CELERY_TASK_ALWAYS_EAGER", "true") + + from ocr_sprint.config import get_settings + from ocr_sprint.db.base import reset_engine_cache + from ocr_sprint.worker.celery_app import celery_app + + get_settings.cache_clear() + reset_engine_cache() + # `celery_app` is built once at import-time, so flip the eager flag on the + # already-instantiated instance for this test. + celery_app.conf.task_always_eager = True + celery_app.conf.task_eager_propagates = True + + yield + + get_settings.cache_clear() + reset_engine_cache() + os.environ.pop("CELERY_TASK_ALWAYS_EAGER", None) + + @pytest.fixture def blank_bgr_image() -> np.ndarray: """A 600x800 white BGR image (uint8) — useful for preprocessing smoke tests.""" diff --git a/tests/unit/test_api.py b/tests/unit/test_api.py index be8addd..716be1f 100644 --- a/tests/unit/test_api.py +++ b/tests/unit/test_api.py @@ -23,35 +23,9 @@ def client() -> TestClient: return TestClient(create_app()) -def test_health_endpoint(client: TestClient) -> None: - response = client.get("/api/v1/health") - assert response.status_code == 200 - assert response.json()["status"] == "ok" - - -def test_documents_rejects_empty_upload(client: TestClient) -> None: - response = client.post( - "/api/v1/documents", - files={"file": ("empty.pdf", b"", "application/pdf")}, - ) - assert response.status_code == 400 - - -def test_documents_rejects_unknown_format( - client: TestClient, - monkeypatch: pytest.MonkeyPatch, -) -> None: - response = client.post( - "/api/v1/documents", - files={"file": ("x.bin", b"random garbage bytes here", "application/octet-stream")}, - ) - assert response.status_code == 400 - - -def test_documents_returns_pipeline_output( - client: TestClient, - monkeypatch: pytest.MonkeyPatch, -) -> None: +@pytest.fixture +def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput: + """Patch run_pipeline everywhere it's referenced.""" fake_result = ExtractionResult( header=HeaderFields( nomor_sprint="Sprin/1/I/2025", @@ -70,14 +44,36 @@ def test_documents_returns_pipeline_output( def _fake_run(_content: bytes) -> PipelineOutput: return fake_output - # Patch the symbol *imported into* the routes module. monkeypatch.setattr(orch_module, "run_pipeline", _fake_run) from ocr_sprint.api.routes import documents as docs_module monkeypatch.setattr(docs_module, "run_pipeline", _fake_run) + from ocr_sprint.worker import tasks as tasks_module + monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run) + return fake_output + + +def test_health_endpoint(client: TestClient) -> None: + response = client.get("/api/v1/health") + assert response.status_code == 200 + assert response.json()["status"] == "ok" + + +def test_documents_rejects_empty_upload(client: TestClient) -> None: response = client.post( "/api/v1/documents", + files={"file": ("empty.pdf", b"", "application/pdf")}, + ) + assert response.status_code == 400 + + +def test_documents_sync_returns_pipeline_output( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + response = client.post( + "/api/v1/documents?sync=true", files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, ) assert response.status_code == 200 @@ -85,3 +81,169 @@ def test_documents_returns_pipeline_output( assert body["status"] == "completed" assert body["confidence"] == 0.97 assert body["data"]["header"]["nomor_sprint"] == "Sprin/1/I/2025" + + +def test_documents_async_returns_202_then_polls_to_completion( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + """Default flow: POST returns 202, GET returns the eventual completion. + + With CELERY_TASK_ALWAYS_EAGER set in conftest, the worker runs inline, + so by the time POST returns the task has already finished and GET sees + a `completed` row. + """ + post = client.post( + "/api/v1/documents", + files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, + ) + assert post.status_code == 202 + job_id = post.json()["job_id"] + + get = client.get(f"/api/v1/documents/{job_id}") + assert get.status_code == 200 + body = get.json() + assert body["status"] == "completed" + assert body["confidence"] == 0.97 + + +def test_documents_defaults_to_sync_when_queue_disabled( + client: TestClient, + fake_pipeline: PipelineOutput, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Regression: with ``QUEUE_ENABLED=false`` the route must NOT enqueue, + otherwise a default install with no Redis returns 500. + """ + monkeypatch.setenv("QUEUE_ENABLED", "false") + from ocr_sprint.config import get_settings + + get_settings.cache_clear() + + # Pretend the broker is unreachable; if the route still enqueues, the + # call would blow up here. + def _no_broker(_self: object, *_args: object, **_kwargs: object) -> None: + raise AssertionError("queue path taken when queue is disabled") + + from ocr_sprint.worker import tasks as task_module + + monkeypatch.setattr(task_module.process_document_task, "delay", _no_broker) + + post = client.post( + "/api/v1/documents", + files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, + ) + assert post.status_code == 200, post.text + body = post.json() + assert body["status"] == "completed" + + +def test_documents_get_unknown_id_returns_404(client: TestClient) -> None: + response = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000") + assert response.status_code == 404 + + +def test_documents_async_marks_failed_on_pipeline_error( + client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + def _explode(_content: bytes) -> PipelineOutput: + raise RuntimeError("boom") + + from ocr_sprint.worker import tasks as tasks_module + + monkeypatch.setattr(tasks_module, "run_pipeline", _explode) + + post = client.post( + "/api/v1/documents", + files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, + ) + assert post.status_code == 202 + job_id = post.json()["job_id"] + + get = client.get(f"/api/v1/documents/{job_id}") + body = get.json() + assert body["status"] == "failed" + assert "boom" in (body.get("error") or "") + + +def test_documents_sync_persists_failed_row_when_pipeline_raises( + client: TestClient, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Regression: an exception in the sync pipeline must NOT roll back the + pending row + ``mark_failed`` write. Otherwise the blob on disk has no + DB record pointing at it. + """ + + def _explode(_content: bytes) -> PipelineOutput: + raise RuntimeError("kapow") + + from ocr_sprint.api.routes import documents as docs_module + + monkeypatch.setattr(docs_module, "run_pipeline", _explode) + + # ``raise_server_exceptions=False`` lets the test see the 500 response + # rather than re-raising the underlying RuntimeError from the route. + silent = TestClient(client.app, raise_server_exceptions=False) + post = silent.post( + "/api/v1/documents?sync=true", + files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, + ) + assert post.status_code == 500 + + # The row must still be visible to GET, with status=failed. + from ocr_sprint.db.base import session_scope + from ocr_sprint.db.repositories import JobRepository + + with session_scope() as session: + # Find the most recent row. + from ocr_sprint.db.models import JobRow + + row = session.query(JobRow).order_by(JobRow.created_at.desc()).first() + assert row is not None, "create() must persist even when pipeline blows up" + assert row.status == "failed" + assert "kapow" in (row.error or "") + assert row.blob_key # blob is referenced — not orphaned + + # GET must surface the failure too (this is the client-visible contract). + get = client.get(f"/api/v1/documents/{row.job_id}") + assert get.status_code == 200 + assert get.json()["status"] == "failed" + assert JobRepository # silence import-only warning + + +def test_metrics_endpoint_exposes_request_counter( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + client.post( + "/api/v1/documents?sync=true", + files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, + ) + metrics = client.get("/metrics") + assert metrics.status_code == 200 + body = metrics.text + assert "http_requests_total" in body + assert "ocr_jobs_total" in body + + +def test_metrics_jobs_total_reflects_worker_writes( + client: TestClient, + fake_pipeline: PipelineOutput, +) -> None: + """Regression: when the worker (eager mode here) marks a job complete, + /metrics must reflect that — the previous Counter-based implementation + would have stayed at zero because the worker's increments don't reach + the API process's in-memory registry. + """ + post = client.post( + "/api/v1/documents", + files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")}, + ) + assert post.status_code == 202 + + body = client.get("/metrics").text + # ``ocr_jobs_total{status="completed"} 1.0`` — exact match to make sure + # the gauge-style metric is being populated from the DB. + assert 'ocr_jobs_total{status="completed"} 1.0' in body diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py new file mode 100644 index 0000000..d2615cf --- /dev/null +++ b/tests/unit/test_auth.py @@ -0,0 +1,43 @@ +"""API key authentication.""" + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from ocr_sprint.config import get_settings +from ocr_sprint.main import create_app + + +def _client_with_keys(monkeypatch: pytest.MonkeyPatch, keys: str) -> TestClient: + monkeypatch.setenv("API_KEYS", keys) + get_settings.cache_clear() + return TestClient(create_app()) + + +def test_auth_disabled_when_keys_empty(monkeypatch: pytest.MonkeyPatch) -> None: + client = _client_with_keys(monkeypatch, "") + response = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000") + # 404 not 401: auth disabled, the endpoint just doesn't find the row. + assert response.status_code == 404 + + +def test_auth_rejects_missing_key(monkeypatch: pytest.MonkeyPatch) -> None: + client = _client_with_keys(monkeypatch, "secret-1,secret-2") + response = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000") + assert response.status_code == 401 + + +def test_auth_accepts_valid_key(monkeypatch: pytest.MonkeyPatch) -> None: + client = _client_with_keys(monkeypatch, "secret-1,secret-2") + response = client.get( + "/api/v1/documents/00000000-0000-0000-0000-000000000000", + headers={"X-API-Key": "secret-2"}, + ) + assert response.status_code == 404 + + +def test_health_is_unprotected(monkeypatch: pytest.MonkeyPatch) -> None: + client = _client_with_keys(monkeypatch, "secret-1") + response = client.get("/api/v1/health") + assert response.status_code == 200 diff --git a/tests/unit/test_blob_storage.py b/tests/unit/test_blob_storage.py new file mode 100644 index 0000000..d510aff --- /dev/null +++ b/tests/unit/test_blob_storage.py @@ -0,0 +1,85 @@ +"""Local-filesystem blob storage.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from ocr_sprint.storage.blob import LocalFsBlobStorage + + +@pytest.fixture +def storage(tmp_path: Path) -> LocalFsBlobStorage: + return LocalFsBlobStorage(tmp_path / "blobs") + + +def test_put_returns_dated_key(storage: LocalFsBlobStorage) -> None: + key = storage.put(b"hello", original_filename="surat.pdf") + # Layout is YYYY/MM/DD/.pdf + parts = key.split("/") + assert len(parts) == 4 + assert parts[3].endswith(".pdf") + assert storage.exists(key) + assert storage.get(key) == b"hello" + + +def test_put_unknown_extension_falls_back_to_bin(storage: LocalFsBlobStorage) -> None: + key = storage.put(b"x", original_filename="weird.xyz") + assert key.endswith(".bin") + + +def test_put_strips_directory_traversal(storage: LocalFsBlobStorage) -> None: + # ext is taken via Path().suffix, not from the raw filename, so a name + # like "../../etc/passwd" is harmless — the only thing the caller can + # influence is the extension. + key = storage.put(b"y", original_filename="../../etc/passwd") + assert "etc" not in key + assert key.endswith(".bin") + + +def test_put_handles_missing_filename(storage: LocalFsBlobStorage) -> None: + key = storage.put(b"z", original_filename=None) + assert key.endswith(".bin") + + +def test_get_unknown_key_raises(storage: LocalFsBlobStorage) -> None: + with pytest.raises(FileNotFoundError): + storage.get("2026/01/01/bogus.pdf") + + +def test_delete_is_idempotent(storage: LocalFsBlobStorage) -> None: + key = storage.put(b"q", original_filename="x.png") + storage.delete(key) + assert not storage.exists(key) + storage.delete(key) # second delete must not raise + + +def test_resolve_rejects_path_escape(storage: LocalFsBlobStorage) -> None: + with pytest.raises(ValueError, match="escapes storage root"): + storage._resolve("../../../etc/passwd") + + +def test_resolve_rejects_directory_prefix_collision(tmp_path: Path) -> None: + """Regression: ``startswith`` would mis-accept sibling dirs whose names + happen to begin with the storage root's basename. ``is_relative_to`` + handles this correctly. + """ + root = tmp_path / "blobs" + root.mkdir() + sibling = tmp_path / "blobs_evil" + sibling.mkdir() + storage = LocalFsBlobStorage(root) + with pytest.raises(ValueError, match="escapes storage root"): + storage._resolve("../blobs_evil/secret.txt") + + +def test_exists_returns_false_for_escaped_key(storage: LocalFsBlobStorage) -> None: + # exists() must not raise even for malicious keys. + assert storage.exists("../../etc/passwd") is False + + +def test_open_streams_content(storage: LocalFsBlobStorage) -> None: + key = storage.put(b"streamed", original_filename="x.png") + with storage.open(key) as fh: + assert fh.read() == b"streamed" diff --git a/tests/unit/test_db_repository.py b/tests/unit/test_db_repository.py new file mode 100644 index 0000000..5c6e1a6 --- /dev/null +++ b/tests/unit/test_db_repository.py @@ -0,0 +1,76 @@ +"""SQLAlchemy repository tests against an in-memory sqlite db.""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest + +from ocr_sprint.db.base import Base, get_engine, session_scope +from ocr_sprint.db.repositories import JobNotFoundError, JobRepository +from ocr_sprint.schemas.document import DocumentStatus, SourceKind + + +@pytest.fixture +def db_ready() -> None: + Base.metadata.create_all(bind=get_engine()) + + +def test_create_then_fetch(db_ready: None) -> None: + jid = uuid4() + with session_scope() as session: + JobRepository(session).create( + job_id=jid, + filename="x.pdf", + source_kind=SourceKind.PDF, + blob_key="2026/01/01/x.pdf", + ) + with session_scope() as session: + row = JobRepository(session).get_or_raise(jid) + assert row.status == DocumentStatus.PENDING.value + assert row.source_kind == SourceKind.PDF.value + assert row.blob_key == "2026/01/01/x.pdf" + + +def test_lifecycle_transitions(db_ready: None) -> None: + jid = uuid4() + with session_scope() as session: + JobRepository(session).create( + job_id=jid, + filename="x.pdf", + source_kind=SourceKind.PDF, + blob_key="k", + ) + with session_scope() as session: + JobRepository(session).mark_processing(jid) + with session_scope() as session: + repo = JobRepository(session) + repo.mark_completed( + jid, + status=DocumentStatus.NEEDS_REVIEW, + confidence=0.88, + result={"header": {"nomor_sprint": "Sprin/1/2025"}}, + review_flags=["low_ocr_confidence"], + ) + row = repo.get_or_raise(jid) + assert row.status == DocumentStatus.NEEDS_REVIEW.value + assert row.confidence == 0.88 + assert row.result == {"header": {"nomor_sprint": "Sprin/1/2025"}} + assert row.review_flags == ["low_ocr_confidence"] + + +def test_mark_failed_truncates_long_error(db_ready: None) -> None: + jid = uuid4() + with session_scope() as session: + JobRepository(session).create( + job_id=jid, filename="x", source_kind=SourceKind.UNKNOWN, blob_key="k" + ) + with session_scope() as session: + JobRepository(session).mark_failed(jid, error="x" * 5000) + row = JobRepository(session).get_or_raise(jid) + assert len(row.error or "") == 2048 + + +def test_unknown_job_raises(db_ready: None) -> None: + with session_scope() as session, pytest.raises(JobNotFoundError): + JobRepository(session).get_or_raise(uuid4())