"""Documents API. Phase 1 shipped a single synchronous endpoint. Phase 4 adds an async flow on top: * `POST /documents` — async by default. Saves the upload to blob storage, creates a `pending` job row, and enqueues a Celery task. Returns `202` with the job id. * `POST /documents?sync=true` — runs the pipeline inline (the original Phase 1 behaviour). Useful for tests and small-volume single-tenant deploys without a Celery worker. * `GET /documents/{job_id}` — returns the current job state. Async clients poll this until `status` is in a terminal state (completed / needs_review / failed). """ from __future__ import annotations from typing import Annotated from uuid import UUID, uuid4 from fastapi import ( APIRouter, Depends, File, Header, HTTPException, Query, Response, UploadFile, status, ) from sqlalchemy.orm import Session from ocr_sprint.api.deps.auth import require_api_key from ocr_sprint.api.deps.db import get_session from ocr_sprint.api.errors import UnsupportedDocumentError from ocr_sprint.api.metrics import JOB_PROCESSING_SECONDS from ocr_sprint.config import get_settings from ocr_sprint.db.base import session_scope from ocr_sprint.db.repositories import ( InvalidFieldPathError, JobAlreadyApprovedError, JobNotCompletedError, JobNotFoundError, JobRepository, ) from ocr_sprint.pipeline.ingest import detect_source_kind from ocr_sprint.pipeline.orchestrator import run_pipeline from ocr_sprint.schemas.document import DocumentResponse, DocumentStatus from ocr_sprint.schemas.extraction import ExtractionResult from ocr_sprint.schemas.review import ( ApprovalResponse, CorrectionEventResponse, CorrectionRequest, ) from ocr_sprint.storage.blob import get_blob_storage from ocr_sprint.utils.logging import get_logger router = APIRouter( prefix="/documents", tags=["documents"], dependencies=[Depends(require_api_key)], ) _logger = get_logger(__name__) # ---------- helpers ---------- def _enforce_size(content: bytes) -> None: s = get_settings() if not content: raise UnsupportedDocumentError("Uploaded file is empty.") max_bytes = s.blob_max_upload_mb * 1024 * 1024 if len(content) > max_bytes: raise UnsupportedDocumentError(f"Uploaded file exceeds {s.blob_max_upload_mb} MB limit.") def _row_to_response(row: object) -> DocumentResponse: # Local import to avoid a circular import at module load time. from ocr_sprint.db.models import JobRow assert isinstance(row, JobRow) status_enum = DocumentStatus(row.status) result_obj: ExtractionResult | None = None if row.result is not None: result_obj = ExtractionResult.model_validate(row.result) return DocumentResponse( job_id=row.job_id, status=status_enum, confidence=row.confidence, data=result_obj, review_flags=list(row.review_flags or []), error=row.error, approved=bool(row.approved), reviewed_by=row.reviewed_by, reviewed_at=row.reviewed_at, ) # ---------- POST ---------- @router.post("", response_model=DocumentResponse) async def create_document( file: Annotated[UploadFile, File(...)], session: Annotated[Session, Depends(get_session)], response: Response, sync: Annotated[ bool | None, Query(description="Run pipeline inline (skip queue). Defaults to !queue_enabled."), ] = None, ) -> DocumentResponse: # When the queue is disabled (default for local dev), running the async # path would try to dial Redis and fail with a 500. Auto-fall-back to the # inline pipeline unless the caller explicitly asked for async. if sync is None: sync = not get_settings().queue_enabled job_id = uuid4() log = _logger.bind(job_id=str(job_id), filename=file.filename or "") content = await file.read() _enforce_size(content) storage = get_blob_storage() blob_key = storage.put(content, original_filename=file.filename) source_kind = detect_source_kind(content) JobRepository(session).create( job_id=job_id, filename=file.filename or "", source_kind=source_kind, blob_key=blob_key, ) # Commit the `pending` row immediately so it is observable regardless # of what happens next. Both code paths below open their own session # for state transitions; that way an exception in `_run_inline` cannot # roll back the create() (which would orphan the blob on disk). session.commit() log.info("documents.received", size=len(content), blob_key=blob_key, sync=sync) if sync: # Status code stays at the default 200; the body's `status` field # tells the client whether the job needs review. return await _run_inline(job_id, content) # Async path — enqueue and return 202. The Celery worker will pick up # the row using its own session. from ocr_sprint.worker.tasks import process_document_task process_document_task.delay(str(job_id)) with session_scope() as poll: row = JobRepository(poll).get_or_raise(job_id) body = _row_to_response(row) response.status_code = status.HTTP_202_ACCEPTED return body async def _run_inline(job_id: UUID, content: bytes) -> DocumentResponse: """Synchronous pipeline execution. Each state transition opens its own short session so the request-scoped session's rollback-on-exception behaviour cannot wipe out the ``mark_failed`` write or strand the blob on disk. """ import time with session_scope() as s: JobRepository(s).mark_processing(job_id) started = time.perf_counter() try: output = run_pipeline(content) except ValueError as exc: with session_scope() as s: JobRepository(s).mark_failed(job_id, error=str(exc)) raise UnsupportedDocumentError(str(exc)) from exc except Exception as exc: with session_scope() as s: JobRepository(s).mark_failed(job_id, error=str(exc)) raise flags = [f.value for f in output.result.review_flags] JOB_PROCESSING_SECONDS.observe(time.perf_counter() - started) with session_scope() as s: repo = JobRepository(s) repo.mark_completed( job_id, status=output.status, confidence=output.confidence, result=output.result.model_dump(mode="json"), review_flags=flags, ) row = repo.get_or_raise(job_id) return _row_to_response(row) # ---------- GET ---------- @router.get( "/{job_id}", response_model=DocumentResponse, ) async def get_document( job_id: UUID, session: Annotated[Session, Depends(get_session)], ) -> DocumentResponse: repo = JobRepository(session) try: row = repo.get_or_raise(job_id) except JobNotFoundError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc return _row_to_response(row) # ---------- Phase 6 — HITL ---------- def _correction_row_to_response(row: object) -> CorrectionEventResponse: # Local import to avoid a cyclic import at module load time. from ocr_sprint.db.models import JobCorrectionRow assert isinstance(row, JobCorrectionRow) return CorrectionEventResponse( id=row.id, job_id=row.job_id, field_path=row.field_path, old_value=row.old_value, new_value=row.new_value, corrected_by=row.corrected_by, reason=row.reason, corrected_at=row.corrected_at, ) @router.patch( "/{job_id}", response_model=DocumentResponse, ) async def patch_document( job_id: UUID, body: CorrectionRequest, session: Annotated[Session, Depends(get_session)], x_user_id: Annotated[ str | None, Header(description="Free-form reviewer identifier recorded on the audit row."), ] = None, ) -> DocumentResponse: """Apply one or more field-level corrections and record an audit trail. The whole batch is applied atomically — if any path is invalid the request fails with 400 and no side effects are written. Returns the updated document so the client doesn't need a follow-up GET. """ repo = JobRepository(session) try: repo.apply_corrections( job_id, corrections=[(c.path, c.value, c.reason) for c in body.corrections], corrected_by=x_user_id, ) row = repo.get_or_raise(job_id) except JobNotFoundError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc except InvalidFieldPathError as exc: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc except JobAlreadyApprovedError as exc: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc except JobNotCompletedError as exc: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc _logger.info( "documents.patched", job_id=str(job_id), count=len(body.corrections), corrected_by=x_user_id or "", ) return _row_to_response(row) @router.get( "/{job_id}/history", response_model=list[CorrectionEventResponse], ) async def get_history( job_id: UUID, session: Annotated[Session, Depends(get_session)], ) -> list[CorrectionEventResponse]: repo = JobRepository(session) try: rows = repo.list_corrections(job_id) except JobNotFoundError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc return [_correction_row_to_response(r) for r in rows] @router.post( "/{job_id}/approve", response_model=ApprovalResponse, ) async def approve_document( job_id: UUID, session: Annotated[Session, Depends(get_session)], x_user_id: Annotated[ str | None, Header(description="Free-form reviewer identifier recorded on the job."), ] = None, ) -> ApprovalResponse: """Lock a job's final version. Idempotent: re-approving returns the existing row without overwriting ``reviewed_at``. """ repo = JobRepository(session) try: row = repo.approve(job_id, reviewed_by=x_user_id) except JobNotFoundError as exc: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=str(exc)) from exc except JobNotCompletedError as exc: raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc _logger.info("documents.approved", job_id=str(job_id), reviewed_by=row.reviewed_by or "") return ApprovalResponse( job_id=row.job_id, approved=bool(row.approved), reviewed_by=row.reviewed_by, reviewed_at=row.reviewed_at, )