Phase 4: async pipeline (Celery+Redis), Postgres job state, local-fs blob storage, API-key auth, Prometheus metrics (#3)
* Phase 4: async pipeline (Celery+Redis), Postgres job state, local-fs blob storage, API-key auth, Prometheus metrics Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com> * Phase 4: fix sync-mode rollback orphaning blobs + use is_relative_to for path-escape check Devin Review on PR #3 found two real bugs: 1. Sync path mark_failed was rolled back by the request-scoped session. When the pipeline raised an exception in ?sync=true mode, _run_inline modified the FastAPI session and re-raised; get_session caught the exception, called session.rollback(), and wiped both the create() and the mark_failed() writes. The blob was already on disk, so it was permanently orphaned with no DB record. Fix: commit the pending row immediately after create(), and run all subsequent state transitions in independent session_scope blocks (matching the worker task pattern). 2. _resolve used str.startswith for path-escape detection, which lets a sibling directory whose name begins with the storage root pass (e.g. /app/blobs_evil vs /app/blobs). Switched to Path.is_relative_to. Added regression tests for both. Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com> * Phase 4: honor queue_enabled setting + resolve base_dir for path comparisons Two more bugs found by Devin Review: 3. queue_enabled was declared in config and documented in .env.example but never read by the route. A fresh dev install with QUEUE_ENABLED=false (the default) would still enqueue, then fail with a Redis connection error. Fixed by making the ?sync= query param default to None and resolving to (not queue_enabled) inside the route. Tests now set QUEUE_ENABLED=true so the async flow stays exercised, and a new test verifies the inline fallback when the queue is disabled. 4. LocalFsBlobStorage stored base_dir as-is. _resolve resolved its candidate paths, so the empty-dir cleanup loop in delete() compared a resolved candidate against an unresolved base_dir and broke on the first iteration (no cleanup ever happened). Fixed by resolving base_dir once in __init__ so every path comparison is apples-to-apples. Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com> * Phase 4: derive ocr_jobs_total from DB so worker writes are visible at /metrics Devin Review correctly noted the Counter-based JOBS_TOTAL would never increment in production because the worker runs in a separate process from the API and the registry is process-local. Replaced JOBS_TOTAL with a custom Collector that issues SELECT status, COUNT(*) FROM jobs GROUP BY status on every /metrics scrape. Result: the metric stays accurate regardless of which process wrote the row. Also corrected the metrics.py docstring (the old comment claimed the counter was 'incremented by the worker', which was the bug). Removed the JOBS_TOTAL.inc() calls from the sync route — the DB collector covers both paths now. JOB_PROCESSING_SECONDS stays as an API-process histogram with an updated docstring noting its scope; cross-process latency belongs to derived dashboards over jobs.created_at/updated_at. Added regression test test_metrics_jobs_total_reflects_worker_writes. Co-Authored-By: adrian kuman firmansah <adriancuman@gmail.com> --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: adrian kuman firmansah <adriancuman@gmail.com>
This commit is contained in:
committed by
GitHub
parent
33b38aacc7
commit
2112023b6e
24
.env.example
24
.env.example
@@ -40,12 +40,20 @@ LLM_MODEL=qwen2.5:1.5b # CPU-friendly default
|
|||||||
LLM_BASE_URL=http://localhost:11434
|
LLM_BASE_URL=http://localhost:11434
|
||||||
LLM_TIMEOUT_S=60
|
LLM_TIMEOUT_S=60
|
||||||
|
|
||||||
# ==== Async pipeline (Phase 4, optional) ====
|
# ==== Async pipeline + persistence (Phase 4) ====
|
||||||
QUEUE_ENABLED=false
|
QUEUE_ENABLED=false # POST /documents queues async when true
|
||||||
REDIS_URL=redis://localhost:6379/0
|
REDIS_URL=redis://localhost:6379/0
|
||||||
DATABASE_URL=postgresql+psycopg://ocr:ocr@localhost:5432/ocr_sprint
|
CELERY_TASK_DEFAULT_QUEUE=ocr_sprint
|
||||||
MINIO_ENDPOINT=localhost:9000
|
|
||||||
MINIO_ACCESS_KEY=minioadmin
|
# Persistence: sqlite for local dev, Postgres for production via docker-compose.
|
||||||
MINIO_SECRET_KEY=minioadmin
|
DATABASE_URL=sqlite:///./storage/ocr_sprint.sqlite
|
||||||
MINIO_BUCKET=ocr-sprint
|
DATABASE_ECHO=false
|
||||||
MINIO_SECURE=false
|
|
||||||
|
# Blob storage: local filesystem only for the MVP (no S3/MinIO).
|
||||||
|
BLOB_STORAGE_DIR=./storage/blobs
|
||||||
|
BLOB_MAX_UPLOAD_MB=25
|
||||||
|
|
||||||
|
# Auth: comma-separated list of accepted API keys. Empty = auth disabled
|
||||||
|
# (local dev only; production must set at least one).
|
||||||
|
API_KEYS=
|
||||||
|
API_KEY_HEADER=X-API-Key
|
||||||
|
|||||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -48,7 +48,10 @@ samples/*.tif
|
|||||||
samples/*.tiff
|
samples/*.tiff
|
||||||
!samples/README.md
|
!samples/README.md
|
||||||
data/local/
|
data/local/
|
||||||
storage/
|
# Runtime data dirs (blobs, sqlite). The src tree's `storage` package is a
|
||||||
|
# real Python module — see the `!` rule below.
|
||||||
|
/storage/
|
||||||
|
!src/ocr_sprint/storage/
|
||||||
*.db
|
*.db
|
||||||
*.sqlite
|
*.sqlite
|
||||||
*.sqlite3
|
*.sqlite3
|
||||||
|
|||||||
13
Dockerfile
13
Dockerfile
@@ -28,17 +28,22 @@ WORKDIR /app
|
|||||||
FROM base AS builder
|
FROM base AS builder
|
||||||
COPY pyproject.toml README.md ./
|
COPY pyproject.toml README.md ./
|
||||||
COPY src/ ./src/
|
COPY src/ ./src/
|
||||||
RUN pip install --upgrade pip && pip install ".[dev]"
|
# `[ocr]` pulls Paddle wheels (~1.5 GB). `[dev]` keeps test+lint deps so
|
||||||
|
# that `make test` works inside the image.
|
||||||
|
RUN pip install --upgrade pip && pip install ".[ocr,dev]"
|
||||||
|
|
||||||
# ----- runtime layer -----
|
# ----- runtime layer -----
|
||||||
FROM base AS runtime
|
FROM base AS runtime
|
||||||
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
|
COPY --from=builder /usr/local/lib/python3.11/site-packages /usr/local/lib/python3.11/site-packages
|
||||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||||
COPY pyproject.toml README.md ./
|
COPY pyproject.toml README.md alembic.ini ./
|
||||||
COPY src/ ./src/
|
COPY src/ ./src/
|
||||||
|
COPY alembic/ ./alembic/
|
||||||
|
|
||||||
# Pre-create cache dirs so PaddleOCR can write models on first run.
|
# Pre-create cache dirs so PaddleOCR can write models on first run, and
|
||||||
RUN mkdir -p /home/app/.paddleocr /app/storage \
|
# the blob storage root so the API can write uploads as the unprivileged
|
||||||
|
# `app` user.
|
||||||
|
RUN mkdir -p /home/app/.paddleocr /app/storage/blobs \
|
||||||
&& useradd --create-home --uid 1000 app \
|
&& useradd --create-home --uid 1000 app \
|
||||||
&& chown -R app:app /home/app /app
|
&& chown -R app:app /home/app /app
|
||||||
|
|
||||||
|
|||||||
28
README.md
28
README.md
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
OCR + structured extraction service for Indonesian police "surat sprint" (surat perintah) documents. Built around **FastAPI + PaddleOCR + hybrid extraction (regex → LLM lokal → validation)** with **on-premise** deployment as a hard requirement.
|
OCR + structured extraction service for Indonesian police "surat sprint" (surat perintah) documents. Built around **FastAPI + PaddleOCR + hybrid extraction (regex → LLM lokal → validation)** with **on-premise** deployment as a hard requirement.
|
||||||
|
|
||||||
> **Status:** Phase 1+2+3 — synchronous PDF/image OCR with regex header extraction, validation, confidence scoring, document detection + perspective correction + shadow removal for phone photos, and **PP-Structure table extraction** for personnel rows. Phase 4–6 (async pipeline, LLM extraction, HITL) are tracked in [`docs/architecture.md`](docs/architecture.md).
|
> **Status:** Phase 1–4 — synchronous + async PDF/image OCR with regex header extraction, PP-Structure personnel-table extraction, validation, confidence scoring, document detection / perspective correction / shadow removal, **Celery + Redis job queue, Postgres job state, local-filesystem blob storage, API-key auth, and Prometheus metrics**. Phase 5–6 (LLM extraction, HITL) are tracked in [`docs/architecture.md`](docs/architecture.md).
|
||||||
|
|
||||||
## Why this stack
|
## Why this stack
|
||||||
|
|
||||||
@@ -29,6 +29,7 @@ cd ocr-sprint-service
|
|||||||
|
|
||||||
python -m venv .venv && source .venv/bin/activate
|
python -m venv .venv && source .venv/bin/activate
|
||||||
make install # installs runtime + dev deps + pre-commit
|
make install # installs runtime + dev deps + pre-commit
|
||||||
|
pip install -e ".[ocr]" # only on the worker host — pulls Paddle wheels (~1.5 GB)
|
||||||
cp .env.example .env # edit if you need GPU / different storage path
|
cp .env.example .env # edit if you need GPU / different storage path
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -41,8 +42,21 @@ make dev
|
|||||||
|
|
||||||
### Try it out
|
### Try it out
|
||||||
|
|
||||||
|
The default `POST /documents` is async — it returns `202 Accepted` with a `job_id` and the worker fills in the result. For tests / local one-shot usage you can append `?sync=true` to run inline.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
curl -F "file=@samples/pdf/example.pdf" http://localhost:8000/api/v1/documents | jq
|
# Async (production flow)
|
||||||
|
curl -F "file=@samples/pdf/example.pdf" \
|
||||||
|
-H "X-API-Key: $API_KEY" \
|
||||||
|
http://localhost:8000/api/v1/documents | jq
|
||||||
|
# → {"job_id":"8f2a...","status":"pending",...}
|
||||||
|
|
||||||
|
curl -H "X-API-Key: $API_KEY" \
|
||||||
|
http://localhost:8000/api/v1/documents/8f2a... | jq
|
||||||
|
|
||||||
|
# Sync (single small doc, no worker required)
|
||||||
|
curl -F "file=@samples/pdf/example.pdf" \
|
||||||
|
"http://localhost:8000/api/v1/documents?sync=true" | jq
|
||||||
```
|
```
|
||||||
|
|
||||||
Expected response (truncated):
|
Expected response (truncated):
|
||||||
@@ -71,13 +85,17 @@ Expected response (truncated):
|
|||||||
|
|
||||||
### Docker
|
### Docker
|
||||||
|
|
||||||
|
The Phase 4 stack runs four services: `api`, `worker` (Celery), `redis`, and `postgres`. Blob uploads are persisted to a Docker volume — there is **no MinIO/S3** dependency.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker compose build
|
docker compose build
|
||||||
docker compose up -d
|
docker compose up -d
|
||||||
docker compose logs -f api
|
docker compose logs -f api worker
|
||||||
```
|
```
|
||||||
|
|
||||||
The first request will trigger PaddleOCR to download its detection/recognition/cls models (~200 MB) into the `paddle-models` volume.
|
The API container runs `alembic upgrade head` on start, so the `jobs` table is created on first boot. The first request will trigger PaddleOCR to download its detection/recognition/cls models (~200 MB) into the `paddle-models` volume.
|
||||||
|
|
||||||
|
Metrics are exposed at <http://localhost:8000/metrics> in Prometheus text format.
|
||||||
|
|
||||||
## Development
|
## Development
|
||||||
|
|
||||||
@@ -114,7 +132,7 @@ docs/ # architecture & decision records
|
|||||||
| 1 | Sync API, PDF/image ingest, basic preprocessing, PaddleOCR, regex header extraction, validation, confidence scoring | **Done** |
|
| 1 | Sync API, PDF/image ingest, basic preprocessing, PaddleOCR, regex header extraction, validation, confidence scoring | **Done** |
|
||||||
| 2 | OpenCV-based document detection, perspective transform, shadow removal for phone photos | **Done** |
|
| 2 | OpenCV-based document detection, perspective transform, shadow removal for phone photos | **Done** |
|
||||||
| 3 | PP-Structure table extraction for personnel rows + column mapper | **Done** |
|
| 3 | PP-Structure table extraction for personnel rows + column mapper | **Done** |
|
||||||
| 4 | Async pipeline (Celery + Redis), Postgres + MinIO, auth, observability | Planned |
|
| 4 | Async pipeline (Celery + Redis), Postgres job state, local-filesystem blob storage, API-key auth, Prometheus metrics | **Done** |
|
||||||
| 5 | LLM hybrid extraction (Ollama + structured output) | Planned |
|
| 5 | LLM hybrid extraction (Ollama + structured output) | Planned |
|
||||||
| 6 | HITL review endpoints + audit trail | Planned |
|
| 6 | HITL review endpoints + audit trail | Planned |
|
||||||
|
|
||||||
|
|||||||
38
alembic.ini
Normal file
38
alembic.ini
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
[alembic]
|
||||||
|
script_location = alembic
|
||||||
|
prepend_sys_path = src
|
||||||
|
sqlalchemy.url =
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
56
alembic/env.py
Normal file
56
alembic/env.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Alembic environment.
|
||||||
|
|
||||||
|
The DB URL is taken from `Settings.database_url`, which respects the same
|
||||||
|
`.env` and environment variables as the main app. CLI users can override
|
||||||
|
with `-x sqlalchemy.url=...` if they need to point at a different DB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
from ocr_sprint.db.models import Base
|
||||||
|
|
||||||
|
config = context.config
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
settings = get_settings()
|
||||||
|
config.set_main_option("sqlalchemy.url", settings.database_url)
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
url = config.get_main_option("sqlalchemy.url")
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
connectable = engine_from_config(
|
||||||
|
config.get_section(config.config_ini_section, {}),
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(connection=connection, target_metadata=target_metadata, compare_type=True)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
25
alembic/script.py.mako
Normal file
25
alembic/script.py.mako
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
"""
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: str | None = ${repr(down_revision)}
|
||||||
|
branch_labels: str | Sequence[str] | None = ${repr(branch_labels)}
|
||||||
|
depends_on: str | Sequence[str] | None = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
42
alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py
Normal file
42
alembic/versions/ff8c14fbf8a0_phase4_jobs_table.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""phase4 jobs table
|
||||||
|
|
||||||
|
Revision ID: ff8c14fbf8a0
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-04-25 15:54:18.579147
|
||||||
|
"""
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'ff8c14fbf8a0'
|
||||||
|
down_revision: str | None = None
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('jobs',
|
||||||
|
sa.Column('job_id', sa.Uuid(), nullable=False),
|
||||||
|
sa.Column('status', sa.String(length=32), nullable=False),
|
||||||
|
sa.Column('source_kind', sa.String(length=16), nullable=False),
|
||||||
|
sa.Column('filename', sa.String(length=512), nullable=False),
|
||||||
|
sa.Column('blob_key', sa.String(length=512), nullable=True),
|
||||||
|
sa.Column('confidence', sa.Float(), nullable=True),
|
||||||
|
sa.Column('review_flags', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('result', sa.JSON(), nullable=True),
|
||||||
|
sa.Column('error', sa.String(length=2048), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('job_id')
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table('jobs')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
# Phase 1 MVP compose: API only.
|
# Phase 4 stack: API + Celery worker + Redis (broker/result backend) +
|
||||||
# Phase 4 will add redis, postgres, minio, and worker services.
|
# Postgres (job state). Object storage is intentionally NOT here — the
|
||||||
|
# `BlobStorage` interface uses the local filesystem mounted at /app/storage.
|
||||||
services:
|
services:
|
||||||
api:
|
api:
|
||||||
build:
|
build:
|
||||||
@@ -7,17 +8,83 @@ services:
|
|||||||
dockerfile: Dockerfile
|
dockerfile: Dockerfile
|
||||||
image: ocr-sprint-service:dev
|
image: ocr-sprint-service:dev
|
||||||
container_name: ocr-sprint-api
|
container_name: ocr-sprint-api
|
||||||
|
command:
|
||||||
|
[
|
||||||
|
"sh",
|
||||||
|
"-c",
|
||||||
|
"alembic upgrade head && uvicorn ocr_sprint.main:app --host 0.0.0.0 --port 8000",
|
||||||
|
]
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8000:8000"
|
||||||
environment:
|
environment:
|
||||||
APP_ENV: local
|
APP_ENV: docker
|
||||||
APP_LOG_LEVEL: INFO
|
APP_LOG_LEVEL: INFO
|
||||||
OCR_USE_GPU: "false"
|
OCR_USE_GPU: "false"
|
||||||
STORAGE_LOCAL_DIR: /app/storage
|
STORAGE_LOCAL_DIR: /app/storage
|
||||||
|
BLOB_STORAGE_DIR: /app/storage/blobs
|
||||||
|
REDIS_URL: redis://redis:6379/0
|
||||||
|
DATABASE_URL: postgresql+psycopg://ocr:ocr@postgres:5432/ocr_sprint
|
||||||
|
QUEUE_ENABLED: "true"
|
||||||
volumes:
|
volumes:
|
||||||
- ./storage:/app/storage
|
- blob-storage:/app/storage/blobs
|
||||||
- paddle-models:/home/app/.paddleocr
|
- paddle-models:/home/app/.paddleocr
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
worker:
|
||||||
|
image: ocr-sprint-service:dev
|
||||||
|
container_name: ocr-sprint-worker
|
||||||
|
command: ["celery", "-A", "ocr_sprint.worker.celery_app", "worker", "-l", "info", "--concurrency=1"]
|
||||||
|
environment:
|
||||||
|
APP_ENV: docker
|
||||||
|
APP_LOG_LEVEL: INFO
|
||||||
|
OCR_USE_GPU: "false"
|
||||||
|
BLOB_STORAGE_DIR: /app/storage/blobs
|
||||||
|
REDIS_URL: redis://redis:6379/0
|
||||||
|
DATABASE_URL: postgresql+psycopg://ocr:ocr@postgres:5432/ocr_sprint
|
||||||
|
volumes:
|
||||||
|
- blob-storage:/app/storage/blobs
|
||||||
|
- paddle-models:/home/app/.paddleocr
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
api:
|
||||||
|
condition: service_started
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:7-alpine
|
||||||
|
container_name: ocr-sprint-redis
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
postgres:
|
||||||
|
image: postgres:16-alpine
|
||||||
|
container_name: ocr-sprint-postgres
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: ocr
|
||||||
|
POSTGRES_PASSWORD: ocr
|
||||||
|
POSTGRES_DB: ocr_sprint
|
||||||
|
volumes:
|
||||||
|
- postgres-data:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U ocr -d ocr_sprint"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 10
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
|
blob-storage:
|
||||||
paddle-models:
|
paddle-models:
|
||||||
|
postgres-data:
|
||||||
|
|||||||
@@ -24,12 +24,15 @@ dependencies = [
|
|||||||
"numpy>=1.26,<2.2",
|
"numpy>=1.26,<2.2",
|
||||||
"PyMuPDF>=1.24,<2",
|
"PyMuPDF>=1.24,<2",
|
||||||
"python-magic>=0.4.27",
|
"python-magic>=0.4.27",
|
||||||
# OCR (CPU build of paddle; GPU users override via extra index)
|
|
||||||
"paddlepaddle==2.6.1",
|
|
||||||
"paddleocr>=2.7.5,<3",
|
|
||||||
# Logging / observability
|
# Logging / observability
|
||||||
"structlog>=24.1",
|
"structlog>=24.1",
|
||||||
"prometheus-client>=0.20",
|
"prometheus-client>=0.20",
|
||||||
|
# Async pipeline + persistence (Phase 4)
|
||||||
|
"celery[redis]>=5.4",
|
||||||
|
"redis>=5.0",
|
||||||
|
"sqlalchemy>=2.0",
|
||||||
|
"psycopg[binary]>=3.2",
|
||||||
|
"alembic>=1.13",
|
||||||
# Misc
|
# Misc
|
||||||
"httpx>=0.27",
|
"httpx>=0.27",
|
||||||
"tenacity>=8.5",
|
"tenacity>=8.5",
|
||||||
@@ -46,22 +49,20 @@ dev = [
|
|||||||
"pre-commit>=3.7",
|
"pre-commit>=3.7",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# OCR runtime — kept as an optional extra so unit tests / dev installs don't
|
||||||
|
# pull ~1.5 GB of Paddle wheels. Install via `pip install -e ".[ocr]"` on
|
||||||
|
# the worker host. CPU build by default; GPU users override the index URL.
|
||||||
|
ocr = [
|
||||||
|
"paddlepaddle>=2.6.2,<4",
|
||||||
|
"paddleocr>=2.7.5,<3",
|
||||||
|
]
|
||||||
|
|
||||||
# Extraction layer (Phase 5) — kept optional so MVP install stays light
|
# Extraction layer (Phase 5) — kept optional so MVP install stays light
|
||||||
llm = [
|
llm = [
|
||||||
"ollama>=0.3",
|
"ollama>=0.3",
|
||||||
"instructor>=1.4",
|
"instructor>=1.4",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Async pipeline (Phase 4)
|
|
||||||
async-pipeline = [
|
|
||||||
"celery[redis]>=5.4",
|
|
||||||
"redis>=5.0",
|
|
||||||
"minio>=7.2",
|
|
||||||
"sqlalchemy>=2.0",
|
|
||||||
"psycopg[binary]>=3.2",
|
|
||||||
"alembic>=1.13",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
ocr-sprint-api = "ocr_sprint.main:run"
|
ocr-sprint-api = "ocr_sprint.main:run"
|
||||||
|
|
||||||
@@ -111,7 +112,7 @@ namespace_packages = true
|
|||||||
explicit_package_bases = true
|
explicit_package_bases = true
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = ["paddleocr.*", "paddle.*", "cv2.*", "fitz.*", "magic.*"]
|
module = ["paddleocr.*", "paddle.*", "cv2.*", "fitz.*", "magic.*", "celery.*", "kombu.*"]
|
||||||
ignore_missing_imports = true
|
ignore_missing_imports = true
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
|
|||||||
6
src/ocr_sprint/api/deps/__init__.py
Normal file
6
src/ocr_sprint/api/deps/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Reusable FastAPI dependencies (auth, db session)."""
|
||||||
|
|
||||||
|
from ocr_sprint.api.deps.auth import require_api_key
|
||||||
|
from ocr_sprint.api.deps.db import get_session
|
||||||
|
|
||||||
|
__all__ = ["get_session", "require_api_key"]
|
||||||
35
src/ocr_sprint/api/deps/auth.py
Normal file
35
src/ocr_sprint/api/deps/auth.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""API-key authentication.
|
||||||
|
|
||||||
|
The MVP uses a static list of keys loaded from `Settings.api_keys`. This is
|
||||||
|
deliberate: the service is intended to run on-prem behind an internal
|
||||||
|
reverse proxy, with a small set of trusted clients (the police HITL UI and
|
||||||
|
internal automation). Anything more sophisticated (JWT / OAuth / mTLS) is
|
||||||
|
deferred until there's a concrete need.
|
||||||
|
|
||||||
|
When `Settings.api_keys` is empty the dependency permits all requests.
|
||||||
|
This makes the local dev experience friction-free; production deploys MUST
|
||||||
|
set at least one key — tested by `test_auth_rejects_missing_key`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import Header, HTTPException, status
|
||||||
|
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
async def require_api_key(
|
||||||
|
x_api_key: Annotated[str | None, Header(alias="X-API-Key")] = None,
|
||||||
|
) -> None:
|
||||||
|
settings = get_settings()
|
||||||
|
keys = settings.api_keys_list
|
||||||
|
if not keys:
|
||||||
|
return # auth disabled for local dev
|
||||||
|
if not x_api_key or x_api_key not in keys:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="invalid or missing API key",
|
||||||
|
headers={"WWW-Authenticate": settings.api_key_header},
|
||||||
|
)
|
||||||
23
src/ocr_sprint/api/deps/db.py
Normal file
23
src/ocr_sprint/api/deps/db.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
"""Per-request SQLAlchemy session dependency."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ocr_sprint.db.base import get_sessionmaker
|
||||||
|
|
||||||
|
|
||||||
|
def get_session() -> Iterator[Session]:
|
||||||
|
"""Yield a session, committing on success and rolling back on errors."""
|
||||||
|
factory = get_sessionmaker()
|
||||||
|
session = factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
114
src/ocr_sprint/api/metrics.py
Normal file
114
src/ocr_sprint/api/metrics.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""Prometheus metrics exposed at `/metrics`.
|
||||||
|
|
||||||
|
We keep the surface tiny on purpose:
|
||||||
|
|
||||||
|
* `http_requests_total{method,route,status}` — request count
|
||||||
|
* `http_request_duration_seconds{method,route}` — latency histogram
|
||||||
|
* `ocr_jobs_total{status}` — current count of jobs in each status, derived
|
||||||
|
from the database at scrape time. Because this is computed from the
|
||||||
|
``jobs`` table it is correct regardless of whether the writer was the API
|
||||||
|
process (sync mode) or a Celery worker process. Note that this is
|
||||||
|
technically a gauge of the cumulative-by-status count, not a strict
|
||||||
|
monotonic counter — values can decrease if rows are deleted/reaped.
|
||||||
|
* `ocr_job_processing_seconds` — pipeline wall-time histogram. Only the API
|
||||||
|
process observes events on this histogram (sync path); the worker writes
|
||||||
|
its timing into the DB row and is not exposed here.
|
||||||
|
|
||||||
|
A custom registry is used so tests can reset counters cleanly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from collections.abc import Awaitable, Callable, Iterable
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from prometheus_client import (
|
||||||
|
CONTENT_TYPE_LATEST,
|
||||||
|
CollectorRegistry,
|
||||||
|
Counter,
|
||||||
|
Histogram,
|
||||||
|
generate_latest,
|
||||||
|
)
|
||||||
|
from prometheus_client.core import GaugeMetricFamily
|
||||||
|
from prometheus_client.metrics_core import Metric
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
from ocr_sprint.db.base import session_scope
|
||||||
|
from ocr_sprint.db.models import JobRow
|
||||||
|
from ocr_sprint.utils.logging import get_logger
|
||||||
|
|
||||||
|
_logger = get_logger(__name__)
|
||||||
|
|
||||||
|
REGISTRY = CollectorRegistry()
|
||||||
|
|
||||||
|
REQUEST_COUNT = Counter(
|
||||||
|
"http_requests_total",
|
||||||
|
"HTTP requests handled by the API",
|
||||||
|
("method", "route", "status"),
|
||||||
|
registry=REGISTRY,
|
||||||
|
)
|
||||||
|
REQUEST_LATENCY = Histogram(
|
||||||
|
"http_request_duration_seconds",
|
||||||
|
"HTTP request latency",
|
||||||
|
("method", "route"),
|
||||||
|
registry=REGISTRY,
|
||||||
|
)
|
||||||
|
JOB_PROCESSING_SECONDS = Histogram(
|
||||||
|
"ocr_job_processing_seconds",
|
||||||
|
"OCR job pipeline wall-time as observed by the API process",
|
||||||
|
registry=REGISTRY,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _JobStatusCollector:
|
||||||
|
"""Custom collector that queries the ``jobs`` table on every scrape.
|
||||||
|
|
||||||
|
Because the worker runs in a separate process from the API, an in-memory
|
||||||
|
``Counter`` cannot accurately track terminal job counts — its writes
|
||||||
|
would never reach the API's ``/metrics`` endpoint. Reading from the
|
||||||
|
shared DB on each scrape keeps the metric correct across processes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def collect(self) -> Iterable[Metric]:
|
||||||
|
family = GaugeMetricFamily(
|
||||||
|
"ocr_jobs_total",
|
||||||
|
"Current count of jobs grouped by status (read from DB).",
|
||||||
|
labels=["status"],
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with session_scope() as session:
|
||||||
|
stmt = select(JobRow.status, func.count()).group_by(JobRow.status)
|
||||||
|
for status_value, count in session.execute(stmt).all():
|
||||||
|
family.add_metric([status_value], float(count))
|
||||||
|
except Exception as exc:
|
||||||
|
_logger.warning("metrics.jobs_collect_failed", error=str(exc))
|
||||||
|
return [family]
|
||||||
|
|
||||||
|
|
||||||
|
REGISTRY.register(_JobStatusCollector())
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Record request count + latency. The `route` label is the path
|
||||||
|
template, not the raw URL, so per-id endpoints don't blow up cardinality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dispatch(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
call_next: Callable[[Request], Awaitable[Response]],
|
||||||
|
) -> Response:
|
||||||
|
start = time.perf_counter()
|
||||||
|
response = await call_next(request)
|
||||||
|
elapsed = time.perf_counter() - start
|
||||||
|
route = request.scope.get("route")
|
||||||
|
path = getattr(route, "path", request.url.path) if route else request.url.path
|
||||||
|
REQUEST_COUNT.labels(request.method, path, str(response.status_code)).inc()
|
||||||
|
REQUEST_LATENCY.labels(request.method, path).observe(elapsed)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
async def metrics_endpoint() -> Response:
|
||||||
|
return Response(content=generate_latest(REGISTRY), media_type=CONTENT_TYPE_LATEST)
|
||||||
@@ -1,58 +1,194 @@
|
|||||||
"""Documents API — Phase 1 synchronous endpoint.
|
"""Documents API.
|
||||||
|
|
||||||
POST /documents accepts a single PDF or image upload, runs the synchronous
|
Phase 1 shipped a single synchronous endpoint. Phase 4 adds an async
|
||||||
pipeline inline, and returns the structured result. This is suitable for
|
flow on top:
|
||||||
development and low-traffic production; Phase 4 will introduce an async
|
|
||||||
queue and a polling-style API at the same path.
|
* `POST /documents` — async by default. Saves the upload to blob
|
||||||
|
storage, creates a `pending` job row, and
|
||||||
|
enqueues a Celery task. Returns `202` with
|
||||||
|
the job id.
|
||||||
|
* `POST /documents?sync=true` — runs the pipeline inline (the original
|
||||||
|
Phase 1 behaviour). Useful for tests and
|
||||||
|
small-volume single-tenant deploys without
|
||||||
|
a Celery worker.
|
||||||
|
* `GET /documents/{job_id}` — returns the current job state. Async
|
||||||
|
clients poll this until `status` is in a
|
||||||
|
terminal state (completed / needs_review /
|
||||||
|
failed).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from uuid import uuid4
|
from typing import Annotated
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
from fastapi import APIRouter, File, UploadFile, status
|
from fastapi import APIRouter, Depends, File, HTTPException, Query, Response, UploadFile, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ocr_sprint.api.deps.auth import require_api_key
|
||||||
|
from ocr_sprint.api.deps.db import get_session
|
||||||
from ocr_sprint.api.errors import UnsupportedDocumentError
|
from ocr_sprint.api.errors import UnsupportedDocumentError
|
||||||
|
from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
from ocr_sprint.db.base import session_scope
|
||||||
|
from ocr_sprint.db.repositories import JobNotFoundError, JobRepository
|
||||||
|
from ocr_sprint.pipeline.ingest import detect_source_kind
|
||||||
from ocr_sprint.pipeline.orchestrator import run_pipeline
|
from ocr_sprint.pipeline.orchestrator import run_pipeline
|
||||||
from ocr_sprint.schemas.document import DocumentResponse
|
from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus
|
||||||
|
from ocr_sprint.schemas.extraction import ExtractionResult
|
||||||
|
from ocr_sprint.storage.blob import get_blob_storage
|
||||||
from ocr_sprint.utils.logging import get_logger
|
from ocr_sprint.utils.logging import get_logger
|
||||||
|
|
||||||
router = APIRouter(prefix="/documents", tags=["documents"])
|
router = APIRouter(
|
||||||
|
prefix="/documents",
|
||||||
|
tags=["documents"],
|
||||||
|
dependencies=[Depends(require_api_key)],
|
||||||
|
)
|
||||||
_logger = get_logger(__name__)
|
_logger = get_logger(__name__)
|
||||||
|
|
||||||
_MAX_UPLOAD_BYTES = 25 * 1024 * 1024 # 25 MB
|
|
||||||
|
# ---------- helpers ----------
|
||||||
|
|
||||||
|
|
||||||
@router.post("", status_code=status.HTTP_200_OK, response_model=DocumentResponse)
|
def _enforce_size(content: bytes) -> None:
|
||||||
async def create_document(file: UploadFile = File(...)) -> DocumentResponse:
|
s = get_settings()
|
||||||
"""Run OCR + extraction synchronously on a single upload."""
|
if not content:
|
||||||
|
raise UnsupportedDocumentError("Uploaded file is empty.")
|
||||||
|
max_bytes = s.blob_max_upload_mb * 1024 * 1024
|
||||||
|
if len(content) > max_bytes:
|
||||||
|
raise UnsupportedDocumentError(f"Uploaded file exceeds {s.blob_max_upload_mb} MB limit.")
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_response(row: object) -> DocumentResponse:
|
||||||
|
# Local import to avoid a circular import at module load time.
|
||||||
|
from ocr_sprint.db.models import JobRow
|
||||||
|
|
||||||
|
assert isinstance(row, JobRow)
|
||||||
|
status_enum = DocumentStatus(row.status)
|
||||||
|
result_obj: ExtractionResult | None = None
|
||||||
|
if row.result is not None:
|
||||||
|
result_obj = ExtractionResult.model_validate(row.result)
|
||||||
|
return DocumentResponse(
|
||||||
|
job_id=row.job_id,
|
||||||
|
status=status_enum,
|
||||||
|
confidence=row.confidence,
|
||||||
|
data=result_obj,
|
||||||
|
review_flags=list(row.review_flags or []),
|
||||||
|
error=row.error,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- POST ----------
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=DocumentResponse)
|
||||||
|
async def create_document(
|
||||||
|
file: Annotated[UploadFile, File(...)],
|
||||||
|
session: Annotated[Session, Depends(get_session)],
|
||||||
|
response: Response,
|
||||||
|
sync: Annotated[
|
||||||
|
bool | None,
|
||||||
|
Query(description="Run pipeline inline (skip queue). Defaults to !queue_enabled."),
|
||||||
|
] = None,
|
||||||
|
) -> DocumentResponse:
|
||||||
|
# When the queue is disabled (default for local dev), running the async
|
||||||
|
# path would try to dial Redis and fail with a 500. Auto-fall-back to the
|
||||||
|
# inline pipeline unless the caller explicitly asked for async.
|
||||||
|
if sync is None:
|
||||||
|
sync = not get_settings().queue_enabled
|
||||||
|
|
||||||
job_id = uuid4()
|
job_id = uuid4()
|
||||||
log = _logger.bind(job_id=str(job_id), filename=file.filename or "")
|
log = _logger.bind(job_id=str(job_id), filename=file.filename or "")
|
||||||
|
|
||||||
content = await file.read()
|
content = await file.read()
|
||||||
if not content:
|
_enforce_size(content)
|
||||||
raise UnsupportedDocumentError("Uploaded file is empty.")
|
|
||||||
if len(content) > _MAX_UPLOAD_BYTES:
|
|
||||||
raise UnsupportedDocumentError(
|
|
||||||
f"Uploaded file exceeds {_MAX_UPLOAD_BYTES // (1024 * 1024)} MB limit."
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info("documents.received", size=len(content))
|
storage = get_blob_storage()
|
||||||
|
blob_key = storage.put(content, original_filename=file.filename)
|
||||||
|
source_kind = detect_source_kind(content)
|
||||||
|
JobRepository(session).create(
|
||||||
|
job_id=job_id,
|
||||||
|
filename=file.filename or "",
|
||||||
|
source_kind=source_kind,
|
||||||
|
blob_key=blob_key,
|
||||||
|
)
|
||||||
|
# Commit the `pending` row immediately so it is observable regardless
|
||||||
|
# of what happens next. Both code paths below open their own session
|
||||||
|
# for state transitions; that way an exception in `_run_inline` cannot
|
||||||
|
# roll back the create() (which would orphan the blob on disk).
|
||||||
|
session.commit()
|
||||||
|
log.info("documents.received", size=len(content), blob_key=blob_key, sync=sync)
|
||||||
|
|
||||||
|
if sync:
|
||||||
|
# Status code stays at the default 200; the body's `status` field
|
||||||
|
# tells the client whether the job needs review.
|
||||||
|
return await _run_inline(job_id, content)
|
||||||
|
|
||||||
|
# Async path — enqueue and return 202. The Celery worker will pick up
|
||||||
|
# the row using its own session.
|
||||||
|
from ocr_sprint.worker.tasks import process_document_task
|
||||||
|
|
||||||
|
process_document_task.delay(str(job_id))
|
||||||
|
with session_scope() as poll:
|
||||||
|
row = JobRepository(poll).get_or_raise(job_id)
|
||||||
|
body = _row_to_response(row)
|
||||||
|
response.status_code = status.HTTP_202_ACCEPTED
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_inline(job_id: UUID, content: bytes) -> DocumentResponse:
|
||||||
|
"""Synchronous pipeline execution.
|
||||||
|
|
||||||
|
Each state transition opens its own short session so the request-scoped
|
||||||
|
session's rollback-on-exception behaviour cannot wipe out the
|
||||||
|
``mark_failed`` write or strand the blob on disk.
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
|
||||||
|
with session_scope() as s:
|
||||||
|
JobRepository(s).mark_processing(job_id)
|
||||||
|
|
||||||
|
started = time.perf_counter()
|
||||||
try:
|
try:
|
||||||
output = run_pipeline(content)
|
output = run_pipeline(content)
|
||||||
except ValueError as exc:
|
except ValueError as exc:
|
||||||
|
with session_scope() as s:
|
||||||
|
JobRepository(s).mark_failed(job_id, error=str(exc))
|
||||||
raise UnsupportedDocumentError(str(exc)) from exc
|
raise UnsupportedDocumentError(str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
with session_scope() as s:
|
||||||
|
JobRepository(s).mark_failed(job_id, error=str(exc))
|
||||||
|
raise
|
||||||
|
|
||||||
log.info(
|
flags = [f.value for f in output.result.review_flags]
|
||||||
"documents.completed",
|
JOB_PROCESSING_SECONDS.observe(time.perf_counter() - started)
|
||||||
status=output.status.value,
|
with session_scope() as s:
|
||||||
confidence=round(output.confidence, 3),
|
repo = JobRepository(s)
|
||||||
flags=[f.value for f in output.result.review_flags],
|
repo.mark_completed(
|
||||||
)
|
job_id,
|
||||||
return DocumentResponse(
|
status=output.status,
|
||||||
job_id=job_id,
|
confidence=output.confidence,
|
||||||
status=output.status,
|
result=output.result.model_dump(mode="json"),
|
||||||
confidence=output.confidence,
|
review_flags=flags,
|
||||||
data=output.result,
|
)
|
||||||
review_flags=[f.value for f in output.result.review_flags],
|
row = repo.get_or_raise(job_id)
|
||||||
)
|
return _row_to_response(row)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- GET ----------
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/{job_id}",
|
||||||
|
response_model=DocumentResponse,
|
||||||
|
)
|
||||||
|
async def get_document(
|
||||||
|
job_id: UUID,
|
||||||
|
session: Annotated[Session, Depends(get_session)],
|
||||||
|
) -> DocumentResponse:
|
||||||
|
repo = JobRepository(session)
|
||||||
|
try:
|
||||||
|
row = repo.get_or_raise(job_id)
|
||||||
|
except JobNotFoundError as exc:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc
|
||||||
|
return _row_to_response(row)
|
||||||
|
|||||||
@@ -64,12 +64,29 @@ class Settings(BaseSettings):
|
|||||||
# Async pipeline (Phase 4)
|
# Async pipeline (Phase 4)
|
||||||
queue_enabled: bool = False
|
queue_enabled: bool = False
|
||||||
redis_url: str = "redis://localhost:6379/0"
|
redis_url: str = "redis://localhost:6379/0"
|
||||||
database_url: str = "postgresql+psycopg://ocr:ocr@localhost:5432/ocr_sprint"
|
celery_task_default_queue: str = "ocr_sprint"
|
||||||
minio_endpoint: str = "localhost:9000"
|
|
||||||
minio_access_key: str = "minioadmin"
|
# Persistence (Phase 4). Use sqlite for local dev / tests; Postgres for
|
||||||
minio_secret_key: str = "minioadmin"
|
# production via docker-compose.
|
||||||
minio_bucket: str = "ocr-sprint"
|
database_url: str = "sqlite:///./storage/ocr_sprint.sqlite"
|
||||||
minio_secure: bool = False
|
database_echo: bool = False
|
||||||
|
|
||||||
|
# Blob storage (Phase 4). Local filesystem only for the MVP; the
|
||||||
|
# `BlobStorage` interface is designed to swap to S3/MinIO without API
|
||||||
|
# changes when needed.
|
||||||
|
blob_storage_dir: Path = Path("./storage/blobs")
|
||||||
|
blob_max_upload_mb: int = 25
|
||||||
|
|
||||||
|
# Auth (Phase 4). Comma-separated list of API keys accepted by the API.
|
||||||
|
# Empty string disables auth (only intended for local dev / tests).
|
||||||
|
# We use ``str`` rather than ``list[str]`` because pydantic-settings rejects
|
||||||
|
# a bare empty string when binding to a list type.
|
||||||
|
api_keys: str = ""
|
||||||
|
api_key_header: str = "X-API-Key"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def api_keys_list(self) -> list[str]:
|
||||||
|
return [k.strip() for k in self.api_keys.split(",") if k.strip()]
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
@@ -77,4 +94,5 @@ def get_settings() -> Settings:
|
|||||||
"""Cached accessor so settings are loaded once per process."""
|
"""Cached accessor so settings are loaded once per process."""
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
settings.storage_local_dir.mkdir(parents=True, exist_ok=True)
|
settings.storage_local_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
settings.blob_storage_dir.mkdir(parents=True, exist_ok=True)
|
||||||
return settings
|
return settings
|
||||||
|
|||||||
14
src/ocr_sprint/db/__init__.py
Normal file
14
src/ocr_sprint/db/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""Persistence layer (Phase 4) — SQLAlchemy 2.0 models, session, repositories."""
|
||||||
|
|
||||||
|
from ocr_sprint.db.base import Base, get_engine, get_sessionmaker, session_scope
|
||||||
|
from ocr_sprint.db.models import JobRow
|
||||||
|
from ocr_sprint.db.repositories import JobRepository
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Base",
|
||||||
|
"JobRepository",
|
||||||
|
"JobRow",
|
||||||
|
"get_engine",
|
||||||
|
"get_sessionmaker",
|
||||||
|
"session_scope",
|
||||||
|
]
|
||||||
67
src/ocr_sprint/db/base.py
Normal file
67
src/ocr_sprint/db/base.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""SQLAlchemy 2.0 engine + session factory.
|
||||||
|
|
||||||
|
We use a single global engine per process. For tests (and SQLite in dev),
|
||||||
|
StaticPool keeps the same in-memory database across connections; for
|
||||||
|
Postgres in production we use SQLAlchemy's default pool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Common SQLAlchemy declarative base."""
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_engine() -> Engine:
|
||||||
|
s = get_settings()
|
||||||
|
kwargs: dict[str, object] = {
|
||||||
|
"echo": s.database_echo,
|
||||||
|
"future": True,
|
||||||
|
}
|
||||||
|
# SQLite needs special handling: same connection across threads (Celery
|
||||||
|
# eager mode + FastAPI) requires `check_same_thread=False`. For the
|
||||||
|
# ``sqlite:///:memory:`` URL we also use StaticPool to reuse the same
|
||||||
|
# underlying connection so test fixtures see committed data.
|
||||||
|
if s.database_url.startswith("sqlite"):
|
||||||
|
kwargs["connect_args"] = {"check_same_thread": False}
|
||||||
|
if ":memory:" in s.database_url or s.database_url.endswith(":memory:"):
|
||||||
|
kwargs["poolclass"] = StaticPool
|
||||||
|
return create_engine(s.database_url, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_sessionmaker() -> sessionmaker[Session]:
|
||||||
|
return sessionmaker(bind=get_engine(), expire_on_commit=False, autoflush=False)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def session_scope() -> Iterator[Session]:
|
||||||
|
"""Yield a SQLAlchemy session and commit/rollback at the boundary."""
|
||||||
|
factory = get_sessionmaker()
|
||||||
|
session = factory()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
session.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
def reset_engine_cache() -> None:
|
||||||
|
"""Clear cached engine + sessionmaker. Used by tests when changing DB URL."""
|
||||||
|
get_engine.cache_clear()
|
||||||
|
get_sessionmaker.cache_clear()
|
||||||
53
src/ocr_sprint/db/models.py
Normal file
53
src/ocr_sprint/db/models.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""SQLAlchemy ORM models for jobs + extraction results.
|
||||||
|
|
||||||
|
We store the structured result as JSON. PaddleOCR's `raw_text` can run into
|
||||||
|
the tens of kilobytes for multi-page documents; that's well within Postgres'
|
||||||
|
JSONB row-size budget. SQLite stores it as TEXT under the hood.
|
||||||
|
|
||||||
|
Schema choice: we keep the result inline on the same row instead of a
|
||||||
|
separate `extraction_results` table. The 1:1 relationship would otherwise
|
||||||
|
add a join on every read, with no real benefit since results are immutable
|
||||||
|
once written and there's no use-case for fetching just the metadata.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from sqlalchemy import JSON, DateTime, Float, String, Uuid
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from ocr_sprint.db.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class JobRow(Base):
|
||||||
|
__tablename__ = "jobs"
|
||||||
|
|
||||||
|
# SQLAlchemy 2.0's Uuid type maps to native UUID on Postgres and CHAR(32)
|
||||||
|
# on SQLite, so the same model works in both environments.
|
||||||
|
job_id: Mapped[UUID] = mapped_column(Uuid, primary_key=True, default=uuid4)
|
||||||
|
status: Mapped[str] = mapped_column(String(32), nullable=False, default="pending")
|
||||||
|
source_kind: Mapped[str] = mapped_column(String(16), nullable=False, default="unknown")
|
||||||
|
filename: Mapped[str] = mapped_column(String(512), nullable=False, default="")
|
||||||
|
blob_key: Mapped[str | None] = mapped_column(String(512), nullable=True)
|
||||||
|
|
||||||
|
confidence: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
|
review_flags: Mapped[list[str]] = mapped_column(JSON, nullable=False, default=list)
|
||||||
|
result: Mapped[dict[str, Any] | None] = mapped_column(JSON, nullable=True)
|
||||||
|
error: Mapped[str | None] = mapped_column(String(2048), nullable=True)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, default=_utcnow
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, default=_utcnow, onupdate=_utcnow
|
||||||
|
)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"JobRow(job_id={self.job_id!s}, status={self.status!r})"
|
||||||
96
src/ocr_sprint/db/repositories.py
Normal file
96
src/ocr_sprint/db/repositories.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""Thin data-access layer over the ORM.
|
||||||
|
|
||||||
|
Repositories encapsulate the SQL so the API + Celery task code never has to
|
||||||
|
know about sessions, transactions, or the row → schema mapping.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from ocr_sprint.db.models import JobRow
|
||||||
|
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||||
|
|
||||||
|
|
||||||
|
def _utcnow() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class JobNotFoundError(LookupError):
|
||||||
|
"""Raised by API code when GET /documents/{id} hits a missing row."""
|
||||||
|
|
||||||
|
|
||||||
|
class JobRepository:
|
||||||
|
"""SQL-backed repository for `jobs` rows."""
|
||||||
|
|
||||||
|
def __init__(self, session: Session) -> None:
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
# ---------- writes ----------
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
job_id: UUID,
|
||||||
|
filename: str,
|
||||||
|
source_kind: SourceKind,
|
||||||
|
blob_key: str,
|
||||||
|
) -> JobRow:
|
||||||
|
row = JobRow(
|
||||||
|
job_id=job_id,
|
||||||
|
status=DocumentStatus.PENDING.value,
|
||||||
|
source_kind=source_kind.value,
|
||||||
|
filename=filename,
|
||||||
|
blob_key=blob_key,
|
||||||
|
)
|
||||||
|
self.session.add(row)
|
||||||
|
self.session.flush()
|
||||||
|
return row
|
||||||
|
|
||||||
|
def mark_processing(self, job_id: UUID) -> None:
|
||||||
|
row = self._get_or_raise(job_id)
|
||||||
|
row.status = DocumentStatus.PROCESSING.value
|
||||||
|
row.updated_at = _utcnow()
|
||||||
|
|
||||||
|
def mark_completed(
|
||||||
|
self,
|
||||||
|
job_id: UUID,
|
||||||
|
*,
|
||||||
|
status: DocumentStatus,
|
||||||
|
confidence: float,
|
||||||
|
result: dict[str, Any],
|
||||||
|
review_flags: list[str],
|
||||||
|
) -> None:
|
||||||
|
row = self._get_or_raise(job_id)
|
||||||
|
row.status = status.value
|
||||||
|
row.confidence = confidence
|
||||||
|
row.result = result
|
||||||
|
row.review_flags = review_flags
|
||||||
|
row.error = None
|
||||||
|
row.updated_at = _utcnow()
|
||||||
|
|
||||||
|
def mark_failed(self, job_id: UUID, *, error: str) -> None:
|
||||||
|
row = self._get_or_raise(job_id)
|
||||||
|
row.status = DocumentStatus.FAILED.value
|
||||||
|
row.error = error[:2048]
|
||||||
|
row.updated_at = _utcnow()
|
||||||
|
|
||||||
|
# ---------- reads ----------
|
||||||
|
|
||||||
|
def get(self, job_id: UUID) -> JobRow | None:
|
||||||
|
stmt = select(JobRow).where(JobRow.job_id == job_id)
|
||||||
|
return self.session.scalar(stmt)
|
||||||
|
|
||||||
|
def get_or_raise(self, job_id: UUID) -> JobRow:
|
||||||
|
return self._get_or_raise(job_id)
|
||||||
|
|
||||||
|
def _get_or_raise(self, job_id: UUID) -> JobRow:
|
||||||
|
row = self.get(job_id)
|
||||||
|
if row is None:
|
||||||
|
raise JobNotFoundError(f"Job not found: {job_id}")
|
||||||
|
return row
|
||||||
@@ -6,15 +6,29 @@ from fastapi import FastAPI
|
|||||||
|
|
||||||
from ocr_sprint import __version__
|
from ocr_sprint import __version__
|
||||||
from ocr_sprint.api.errors import register_error_handlers
|
from ocr_sprint.api.errors import register_error_handlers
|
||||||
|
from ocr_sprint.api.metrics import MetricsMiddleware, metrics_endpoint
|
||||||
from ocr_sprint.api.routes import documents, health
|
from ocr_sprint.api.routes import documents, health
|
||||||
from ocr_sprint.config import get_settings
|
from ocr_sprint.config import get_settings
|
||||||
|
from ocr_sprint.db import models as _models # noqa: F401 (register ORM tables)
|
||||||
|
from ocr_sprint.db.base import Base, get_engine
|
||||||
from ocr_sprint.utils.logging import configure_logging
|
from ocr_sprint.utils.logging import configure_logging
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_schema() -> None:
|
||||||
|
"""Create tables if they don't exist.
|
||||||
|
|
||||||
|
Production deploys should run Alembic migrations explicitly; this is a
|
||||||
|
convenience for local dev / tests so the API works without a manual
|
||||||
|
`alembic upgrade head` step.
|
||||||
|
"""
|
||||||
|
Base.metadata.create_all(bind=get_engine())
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
"""Application factory — keeps top-level state easy to test."""
|
"""Application factory — keeps top-level state easy to test."""
|
||||||
settings = get_settings()
|
settings = get_settings()
|
||||||
configure_logging(settings.app_log_level)
|
configure_logging(settings.app_log_level)
|
||||||
|
_ensure_schema()
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="OCR Sprint Service",
|
title="OCR Sprint Service",
|
||||||
@@ -26,8 +40,10 @@ def create_app() -> FastAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
register_error_handlers(app)
|
register_error_handlers(app)
|
||||||
|
app.add_middleware(MetricsMiddleware)
|
||||||
app.include_router(health.router, prefix="/api/v1")
|
app.include_router(health.router, prefix="/api/v1")
|
||||||
app.include_router(documents.router, prefix="/api/v1")
|
app.include_router(documents.router, prefix="/api/v1")
|
||||||
|
app.add_api_route("/metrics", metrics_endpoint, methods=["GET"], include_in_schema=False)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
5
src/ocr_sprint/storage/__init__.py
Normal file
5
src/ocr_sprint/storage/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Pluggable blob storage. Local-fs only for the MVP."""
|
||||||
|
|
||||||
|
from ocr_sprint.storage.blob import BlobStorage, LocalFsBlobStorage, get_blob_storage
|
||||||
|
|
||||||
|
__all__ = ["BlobStorage", "LocalFsBlobStorage", "get_blob_storage"]
|
||||||
146
src/ocr_sprint/storage/blob.py
Normal file
146
src/ocr_sprint/storage/blob.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""Blob storage abstraction.
|
||||||
|
|
||||||
|
The MVP only ships a local-filesystem backend. The `BlobStorage` Protocol is
|
||||||
|
deliberately small (put / get / exists / delete) so that an S3- or MinIO-
|
||||||
|
backed implementation can be dropped in later without touching API code.
|
||||||
|
|
||||||
|
Layout on disk:
|
||||||
|
|
||||||
|
{blob_storage_dir}/
|
||||||
|
2026/04/25/
|
||||||
|
<uuid4>.<ext>
|
||||||
|
|
||||||
|
The date hierarchy keeps the directory listing manageable when the service
|
||||||
|
processes thousands of documents per day, and makes manual rsync-based
|
||||||
|
backup straightforward.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import BinaryIO, Protocol
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
from ocr_sprint.utils.logging import get_logger
|
||||||
|
|
||||||
|
_logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Map of upload extensions we'll honor when persisting blobs. Anything else
|
||||||
|
# falls back to `.bin` and the OCR pipeline's magic-byte sniffing handles
|
||||||
|
# the actual content kind.
|
||||||
|
_KNOWN_EXTS = {".pdf", ".png", ".jpg", ".jpeg", ".tif", ".tiff", ".webp"}
|
||||||
|
|
||||||
|
|
||||||
|
class BlobStorage(Protocol):
|
||||||
|
"""Minimal interface a blob backend must satisfy."""
|
||||||
|
|
||||||
|
def put(self, content: bytes, original_filename: str | None = None) -> str:
|
||||||
|
"""Persist `content` and return an opaque key the caller can use later."""
|
||||||
|
|
||||||
|
def get(self, key: str) -> bytes:
|
||||||
|
"""Return the raw bytes for `key`. Raises FileNotFoundError on miss."""
|
||||||
|
|
||||||
|
def open(self, key: str) -> BinaryIO:
|
||||||
|
"""Return a binary file-like object for streaming reads."""
|
||||||
|
|
||||||
|
def exists(self, key: str) -> bool:
|
||||||
|
"""True if `key` is currently stored."""
|
||||||
|
|
||||||
|
def delete(self, key: str) -> None:
|
||||||
|
"""Remove a blob. No-op if it doesn't exist."""
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFsBlobStorage:
|
||||||
|
"""Filesystem-backed implementation rooted at `base_dir`."""
|
||||||
|
|
||||||
|
def __init__(self, base_dir: Path) -> None:
|
||||||
|
# Resolve once so every subsequent path comparison (escape check,
|
||||||
|
# empty-dir cleanup) is apples-to-apples — ``Path.parents`` of a
|
||||||
|
# resolved key would otherwise never equal a relative ``base_dir``.
|
||||||
|
base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.base_dir = base_dir.resolve()
|
||||||
|
|
||||||
|
# ---------- helpers ----------
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_ext(original_filename: str | None) -> str:
|
||||||
|
if not original_filename:
|
||||||
|
return ".bin"
|
||||||
|
suffix = Path(original_filename).suffix.lower()
|
||||||
|
return suffix if suffix in _KNOWN_EXTS else ".bin"
|
||||||
|
|
||||||
|
def _resolve(self, key: str) -> Path:
|
||||||
|
# Defensive: keys come from the DB but we still reject paths that try
|
||||||
|
# to escape the blob root. ``Path.is_relative_to`` does proper path
|
||||||
|
# containment — string ``startswith`` would let ``/app/blobs_evil``
|
||||||
|
# slip past when the root is ``/app/blobs``.
|
||||||
|
candidate = (self.base_dir / key).resolve()
|
||||||
|
if not candidate.is_relative_to(self.base_dir):
|
||||||
|
raise ValueError(f"Blob key escapes storage root: {key!r}")
|
||||||
|
return candidate
|
||||||
|
|
||||||
|
# ---------- BlobStorage protocol ----------
|
||||||
|
|
||||||
|
def put(self, content: bytes, original_filename: str | None = None) -> str:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
date_dir = Path(f"{now:%Y/%m/%d}")
|
||||||
|
ext = self._safe_ext(original_filename)
|
||||||
|
key = str(date_dir / f"{uuid4().hex}{ext}")
|
||||||
|
target = self._resolve(key)
|
||||||
|
target.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Write to a temp file in the same directory then rename. This avoids
|
||||||
|
# a half-written blob being read by a concurrent worker.
|
||||||
|
tmp = target.with_suffix(target.suffix + ".tmp")
|
||||||
|
tmp.write_bytes(content)
|
||||||
|
tmp.rename(target)
|
||||||
|
_logger.info("blob.put", key=key, size=len(content))
|
||||||
|
return key
|
||||||
|
|
||||||
|
def get(self, key: str) -> bytes:
|
||||||
|
path = self._resolve(key)
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Blob not found: {key}")
|
||||||
|
return path.read_bytes()
|
||||||
|
|
||||||
|
def open(self, key: str) -> BinaryIO:
|
||||||
|
path = self._resolve(key)
|
||||||
|
if not path.exists():
|
||||||
|
raise FileNotFoundError(f"Blob not found: {key}")
|
||||||
|
return path.open("rb")
|
||||||
|
|
||||||
|
def exists(self, key: str) -> bool:
|
||||||
|
try:
|
||||||
|
return self._resolve(key).exists()
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def delete(self, key: str) -> None:
|
||||||
|
try:
|
||||||
|
path = self._resolve(key)
|
||||||
|
except ValueError:
|
||||||
|
return
|
||||||
|
if path.exists():
|
||||||
|
path.unlink()
|
||||||
|
_logger.info("blob.delete", key=key)
|
||||||
|
# Best-effort cleanup of empty date dirs so we don't accumulate
|
||||||
|
# 365 directories per year forever. ``self.base_dir`` is already
|
||||||
|
# resolved (see __init__), so it can be compared against
|
||||||
|
# ``path.parents`` directly.
|
||||||
|
for parent in path.parents:
|
||||||
|
if parent == self.base_dir or self.base_dir not in parent.parents:
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
parent.rmdir()
|
||||||
|
except OSError:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def get_blob_storage() -> BlobStorage:
|
||||||
|
"""Build the configured blob backend. Single-process cache lives in `Settings`."""
|
||||||
|
s = get_settings()
|
||||||
|
return LocalFsBlobStorage(s.blob_storage_dir)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["BlobStorage", "LocalFsBlobStorage", "get_blob_storage"]
|
||||||
6
src/ocr_sprint/worker/__init__.py
Normal file
6
src/ocr_sprint/worker/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
"""Celery worker (Phase 4) — async OCR pipeline."""
|
||||||
|
|
||||||
|
from ocr_sprint.worker.celery_app import celery_app
|
||||||
|
from ocr_sprint.worker.tasks import process_document_task
|
||||||
|
|
||||||
|
__all__ = ["celery_app", "process_document_task"]
|
||||||
49
src/ocr_sprint/worker/celery_app.py
Normal file
49
src/ocr_sprint/worker/celery_app.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Celery application factory.
|
||||||
|
|
||||||
|
The broker and result backend are both Redis. We deliberately don't use
|
||||||
|
Postgres as the result backend — Celery's pg result-backend creates noisy
|
||||||
|
schema artefacts and we already store the structured result on the `jobs`
|
||||||
|
table.
|
||||||
|
|
||||||
|
Tasks register themselves by calling `celery_app.task` in `tasks.py`. Eager
|
||||||
|
mode (used in tests) is enabled by setting `CELERY_TASK_ALWAYS_EAGER=true`
|
||||||
|
in the environment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from celery import Celery
|
||||||
|
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
|
||||||
|
|
||||||
|
def build_celery_app() -> Celery:
|
||||||
|
settings = get_settings()
|
||||||
|
app = Celery(
|
||||||
|
"ocr_sprint",
|
||||||
|
broker=settings.redis_url,
|
||||||
|
backend=settings.redis_url,
|
||||||
|
include=["ocr_sprint.worker.tasks"],
|
||||||
|
)
|
||||||
|
app.conf.update(
|
||||||
|
task_default_queue=settings.celery_task_default_queue,
|
||||||
|
task_serializer="json",
|
||||||
|
result_serializer="json",
|
||||||
|
accept_content=["json"],
|
||||||
|
timezone="UTC",
|
||||||
|
enable_utc=True,
|
||||||
|
task_acks_late=True,
|
||||||
|
task_reject_on_worker_lost=True,
|
||||||
|
worker_prefetch_multiplier=1, # OCR is CPU-bound; one task at a time.
|
||||||
|
broker_connection_retry_on_startup=True,
|
||||||
|
result_expires=24 * 3600, # results live in DB; redis is just a cache
|
||||||
|
)
|
||||||
|
if os.getenv("CELERY_TASK_ALWAYS_EAGER", "").lower() in {"1", "true", "yes"}:
|
||||||
|
app.conf.task_always_eager = True
|
||||||
|
app.conf.task_eager_propagates = True
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
celery_app = build_celery_app()
|
||||||
84
src/ocr_sprint/worker/tasks.py
Normal file
84
src/ocr_sprint/worker/tasks.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
"""Celery tasks.
|
||||||
|
|
||||||
|
`process_document_task` is the single async entrypoint: the API enqueues
|
||||||
|
one task per upload, the worker pulls the blob, runs the orchestrator, and
|
||||||
|
writes the result back to the `jobs` table.
|
||||||
|
|
||||||
|
Tasks must be idempotent at the boundary: if the worker crashes mid-OCR,
|
||||||
|
the next retry should pick up the same blob and produce the same result.
|
||||||
|
We therefore re-fetch the row by id on every transition rather than
|
||||||
|
threading state through closures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from celery.exceptions import Reject
|
||||||
|
|
||||||
|
from ocr_sprint.db.base import session_scope
|
||||||
|
from ocr_sprint.db.repositories import JobRepository
|
||||||
|
from ocr_sprint.pipeline.orchestrator import run_pipeline
|
||||||
|
from ocr_sprint.storage.blob import get_blob_storage
|
||||||
|
from ocr_sprint.utils.logging import get_logger
|
||||||
|
from ocr_sprint.worker.celery_app import celery_app
|
||||||
|
|
||||||
|
_logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@celery_app.task(name="ocr_sprint.process_document", bind=True, max_retries=0) # type: ignore[untyped-decorator]
|
||||||
|
def process_document_task(self: object, job_id_str: str) -> str:
|
||||||
|
"""Run the OCR pipeline for `job_id` and persist the structured result.
|
||||||
|
|
||||||
|
Returns the final job status as a string so callers can wait on
|
||||||
|
`AsyncResult.get()` if they want to (mainly tests).
|
||||||
|
"""
|
||||||
|
job_id = UUID(job_id_str)
|
||||||
|
log = _logger.bind(job_id=job_id_str)
|
||||||
|
storage = get_blob_storage()
|
||||||
|
|
||||||
|
with session_scope() as session:
|
||||||
|
repo = JobRepository(session)
|
||||||
|
row = repo.get(job_id)
|
||||||
|
if row is None:
|
||||||
|
raise Reject(f"Job {job_id_str} not found", requeue=False)
|
||||||
|
repo.mark_processing(job_id)
|
||||||
|
blob_key = row.blob_key
|
||||||
|
|
||||||
|
if not blob_key:
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).mark_failed(job_id, error="missing blob_key")
|
||||||
|
return "failed"
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = storage.get(blob_key)
|
||||||
|
except FileNotFoundError as exc:
|
||||||
|
log.error("worker.blob_missing", error=str(exc))
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).mark_failed(job_id, error=f"blob missing: {exc}")
|
||||||
|
return "failed"
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = run_pipeline(content)
|
||||||
|
except Exception as exc:
|
||||||
|
log.exception("worker.pipeline_error")
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).mark_failed(job_id, error=str(exc))
|
||||||
|
return "failed"
|
||||||
|
|
||||||
|
flags = [f.value for f in output.result.review_flags]
|
||||||
|
log.info(
|
||||||
|
"worker.completed",
|
||||||
|
status=output.status.value,
|
||||||
|
confidence=round(output.confidence, 3),
|
||||||
|
flags=flags,
|
||||||
|
)
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).mark_completed(
|
||||||
|
job_id,
|
||||||
|
status=output.status,
|
||||||
|
confidence=output.confidence,
|
||||||
|
result=output.result.model_dump(mode="json"),
|
||||||
|
review_flags=flags,
|
||||||
|
)
|
||||||
|
return output.status.value
|
||||||
@@ -2,10 +2,54 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _isolated_runtime(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> Iterator[None]:
|
||||||
|
"""Per-test sqlite + blob storage so tests don't share state.
|
||||||
|
|
||||||
|
Setting these env vars before ``Settings`` is first read in the test gives
|
||||||
|
each test its own DB file and blob root. We also clear the lru_cache on
|
||||||
|
`get_settings`, the engine, and the sessionmaker so the fresh paths take
|
||||||
|
effect even if a previous test already loaded settings.
|
||||||
|
"""
|
||||||
|
db_path = tmp_path / "test.sqlite"
|
||||||
|
blob_dir = tmp_path / "blobs"
|
||||||
|
monkeypatch.setenv("DATABASE_URL", f"sqlite:///{db_path}")
|
||||||
|
monkeypatch.setenv("BLOB_STORAGE_DIR", str(blob_dir))
|
||||||
|
monkeypatch.setenv("STORAGE_LOCAL_DIR", str(tmp_path / "storage"))
|
||||||
|
monkeypatch.setenv("API_KEYS", "")
|
||||||
|
# The async API path is exercised by the test suite, so default it on
|
||||||
|
# here. Production keeps ``QUEUE_ENABLED=false`` so the route falls back
|
||||||
|
# to the inline pipeline when no Redis is configured.
|
||||||
|
monkeypatch.setenv("QUEUE_ENABLED", "true")
|
||||||
|
# Force Celery to run tasks inline so we don't need a broker.
|
||||||
|
monkeypatch.setenv("CELERY_TASK_ALWAYS_EAGER", "true")
|
||||||
|
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
from ocr_sprint.db.base import reset_engine_cache
|
||||||
|
from ocr_sprint.worker.celery_app import celery_app
|
||||||
|
|
||||||
|
get_settings.cache_clear()
|
||||||
|
reset_engine_cache()
|
||||||
|
# `celery_app` is built once at import-time, so flip the eager flag on the
|
||||||
|
# already-instantiated instance for this test.
|
||||||
|
celery_app.conf.task_always_eager = True
|
||||||
|
celery_app.conf.task_eager_propagates = True
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
get_settings.cache_clear()
|
||||||
|
reset_engine_cache()
|
||||||
|
os.environ.pop("CELERY_TASK_ALWAYS_EAGER", None)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def blank_bgr_image() -> np.ndarray:
|
def blank_bgr_image() -> np.ndarray:
|
||||||
"""A 600x800 white BGR image (uint8) — useful for preprocessing smoke tests."""
|
"""A 600x800 white BGR image (uint8) — useful for preprocessing smoke tests."""
|
||||||
|
|||||||
@@ -23,35 +23,9 @@ def client() -> TestClient:
|
|||||||
return TestClient(create_app())
|
return TestClient(create_app())
|
||||||
|
|
||||||
|
|
||||||
def test_health_endpoint(client: TestClient) -> None:
|
@pytest.fixture
|
||||||
response = client.get("/api/v1/health")
|
def fake_pipeline(monkeypatch: pytest.MonkeyPatch) -> PipelineOutput:
|
||||||
assert response.status_code == 200
|
"""Patch run_pipeline everywhere it's referenced."""
|
||||||
assert response.json()["status"] == "ok"
|
|
||||||
|
|
||||||
|
|
||||||
def test_documents_rejects_empty_upload(client: TestClient) -> None:
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/documents",
|
|
||||||
files={"file": ("empty.pdf", b"", "application/pdf")},
|
|
||||||
)
|
|
||||||
assert response.status_code == 400
|
|
||||||
|
|
||||||
|
|
||||||
def test_documents_rejects_unknown_format(
|
|
||||||
client: TestClient,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
response = client.post(
|
|
||||||
"/api/v1/documents",
|
|
||||||
files={"file": ("x.bin", b"random garbage bytes here", "application/octet-stream")},
|
|
||||||
)
|
|
||||||
assert response.status_code == 400
|
|
||||||
|
|
||||||
|
|
||||||
def test_documents_returns_pipeline_output(
|
|
||||||
client: TestClient,
|
|
||||||
monkeypatch: pytest.MonkeyPatch,
|
|
||||||
) -> None:
|
|
||||||
fake_result = ExtractionResult(
|
fake_result = ExtractionResult(
|
||||||
header=HeaderFields(
|
header=HeaderFields(
|
||||||
nomor_sprint="Sprin/1/I/2025",
|
nomor_sprint="Sprin/1/I/2025",
|
||||||
@@ -70,14 +44,36 @@ def test_documents_returns_pipeline_output(
|
|||||||
def _fake_run(_content: bytes) -> PipelineOutput:
|
def _fake_run(_content: bytes) -> PipelineOutput:
|
||||||
return fake_output
|
return fake_output
|
||||||
|
|
||||||
# Patch the symbol *imported into* the routes module.
|
|
||||||
monkeypatch.setattr(orch_module, "run_pipeline", _fake_run)
|
monkeypatch.setattr(orch_module, "run_pipeline", _fake_run)
|
||||||
from ocr_sprint.api.routes import documents as docs_module
|
from ocr_sprint.api.routes import documents as docs_module
|
||||||
|
|
||||||
monkeypatch.setattr(docs_module, "run_pipeline", _fake_run)
|
monkeypatch.setattr(docs_module, "run_pipeline", _fake_run)
|
||||||
|
from ocr_sprint.worker import tasks as tasks_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(tasks_module, "run_pipeline", _fake_run)
|
||||||
|
return fake_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_endpoint(client: TestClient) -> None:
|
||||||
|
response = client.get("/api/v1/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["status"] == "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def test_documents_rejects_empty_upload(client: TestClient) -> None:
|
||||||
response = client.post(
|
response = client.post(
|
||||||
"/api/v1/documents",
|
"/api/v1/documents",
|
||||||
|
files={"file": ("empty.pdf", b"", "application/pdf")},
|
||||||
|
)
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_documents_sync_returns_pipeline_output(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/documents?sync=true",
|
||||||
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||||
)
|
)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -85,3 +81,169 @@ def test_documents_returns_pipeline_output(
|
|||||||
assert body["status"] == "completed"
|
assert body["status"] == "completed"
|
||||||
assert body["confidence"] == 0.97
|
assert body["confidence"] == 0.97
|
||||||
assert body["data"]["header"]["nomor_sprint"] == "Sprin/1/I/2025"
|
assert body["data"]["header"]["nomor_sprint"] == "Sprin/1/I/2025"
|
||||||
|
|
||||||
|
|
||||||
|
def test_documents_async_returns_202_then_polls_to_completion(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
"""Default flow: POST returns 202, GET returns the eventual completion.
|
||||||
|
|
||||||
|
With CELERY_TASK_ALWAYS_EAGER set in conftest, the worker runs inline,
|
||||||
|
so by the time POST returns the task has already finished and GET sees
|
||||||
|
a `completed` row.
|
||||||
|
"""
|
||||||
|
post = client.post(
|
||||||
|
"/api/v1/documents",
|
||||||
|
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||||
|
)
|
||||||
|
assert post.status_code == 202
|
||||||
|
job_id = post.json()["job_id"]
|
||||||
|
|
||||||
|
get = client.get(f"/api/v1/documents/{job_id}")
|
||||||
|
assert get.status_code == 200
|
||||||
|
body = get.json()
|
||||||
|
assert body["status"] == "completed"
|
||||||
|
assert body["confidence"] == 0.97
|
||||||
|
|
||||||
|
|
||||||
|
def test_documents_defaults_to_sync_when_queue_disabled(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Regression: with ``QUEUE_ENABLED=false`` the route must NOT enqueue,
|
||||||
|
otherwise a default install with no Redis returns 500.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("QUEUE_ENABLED", "false")
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
|
||||||
|
get_settings.cache_clear()
|
||||||
|
|
||||||
|
# Pretend the broker is unreachable; if the route still enqueues, the
|
||||||
|
# call would blow up here.
|
||||||
|
def _no_broker(_self: object, *_args: object, **_kwargs: object) -> None:
|
||||||
|
raise AssertionError("queue path taken when queue is disabled")
|
||||||
|
|
||||||
|
from ocr_sprint.worker import tasks as task_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_module.process_document_task, "delay", _no_broker)
|
||||||
|
|
||||||
|
post = client.post(
|
||||||
|
"/api/v1/documents",
|
||||||
|
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||||
|
)
|
||||||
|
assert post.status_code == 200, post.text
|
||||||
|
body = post.json()
|
||||||
|
assert body["status"] == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_documents_get_unknown_id_returns_404(client: TestClient) -> None:
|
||||||
|
response = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000")
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_documents_async_marks_failed_on_pipeline_error(
|
||||||
|
client: TestClient,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
def _explode(_content: bytes) -> PipelineOutput:
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
from ocr_sprint.worker import tasks as tasks_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(tasks_module, "run_pipeline", _explode)
|
||||||
|
|
||||||
|
post = client.post(
|
||||||
|
"/api/v1/documents",
|
||||||
|
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||||
|
)
|
||||||
|
assert post.status_code == 202
|
||||||
|
job_id = post.json()["job_id"]
|
||||||
|
|
||||||
|
get = client.get(f"/api/v1/documents/{job_id}")
|
||||||
|
body = get.json()
|
||||||
|
assert body["status"] == "failed"
|
||||||
|
assert "boom" in (body.get("error") or "")
|
||||||
|
|
||||||
|
|
||||||
|
def test_documents_sync_persists_failed_row_when_pipeline_raises(
|
||||||
|
client: TestClient,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""Regression: an exception in the sync pipeline must NOT roll back the
|
||||||
|
pending row + ``mark_failed`` write. Otherwise the blob on disk has no
|
||||||
|
DB record pointing at it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _explode(_content: bytes) -> PipelineOutput:
|
||||||
|
raise RuntimeError("kapow")
|
||||||
|
|
||||||
|
from ocr_sprint.api.routes import documents as docs_module
|
||||||
|
|
||||||
|
monkeypatch.setattr(docs_module, "run_pipeline", _explode)
|
||||||
|
|
||||||
|
# ``raise_server_exceptions=False`` lets the test see the 500 response
|
||||||
|
# rather than re-raising the underlying RuntimeError from the route.
|
||||||
|
silent = TestClient(client.app, raise_server_exceptions=False)
|
||||||
|
post = silent.post(
|
||||||
|
"/api/v1/documents?sync=true",
|
||||||
|
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||||
|
)
|
||||||
|
assert post.status_code == 500
|
||||||
|
|
||||||
|
# The row must still be visible to GET, with status=failed.
|
||||||
|
from ocr_sprint.db.base import session_scope
|
||||||
|
from ocr_sprint.db.repositories import JobRepository
|
||||||
|
|
||||||
|
with session_scope() as session:
|
||||||
|
# Find the most recent row.
|
||||||
|
from ocr_sprint.db.models import JobRow
|
||||||
|
|
||||||
|
row = session.query(JobRow).order_by(JobRow.created_at.desc()).first()
|
||||||
|
assert row is not None, "create() must persist even when pipeline blows up"
|
||||||
|
assert row.status == "failed"
|
||||||
|
assert "kapow" in (row.error or "")
|
||||||
|
assert row.blob_key # blob is referenced — not orphaned
|
||||||
|
|
||||||
|
# GET must surface the failure too (this is the client-visible contract).
|
||||||
|
get = client.get(f"/api/v1/documents/{row.job_id}")
|
||||||
|
assert get.status_code == 200
|
||||||
|
assert get.json()["status"] == "failed"
|
||||||
|
assert JobRepository # silence import-only warning
|
||||||
|
|
||||||
|
|
||||||
|
def test_metrics_endpoint_exposes_request_counter(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
client.post(
|
||||||
|
"/api/v1/documents?sync=true",
|
||||||
|
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||||
|
)
|
||||||
|
metrics = client.get("/metrics")
|
||||||
|
assert metrics.status_code == 200
|
||||||
|
body = metrics.text
|
||||||
|
assert "http_requests_total" in body
|
||||||
|
assert "ocr_jobs_total" in body
|
||||||
|
|
||||||
|
|
||||||
|
def test_metrics_jobs_total_reflects_worker_writes(
|
||||||
|
client: TestClient,
|
||||||
|
fake_pipeline: PipelineOutput,
|
||||||
|
) -> None:
|
||||||
|
"""Regression: when the worker (eager mode here) marks a job complete,
|
||||||
|
/metrics must reflect that — the previous Counter-based implementation
|
||||||
|
would have stayed at zero because the worker's increments don't reach
|
||||||
|
the API process's in-memory registry.
|
||||||
|
"""
|
||||||
|
post = client.post(
|
||||||
|
"/api/v1/documents",
|
||||||
|
files={"file": ("x.pdf", b"%PDF-1.4\n%fake", "application/pdf")},
|
||||||
|
)
|
||||||
|
assert post.status_code == 202
|
||||||
|
|
||||||
|
body = client.get("/metrics").text
|
||||||
|
# ``ocr_jobs_total{status="completed"} 1.0`` — exact match to make sure
|
||||||
|
# the gauge-style metric is being populated from the DB.
|
||||||
|
assert 'ocr_jobs_total{status="completed"} 1.0' in body
|
||||||
|
|||||||
43
tests/unit/test_auth.py
Normal file
43
tests/unit/test_auth.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
"""API key authentication."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from ocr_sprint.config import get_settings
|
||||||
|
from ocr_sprint.main import create_app
|
||||||
|
|
||||||
|
|
||||||
|
def _client_with_keys(monkeypatch: pytest.MonkeyPatch, keys: str) -> TestClient:
|
||||||
|
monkeypatch.setenv("API_KEYS", keys)
|
||||||
|
get_settings.cache_clear()
|
||||||
|
return TestClient(create_app())
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_when_keys_empty(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
client = _client_with_keys(monkeypatch, "")
|
||||||
|
response = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000")
|
||||||
|
# 404 not 401: auth disabled, the endpoint just doesn't find the row.
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_rejects_missing_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
client = _client_with_keys(monkeypatch, "secret-1,secret-2")
|
||||||
|
response = client.get("/api/v1/documents/00000000-0000-0000-0000-000000000000")
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_accepts_valid_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
client = _client_with_keys(monkeypatch, "secret-1,secret-2")
|
||||||
|
response = client.get(
|
||||||
|
"/api/v1/documents/00000000-0000-0000-0000-000000000000",
|
||||||
|
headers={"X-API-Key": "secret-2"},
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_is_unprotected(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
client = _client_with_keys(monkeypatch, "secret-1")
|
||||||
|
response = client.get("/api/v1/health")
|
||||||
|
assert response.status_code == 200
|
||||||
85
tests/unit/test_blob_storage.py
Normal file
85
tests/unit/test_blob_storage.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Local-filesystem blob storage."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ocr_sprint.storage.blob import LocalFsBlobStorage
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def storage(tmp_path: Path) -> LocalFsBlobStorage:
|
||||||
|
return LocalFsBlobStorage(tmp_path / "blobs")
|
||||||
|
|
||||||
|
|
||||||
|
def test_put_returns_dated_key(storage: LocalFsBlobStorage) -> None:
|
||||||
|
key = storage.put(b"hello", original_filename="surat.pdf")
|
||||||
|
# Layout is YYYY/MM/DD/<uuid>.pdf
|
||||||
|
parts = key.split("/")
|
||||||
|
assert len(parts) == 4
|
||||||
|
assert parts[3].endswith(".pdf")
|
||||||
|
assert storage.exists(key)
|
||||||
|
assert storage.get(key) == b"hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_put_unknown_extension_falls_back_to_bin(storage: LocalFsBlobStorage) -> None:
|
||||||
|
key = storage.put(b"x", original_filename="weird.xyz")
|
||||||
|
assert key.endswith(".bin")
|
||||||
|
|
||||||
|
|
||||||
|
def test_put_strips_directory_traversal(storage: LocalFsBlobStorage) -> None:
|
||||||
|
# ext is taken via Path().suffix, not from the raw filename, so a name
|
||||||
|
# like "../../etc/passwd" is harmless — the only thing the caller can
|
||||||
|
# influence is the extension.
|
||||||
|
key = storage.put(b"y", original_filename="../../etc/passwd")
|
||||||
|
assert "etc" not in key
|
||||||
|
assert key.endswith(".bin")
|
||||||
|
|
||||||
|
|
||||||
|
def test_put_handles_missing_filename(storage: LocalFsBlobStorage) -> None:
|
||||||
|
key = storage.put(b"z", original_filename=None)
|
||||||
|
assert key.endswith(".bin")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_unknown_key_raises(storage: LocalFsBlobStorage) -> None:
|
||||||
|
with pytest.raises(FileNotFoundError):
|
||||||
|
storage.get("2026/01/01/bogus.pdf")
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_is_idempotent(storage: LocalFsBlobStorage) -> None:
|
||||||
|
key = storage.put(b"q", original_filename="x.png")
|
||||||
|
storage.delete(key)
|
||||||
|
assert not storage.exists(key)
|
||||||
|
storage.delete(key) # second delete must not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_rejects_path_escape(storage: LocalFsBlobStorage) -> None:
|
||||||
|
with pytest.raises(ValueError, match="escapes storage root"):
|
||||||
|
storage._resolve("../../../etc/passwd")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_rejects_directory_prefix_collision(tmp_path: Path) -> None:
|
||||||
|
"""Regression: ``startswith`` would mis-accept sibling dirs whose names
|
||||||
|
happen to begin with the storage root's basename. ``is_relative_to``
|
||||||
|
handles this correctly.
|
||||||
|
"""
|
||||||
|
root = tmp_path / "blobs"
|
||||||
|
root.mkdir()
|
||||||
|
sibling = tmp_path / "blobs_evil"
|
||||||
|
sibling.mkdir()
|
||||||
|
storage = LocalFsBlobStorage(root)
|
||||||
|
with pytest.raises(ValueError, match="escapes storage root"):
|
||||||
|
storage._resolve("../blobs_evil/secret.txt")
|
||||||
|
|
||||||
|
|
||||||
|
def test_exists_returns_false_for_escaped_key(storage: LocalFsBlobStorage) -> None:
|
||||||
|
# exists() must not raise even for malicious keys.
|
||||||
|
assert storage.exists("../../etc/passwd") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_open_streams_content(storage: LocalFsBlobStorage) -> None:
|
||||||
|
key = storage.put(b"streamed", original_filename="x.png")
|
||||||
|
with storage.open(key) as fh:
|
||||||
|
assert fh.read() == b"streamed"
|
||||||
76
tests/unit/test_db_repository.py
Normal file
76
tests/unit/test_db_repository.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""SQLAlchemy repository tests against an in-memory sqlite db."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ocr_sprint.db.base import Base, get_engine, session_scope
|
||||||
|
from ocr_sprint.db.repositories import JobNotFoundError, JobRepository
|
||||||
|
from ocr_sprint.schemas.document import DocumentStatus, SourceKind
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_ready() -> None:
|
||||||
|
Base.metadata.create_all(bind=get_engine())
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_then_fetch(db_ready: None) -> None:
|
||||||
|
jid = uuid4()
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).create(
|
||||||
|
job_id=jid,
|
||||||
|
filename="x.pdf",
|
||||||
|
source_kind=SourceKind.PDF,
|
||||||
|
blob_key="2026/01/01/x.pdf",
|
||||||
|
)
|
||||||
|
with session_scope() as session:
|
||||||
|
row = JobRepository(session).get_or_raise(jid)
|
||||||
|
assert row.status == DocumentStatus.PENDING.value
|
||||||
|
assert row.source_kind == SourceKind.PDF.value
|
||||||
|
assert row.blob_key == "2026/01/01/x.pdf"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifecycle_transitions(db_ready: None) -> None:
|
||||||
|
jid = uuid4()
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).create(
|
||||||
|
job_id=jid,
|
||||||
|
filename="x.pdf",
|
||||||
|
source_kind=SourceKind.PDF,
|
||||||
|
blob_key="k",
|
||||||
|
)
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).mark_processing(jid)
|
||||||
|
with session_scope() as session:
|
||||||
|
repo = JobRepository(session)
|
||||||
|
repo.mark_completed(
|
||||||
|
jid,
|
||||||
|
status=DocumentStatus.NEEDS_REVIEW,
|
||||||
|
confidence=0.88,
|
||||||
|
result={"header": {"nomor_sprint": "Sprin/1/2025"}},
|
||||||
|
review_flags=["low_ocr_confidence"],
|
||||||
|
)
|
||||||
|
row = repo.get_or_raise(jid)
|
||||||
|
assert row.status == DocumentStatus.NEEDS_REVIEW.value
|
||||||
|
assert row.confidence == 0.88
|
||||||
|
assert row.result == {"header": {"nomor_sprint": "Sprin/1/2025"}}
|
||||||
|
assert row.review_flags == ["low_ocr_confidence"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_mark_failed_truncates_long_error(db_ready: None) -> None:
|
||||||
|
jid = uuid4()
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).create(
|
||||||
|
job_id=jid, filename="x", source_kind=SourceKind.UNKNOWN, blob_key="k"
|
||||||
|
)
|
||||||
|
with session_scope() as session:
|
||||||
|
JobRepository(session).mark_failed(jid, error="x" * 5000)
|
||||||
|
row = JobRepository(session).get_or_raise(jid)
|
||||||
|
assert len(row.error or "") == 2048
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_job_raises(db_ready: None) -> None:
|
||||||
|
with session_scope() as session, pytest.raises(JobNotFoundError):
|
||||||
|
JobRepository(session).get_or_raise(uuid4())
|
||||||
Reference in New Issue
Block a user