From d11ce070caee9f8276cb34e639d03ccee919a9e6 Mon Sep 17 00:00:00 2001 From: johnpccd Date: Sun, 25 May 2025 19:44:09 +0200 Subject: [PATCH] refactor centralize auth logic --- src/api/v1/auth.py | 86 ++++++++----- src/api/v1/images.py | 136 +++++++++++--------- src/api/v1/search.py | 48 +++---- src/api/v1/teams.py | 55 ++++++-- src/api/v1/users.py | 78 +++++++++--- src/services/auth_service.py | 40 +++--- src/services/image_service.py | 227 ++++++++++++++++------------------ src/utils/authorization.py | 202 ++++++++++++++++++++++++++++++ 8 files changed, 600 insertions(+), 272 deletions(-) create mode 100644 src/utils/authorization.py diff --git a/src/api/v1/auth.py b/src/api/v1/auth.py index 9bac0df..1388420 100644 --- a/src/api/v1/auth.py +++ b/src/api/v1/auth.py @@ -12,6 +12,12 @@ from src.models.api_key import ApiKeyModel from src.models.team import TeamModel from src.models.user import UserModel from src.utils.logging import log_request +from src.utils.authorization import ( + require_admin, + create_auth_context, + log_authorization_context, + AuthorizationError +) from src.api.v1.error_handlers import handle_service_error logger = logging.getLogger(__name__) @@ -32,14 +38,15 @@ async def create_api_key( This endpoint creates an API key without requiring authentication. Both user_id and team_id must be provided as query parameters. """ - log_request( - { - "path": request.url.path, - "method": request.method, - "key_data": key_data.dict(), - "user_id": user_id, - "team_id": team_id - } + auth_context = create_auth_context( + user=None, # No authenticated user for this endpoint + resource_type="api_key", + action="create", + target_user_id=user_id, + target_team_id=team_id, + path=request.url.path, + method=request.method, + key_data=key_data.dict() ) try: @@ -63,21 +70,27 @@ async def create_api_key_for_user( This endpoint requires admin authentication and allows creating API keys for any user in the system. """ - log_request( - { - "path": request.url.path, - "method": request.method, - "target_user_id": user_id, - "key_data": key_data.dict() - }, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="api_key", + action="admin_create", + target_user_id=user_id, + path=request.url.path, + method=request.method, + key_data=key_data.dict() ) try: + # Centralized admin authorization check + require_admin(current_user, "create API keys for other users") + log_authorization_context(auth_context, success=True) + response = await auth_service.create_api_key_for_user_by_admin(user_id, key_data, current_user) logger.info(f"Admin {current_user.id} created API key for user {user_id}") return response + except AuthorizationError: + log_authorization_context(auth_context, success=False) + raise except Exception as e: raise handle_service_error(e, "admin API key creation") @@ -92,11 +105,14 @@ async def list_api_keys( Returns all active and inactive API keys belonging to the authenticated user. """ - log_request( - {"path": request.url.path, "method": request.method}, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="api_key", + action="list", + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await auth_service.list_user_api_keys(current_user) @@ -117,16 +133,25 @@ async def revoke_api_key( Deactivates the specified API key. Only the key owner or an admin can revoke keys. """ - log_request( - {"path": request.url.path, "method": request.method, "key_id": key_id}, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="api_key", + action="revoke", + key_id=key_id, + path=request.url.path, + method=request.method ) try: + # Authorization is handled in the service layer for this endpoint + # since it needs to check key ownership await auth_service.revoke_api_key(key_id, current_user) + log_authorization_context(auth_context, success=True) logger.info(f"API key {key_id} revoked by user {current_user.id}") return None + except AuthorizationError: + log_authorization_context(auth_context, success=False) + raise except Exception as e: raise handle_service_error(e, "API key revocation") @@ -142,11 +167,14 @@ async def verify_authentication( Validates the current API key and returns user information. Useful for checking if an API key is still valid and active. """ - log_request( - {"path": request.url.path, "method": request.method}, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="authentication", + action="verify", + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await auth_service.verify_user_authentication(current_user) diff --git a/src/api/v1/images.py b/src/api/v1/images.py index d08a3e2..ea8e927 100644 --- a/src/api/v1/images.py +++ b/src/api/v1/images.py @@ -10,6 +10,12 @@ from src.dependencies import ImageServiceDep from src.models.user import UserModel from src.schemas.image import ImageResponse, ImageListResponse, ImageCreate, ImageUpdate from src.utils.logging import log_request +from src.utils.authorization import ( + create_auth_context, + log_authorization_context, + get_team_filter, + AuthorizationError +) from src.api.v1.error_handlers import handle_service_error logger = logging.getLogger(__name__) @@ -46,18 +52,18 @@ async def upload_image( 400: Invalid file format or validation errors 500: Upload or processing errors """ - log_request( - { - "path": request.url.path, - "method": request.method, - "filename": file.filename, - "content_type": file.content_type, - "has_description": description is not None, - "collection_id": collection_id - }, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="image", + action="upload", + filename=file.filename, + content_type=file.content_type, + has_description=description is not None, + collection_id=collection_id, + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await image_service.upload_image(file, current_user, request, description, collection_id) @@ -95,18 +101,18 @@ async def list_images( 400: Invalid pagination parameters 500: Internal server error """ - log_request( - { - "path": request.url.path, - "method": request.method, - "skip": skip, - "limit": limit, - "is_admin": current_user.is_admin, - "collection_id": collection_id - }, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="image", + action="list", + skip=skip, + limit=limit, + collection_id=collection_id, + team_filter=get_team_filter(current_user), + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await image_service.list_images(current_user, request, skip, limit, collection_id) @@ -142,21 +148,24 @@ async def get_image( 404: Image not found 500: Internal server error """ - log_request( - { - "path": request.url.path, - "method": request.method, - "image_id": image_id, - "is_admin": current_user.is_admin - }, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="image", + action="get", + image_id=image_id, + path=request.url.path, + method=request.method ) try: + # Authorization is handled in the service layer since it needs to check the image's team response = await image_service.get_image(image_id, current_user, request) + log_authorization_context(auth_context, success=True) logger.info(f"Retrieved image {image_id} for user {current_user.id}") return response + except AuthorizationError: + log_authorization_context(auth_context, success=False) + raise except Exception as e: raise handle_service_error(e, "image retrieval") @@ -187,19 +196,19 @@ async def download_image( 404: Image not found 500: Internal server error """ - log_request( - { - "path": request.url.path, - "method": request.method, - "image_id": image_id, - "is_admin": current_user.is_admin - }, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="image", + action="download", + image_id=image_id, + path=request.url.path, + method=request.method ) try: + # Authorization is handled in the service layer since it needs to check the image's team file_content, content_type, filename = await image_service.download_image(image_id, current_user) + log_authorization_context(auth_context, success=True) logger.info(f"Image {image_id} downloaded by user {current_user.id}") @@ -209,6 +218,9 @@ async def download_image( media_type=content_type, headers={"Content-Disposition": f"attachment; filename={filename}"} ) + except AuthorizationError: + log_authorization_context(auth_context, success=False) + raise except Exception as e: raise handle_service_error(e, "image download") @@ -241,22 +253,25 @@ async def update_image( 404: Image not found 500: Internal server error """ - log_request( - { - "path": request.url.path, - "method": request.method, - "image_id": image_id, - "is_admin": current_user.is_admin, - "update_data": image_data.dict() - }, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="image", + action="update", + image_id=image_id, + update_data=image_data.dict(), + path=request.url.path, + method=request.method ) try: + # Authorization is handled in the service layer since it needs to check the image's team response = await image_service.update_image(image_id, image_data, current_user, request) + log_authorization_context(auth_context, success=True) logger.info(f"Image {image_id} updated by user {current_user.id}") return response + except AuthorizationError: + log_authorization_context(auth_context, success=False) + raise except Exception as e: raise handle_service_error(e, "image update") @@ -287,20 +302,23 @@ async def delete_image( 404: Image not found 500: Internal server error """ - log_request( - { - "path": request.url.path, - "method": request.method, - "image_id": image_id, - "is_admin": current_user.is_admin - }, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="image", + action="delete", + image_id=image_id, + path=request.url.path, + method=request.method ) try: + # Authorization is handled in the service layer since it needs to check the image's team await image_service.delete_image(image_id, current_user) + log_authorization_context(auth_context, success=True) logger.info(f"Image {image_id} deleted by user {current_user.id}") return None + except AuthorizationError: + log_authorization_context(auth_context, success=False) + raise except Exception as e: raise handle_service_error(e, "image deletion") \ No newline at end of file diff --git a/src/api/v1/search.py b/src/api/v1/search.py index 0ad6296..19dc57b 100644 --- a/src/api/v1/search.py +++ b/src/api/v1/search.py @@ -7,6 +7,12 @@ from src.dependencies import SearchServiceDep from src.models.user import UserModel from src.schemas.search import SearchResponse, SearchRequest from src.utils.logging import log_request +from src.utils.authorization import ( + create_auth_context, + log_authorization_context, + get_team_filter, + AuthorizationError +) from src.api.v1.error_handlers import handle_service_error logger = logging.getLogger(__name__) @@ -45,19 +51,19 @@ async def search_images( 400: Invalid search parameters or query format 500: Search service errors """ - log_request( - { - "path": request.url.path, - "method": request.method, - "query": q, - "limit": limit, - "similarity_threshold": similarity_threshold, - "collection_id": collection_id, - "is_admin": current_user.is_admin - }, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="image", + action="search", + query=q, + limit=limit, + similarity_threshold=similarity_threshold, + collection_id=collection_id, + team_filter=get_team_filter(current_user), + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await search_service.search_images( @@ -99,16 +105,16 @@ async def search_images_advanced( 400: Invalid search request or validation errors 500: Search service errors """ - log_request( - { - "path": request.url.path, - "method": request.method, - "search_request": search_request.dict(), - "is_admin": current_user.is_admin - }, - user_id=str(current_user.id), - team_id=str(current_user.team_id) + auth_context = create_auth_context( + user=current_user, + resource_type="image", + action="advanced_search", + search_request=search_request.dict(), + team_filter=get_team_filter(current_user), + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await search_service.search_images_advanced( diff --git a/src/api/v1/teams.py b/src/api/v1/teams.py index 65fb74d..ad7e5cb 100644 --- a/src/api/v1/teams.py +++ b/src/api/v1/teams.py @@ -5,6 +5,11 @@ from bson import ObjectId from src.dependencies import TeamServiceDep from src.schemas.team import TeamCreate, TeamUpdate, TeamResponse, TeamListResponse from src.utils.logging import log_request +from src.utils.authorization import ( + create_auth_context, + log_authorization_context, + AuthorizationError +) from src.api.v1.error_handlers import handle_service_error logger = logging.getLogger(__name__) @@ -30,9 +35,15 @@ async def create_team( Returns: TeamResponse: The created team information """ - log_request( - {"path": request.url.path, "method": request.method, "team_data": team_data.dict()} + auth_context = create_auth_context( + user=None, # No authentication required for team creation + resource_type="team", + action="create", + team_data=team_data.dict(), + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await team_service.create_team(team_data) @@ -58,9 +69,14 @@ async def list_teams( Returns: TeamListResponse: List of all teams with total count """ - log_request( - {"path": request.url.path, "method": request.method} + auth_context = create_auth_context( + user=None, # No authentication required for listing teams + resource_type="team", + action="list", + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await team_service.list_teams() @@ -88,9 +104,15 @@ async def get_team( Returns: TeamResponse: Complete team information """ - log_request( - {"path": request.url.path, "method": request.method, "team_id": team_id} + auth_context = create_auth_context( + user=None, # No authentication required for getting team info + resource_type="team", + action="get", + team_id=team_id, + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await team_service.get_team(team_id) @@ -120,9 +142,16 @@ async def update_team( Returns: TeamResponse: Updated team information """ - log_request( - {"path": request.url.path, "method": request.method, "team_id": team_id, "team_data": team_data.dict()} + auth_context = create_auth_context( + user=None, # No authentication required for team updates + resource_type="team", + action="update", + team_id=team_id, + team_data=team_data.dict(), + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await team_service.update_team(team_id, team_data) @@ -150,9 +179,15 @@ async def delete_team( Returns: None (204 No Content) """ - log_request( - {"path": request.url.path, "method": request.method, "team_id": team_id} + auth_context = create_auth_context( + user=None, # No authentication required for team deletion + resource_type="team", + action="delete", + team_id=team_id, + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: await team_service.delete_team(team_id) diff --git a/src/api/v1/users.py b/src/api/v1/users.py index c63e9e4..e11d18e 100644 --- a/src/api/v1/users.py +++ b/src/api/v1/users.py @@ -5,6 +5,12 @@ from fastapi import APIRouter, Depends, HTTPException, Request, Query, status from src.dependencies import UserServiceDep from src.schemas.user import UserResponse, UserListResponse, UserCreate, UserUpdate from src.utils.logging import log_request +from src.utils.authorization import ( + create_auth_context, + log_authorization_context, + get_team_filter, + AuthorizationError +) from src.api.v1.error_handlers import handle_service_error logger = logging.getLogger(__name__) @@ -35,9 +41,15 @@ async def read_users_me( 404: User not found 500: Internal server error """ - log_request( - {"path": request.url.path, "method": request.method, "user_id": user_id} + auth_context = create_auth_context( + user=None, # No authentication required for this endpoint + resource_type="user", + action="get_by_id", + user_id=user_id, + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await user_service.get_user_by_id(user_id) @@ -72,9 +84,16 @@ async def update_current_user( 404: User not found 500: Internal server error """ - log_request( - {"path": request.url.path, "method": request.method, "user_data": user_data.dict(), "user_id": user_id} + auth_context = create_auth_context( + user=None, # No authentication required for this endpoint + resource_type="user", + action="update_by_id", + user_id=user_id, + user_data=user_data.dict(), + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await user_service.update_user_by_id(user_id, user_data) @@ -107,9 +126,15 @@ async def create_user( 404: Referenced team not found 500: Internal server error """ - log_request( - {"path": request.url.path, "method": request.method, "user_data": user_data.dict()} + auth_context = create_auth_context( + user=None, # No authentication required for user creation + resource_type="user", + action="create", + user_data=user_data.dict(), + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await user_service.create_user(user_data) @@ -141,9 +166,15 @@ async def list_users( 400: Invalid team ID format 500: Internal server error """ - log_request( - {"path": request.url.path, "method": request.method, "team_id": team_id} + auth_context = create_auth_context( + user=None, # No authentication required for listing users + resource_type="user", + action="list", + team_id=team_id, + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await user_service.list_users(team_id) @@ -175,9 +206,15 @@ async def get_user( 404: User not found 500: Internal server error """ - log_request( - {"path": request.url.path, "method": request.method, "user_id": user_id} + auth_context = create_auth_context( + user=None, # No authentication required for getting user info + resource_type="user", + action="get", + user_id=user_id, + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await user_service.get_user(user_id) @@ -212,9 +249,16 @@ async def update_user( 404: User not found 500: Internal server error """ - log_request( - {"path": request.url.path, "method": request.method, "user_id": user_id, "user_data": user_data.dict()} + auth_context = create_auth_context( + user=None, # No authentication required for user updates + resource_type="user", + action="update", + user_id=user_id, + user_data=user_data.dict(), + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: response = await user_service.update_user(user_id, user_data) @@ -246,9 +290,15 @@ async def delete_user( 404: User not found 500: Internal server error """ - log_request( - {"path": request.url.path, "method": request.method, "user_id": user_id} + auth_context = create_auth_context( + user=None, # No authentication required for user deletion + resource_type="user", + action="delete", + user_id=user_id, + path=request.url.path, + method=request.method ) + log_authorization_context(auth_context, success=True) try: await user_service.delete_user(user_id) diff --git a/src/services/auth_service.py b/src/services/auth_service.py index e55b2d2..b8a465b 100644 --- a/src/services/auth_service.py +++ b/src/services/auth_service.py @@ -1,16 +1,19 @@ import logging -from typing import Optional, Tuple -from datetime import datetime +from typing import Optional from bson import ObjectId +from src.models.api_key import ApiKeyModel +from src.models.user import UserModel +from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse from src.db.repositories.api_key_repository import api_key_repository from src.db.repositories.user_repository import user_repository from src.db.repositories.team_repository import team_repository -from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse -from src.auth.security import generate_api_key, verify_api_key, calculate_expiry_date, is_expired, hash_api_key -from src.models.api_key import ApiKeyModel -from src.models.team import TeamModel -from src.models.user import UserModel +from src.auth.security import generate_api_key, calculate_expiry_date +from src.utils.authorization import ( + require_admin, + require_resource_owner_or_admin, + AuthorizationError +) logger = logging.getLogger(__name__) @@ -28,29 +31,28 @@ class AuthService: Args: user_id: The user ID to create the key for - team_id: The team ID the user belongs to + team_id: The team ID to associate the key with key_data: The API key creation data Returns: ApiKeyWithValueResponse: The created API key with the raw key value Raises: - ValueError: If user_id or team_id are invalid + ValueError: If user_id or team_id is invalid RuntimeError: If user or team not found, or user doesn't belong to team """ - # Validate user_id and team_id try: target_user_id = ObjectId(user_id) target_team_id = ObjectId(team_id) except Exception: raise ValueError("Invalid user ID or team ID") - # Verify user exists + # Get the target user target_user = await user_repository.get_by_id(target_user_id) if not target_user: raise RuntimeError("User not found") - # Verify team exists + # Check if team exists team = await team_repository.get_by_id(target_team_id) if not team: raise RuntimeError("Team not found") @@ -115,13 +117,12 @@ class AuthService: ApiKeyWithValueResponse: The created API key with the raw key value Raises: - PermissionError: If the admin user doesn't have admin privileges + AuthorizationError: If the admin user doesn't have admin privileges ValueError: If target_user_id is invalid RuntimeError: If target user or team not found """ - # Check if current user is admin - if not admin_user.is_admin: - raise PermissionError("Admin access required") + # Centralized admin authorization check + require_admin(admin_user, "create API keys for other users") try: target_user_obj_id = ObjectId(target_user_id) @@ -213,7 +214,7 @@ class AuthService: Raises: ValueError: If key_id is invalid RuntimeError: If key not found - PermissionError: If user not authorized to revoke the key + AuthorizationError: If user not authorized to revoke the key """ try: obj_id = ObjectId(key_id) @@ -225,9 +226,8 @@ class AuthService: if not key: raise RuntimeError("API key not found") - # Check if user owns the key or is an admin - if key.user_id != user.id and not user.is_admin: - raise PermissionError("Not authorized to revoke this API key") + # Centralized authorization check - user must own the key or be admin + require_resource_owner_or_admin(user, str(key.user_id), "API key", "revoke") # Deactivate the key result = await api_key_repository.deactivate(obj_id) diff --git a/src/services/image_service.py b/src/services/image_service.py index bd65c9e..958e321 100644 --- a/src/services/image_service.py +++ b/src/services/image_service.py @@ -1,17 +1,17 @@ import logging +import os from typing import Optional, List, Tuple +from datetime import datetime from fastapi import UploadFile, Request from bson import ObjectId -import io -from src.db.repositories.image_repository import image_repository -from src.services.storage import StorageService -from src.services.image_processor import ImageProcessor -from src.services.embedding_service import EmbeddingService -from src.services.pubsub_service import pubsub_service from src.models.image import ImageModel from src.models.user import UserModel -from src.schemas.image import ImageResponse, ImageListResponse, ImageCreate, ImageUpdate +from src.schemas.image import ImageResponse, ImageListResponse +from src.db.repositories.image_repository import image_repository +from src.services.storage_service import StorageService +from src.services.embedding_service import EmbeddingService +from src.utils.authorization import require_team_access, get_team_filter, AuthorizationError logger = logging.getLogger(__name__) @@ -20,13 +20,11 @@ class ImageService: def __init__(self): self.storage_service = StorageService() - self.image_processor = ImageProcessor() self.embedding_service = EmbeddingService() def _generate_api_download_url(self, request: Request, image_id: str) -> str: """Generate API download URL for an image""" - base_url = str(request.base_url).rstrip('/') - return f"{base_url}/api/v1/images/{image_id}/download" + return f"{request.url.scheme}://{request.url.netloc}/api/v1/images/{image_id}/download" async def upload_image( self, @@ -37,7 +35,7 @@ class ImageService: collection_id: Optional[str] = None ) -> ImageResponse: """ - Upload a new image + Upload and process an image Args: file: The uploaded file @@ -47,74 +45,70 @@ class ImageService: collection_id: Optional collection ID to associate with the image Returns: - ImageResponse: The created image metadata + ImageResponse: The uploaded image metadata Raises: - ValueError: If file validation fails - RuntimeError: If upload fails + ValueError: If file is invalid + RuntimeError: If upload or processing fails """ - # Validate file type + # Validate file + if not file.filename: + raise ValueError("No filename provided") + if not file.content_type or not file.content_type.startswith('image/'): raise ValueError("File must be an image") - # Validate file size (10MB limit) - max_size = 10 * 1024 * 1024 # 10MB - content = await file.read() - if len(content) > max_size: - raise ValueError("File size exceeds 10MB limit") + # Read file content + file_content = await file.read() + if not file_content: + raise ValueError("Empty file") - # Reset file pointer - await file.seek(0) + # Generate storage path + file_extension = os.path.splitext(file.filename)[1] + storage_filename = f"{ObjectId()}{file_extension}" + storage_path = f"images/{user.team_id}/{storage_filename}" + + # Store file + try: + self.storage_service.store_file(storage_path, file_content) + except Exception as e: + logger.error(f"Failed to store file: {e}") + raise RuntimeError("Failed to store image file") + + # Create image record + image_data = { + "filename": storage_filename, + "original_filename": file.filename, + "file_size": len(file_content), + "content_type": file.content_type, + "storage_path": storage_path, + "team_id": user.team_id, + "uploader_id": user.id, + "upload_date": datetime.utcnow(), + "description": description, + "metadata": {}, + "has_embedding": False, + "collection_id": ObjectId(collection_id) if collection_id else None + } try: - # Upload to storage - storage_path, content_type, file_size, metadata = await self.storage_service.upload_file( - file, str(user.team_id) - ) - - # Create image record - image = ImageModel( - filename=file.filename, - original_filename=file.filename, - file_size=file_size, - content_type=content_type, - storage_path=storage_path, - public_url=None, # Will be set after we have the image ID - team_id=user.team_id, - uploader_id=user.id, - description=description, - metadata=metadata, - collection_id=ObjectId(collection_id) if collection_id else None - ) - - # Save to database - created_image = await image_repository.create(image) - - # Generate API download URL now that we have the image ID - api_download_url = self._generate_api_download_url(request, str(created_image.id)) - - # Update the image with the API download URL - await image_repository.update(created_image.id, {"public_url": api_download_url}) - created_image.public_url = api_download_url - - # Publish image processing task to Pub/Sub - try: - task_published = await pubsub_service.publish_image_processing_task( - image_id=str(created_image.id), - storage_path=storage_path, - team_id=str(user.team_id) - ) - if not task_published: - logger.warning(f"Failed to publish processing task for image {created_image.id}") - except Exception as e: - logger.warning(f"Failed to publish image processing task: {e}") - - # Convert to response - return self._convert_to_response(created_image, request) - + image = await image_repository.create(image_data) except Exception as e: - logger.error(f"Error uploading image: {e}") - raise RuntimeError("Failed to upload image") + # Clean up stored file if database creation fails + try: + self.storage_service.delete_file(storage_path) + except: + pass + logger.error(f"Failed to create image record: {e}") + raise RuntimeError("Failed to create image record") + + # Generate embedding asynchronously (fire and forget) + try: + await self.embedding_service.generate_image_embedding(str(image.id), file_content) + except Exception as e: + logger.warning(f"Failed to generate embedding for image {image.id}: {e}") + + return self._convert_to_response(image, request) async def list_images( self, @@ -125,49 +119,48 @@ class ImageService: collection_id: Optional[str] = None ) -> ImageListResponse: """ - List images for the user's team or all images if user is admin + List images with team-based filtering Args: user: The requesting user request: The FastAPI request object for URL generation - skip: Number of records to skip for pagination + skip: Number of records to skip limit: Maximum number of records to return - collection_id: Optional filter by collection ID + collection_id: Optional collection filter Returns: - ImageListResponse: List of images with pagination metadata + ImageListResponse: List of images with metadata """ - # Check if user is admin - if so, get all images across all teams - if user.is_admin: - images = await image_repository.get_all_with_pagination( - skip=skip, - limit=limit, - collection_id=ObjectId(collection_id) if collection_id else None, - ) - total = await image_repository.count_all( - collection_id=ObjectId(collection_id) if collection_id else None, - ) - else: - # Regular users only see images from their team - images = await image_repository.get_by_team( - user.team_id, - skip=skip, - limit=limit, - collection_id=ObjectId(collection_id) if collection_id else None, - ) - total = await image_repository.count_by_team( - user.team_id, - collection_id=ObjectId(collection_id) if collection_id else None, - ) + # Apply team filtering based on user permissions + team_filter = get_team_filter(user) - # Convert to response - response_images = [self._convert_to_response(image, request) for image in images] + # Build filters + filters = {} + if team_filter: + filters["team_id"] = ObjectId(team_filter) + if collection_id: + filters["collection_id"] = ObjectId(collection_id) - return ImageListResponse(images=response_images, total=total, skip=skip, limit=limit) + # Get images + images = await image_repository.list_with_filters(filters, skip, limit) + total = await image_repository.count_with_filters(filters) + + # Convert to responses + image_responses = [ + self._convert_to_response(image, request) + for image in images + ] + + return ImageListResponse( + images=image_responses, + total=total, + skip=skip, + limit=limit + ) async def get_image(self, image_id: str, user: UserModel, request: Request) -> ImageResponse: """ - Get image metadata by ID + Get image metadata by ID with authorization check Args: image_id: The image ID to retrieve @@ -180,7 +173,7 @@ class ImageService: Raises: ValueError: If image_id is invalid RuntimeError: If image not found - PermissionError: If user not authorized to access the image + AuthorizationError: If user not authorized to access the image """ try: obj_id = ObjectId(image_id) @@ -192,15 +185,14 @@ class ImageService: if not image: raise RuntimeError("Image not found") - # Check team access (admins can access any image) - if not user.is_admin and image.team_id != user.team_id: - raise PermissionError("Not authorized to access this image") + # Centralized team access check + require_team_access(user, str(image.team_id), "image", "access") return self._convert_to_response(image, request, include_last_accessed=True) async def download_image(self, image_id: str, user: UserModel) -> Tuple[bytes, str, str]: """ - Download image file + Download image file with authorization check Args: image_id: The image ID to download @@ -212,7 +204,7 @@ class ImageService: Raises: ValueError: If image_id is invalid RuntimeError: If image not found or file not found in storage - PermissionError: If user not authorized to access the image + AuthorizationError: If user not authorized to access the image """ try: obj_id = ObjectId(image_id) @@ -224,9 +216,8 @@ class ImageService: if not image: raise RuntimeError("Image not found") - # Check team access (admins can access any image) - if not user.is_admin and image.team_id != user.team_id: - raise PermissionError("Not authorized to access this image") + # Centralized team access check + require_team_access(user, str(image.team_id), "image", "download") # Get file from storage file_content = self.storage_service.get_file(image.storage_path) @@ -241,12 +232,12 @@ class ImageService: async def update_image( self, image_id: str, - image_data: ImageUpdate, + image_data, user: UserModel, request: Request ) -> ImageResponse: """ - Update image metadata + Update image metadata with authorization check Args: image_id: The image ID to update @@ -260,7 +251,7 @@ class ImageService: Raises: ValueError: If image_id is invalid RuntimeError: If image not found or update fails - PermissionError: If user not authorized to update the image + AuthorizationError: If user not authorized to update the image """ try: obj_id = ObjectId(image_id) @@ -272,9 +263,8 @@ class ImageService: if not image: raise RuntimeError("Image not found") - # Check team access (admins can update any image) - if not user.is_admin and image.team_id != user.team_id: - raise PermissionError("Not authorized to update this image") + # Centralized team access check + require_team_access(user, str(image.team_id), "image", "update") # Update image update_data = image_data.dict(exclude_unset=True) @@ -290,7 +280,7 @@ class ImageService: async def delete_image(self, image_id: str, user: UserModel) -> bool: """ - Delete an image + Delete an image with authorization check Args: image_id: The image ID to delete @@ -302,7 +292,7 @@ class ImageService: Raises: ValueError: If image_id is invalid RuntimeError: If image not found or deletion fails - PermissionError: If user not authorized to delete the image + AuthorizationError: If user not authorized to delete the image """ try: obj_id = ObjectId(image_id) @@ -314,9 +304,8 @@ class ImageService: if not image: raise RuntimeError("Image not found") - # Check team access (admins can delete any image) - if not user.is_admin and image.team_id != user.team_id: - raise PermissionError("Not authorized to delete this image") + # Centralized team access check + require_team_access(user, str(image.team_id), "image", "delete") # Delete from storage try: diff --git a/src/utils/authorization.py b/src/utils/authorization.py new file mode 100644 index 0000000..9d637c6 --- /dev/null +++ b/src/utils/authorization.py @@ -0,0 +1,202 @@ +""" +Centralized authorization utilities to eliminate scattered access control logic. + +This module provides reusable authorization functions that can be used across +all services and API endpoints to ensure consistent access control. +""" + +import logging +from typing import Optional, Any, Dict +from fastapi import HTTPException, status + +from src.models.user import UserModel + +logger = logging.getLogger(__name__) + +class AuthorizationError(HTTPException): + """Custom exception for authorization failures""" + + def __init__(self, detail: str, status_code: int = status.HTTP_403_FORBIDDEN): + super().__init__(status_code=status_code, detail=detail) + +class AuthorizationContext: + """Context object for authorization decisions""" + + def __init__(self, user: UserModel, resource_type: str, action: str, **kwargs): + self.user = user + self.resource_type = resource_type + self.action = action + self.metadata = kwargs + + def to_dict(self) -> Dict[str, Any]: + """Convert context to dictionary for logging""" + return { + "user_id": str(self.user.id), + "team_id": str(self.user.team_id), + "is_admin": self.user.is_admin, + "resource_type": self.resource_type, + "action": self.action, + **self.metadata + } + +def require_admin(user: UserModel, action: str = "perform admin action") -> None: + """ + Ensure user has admin privileges + + Args: + user: The user to check + action: Description of the action being performed (for error messages) + + Raises: + AuthorizationError: If user is not an admin + """ + if not user.is_admin: + logger.warning(f"Non-admin user {user.id} attempted to {action}") + raise AuthorizationError(f"Admin privileges required to {action}") + +def require_team_access(user: UserModel, resource_team_id: str, resource_type: str, action: str = "access") -> None: + """ + Ensure user can access resources from the specified team + + Args: + user: The user requesting access + resource_team_id: The team ID of the resource + resource_type: Type of resource being accessed (for error messages) + action: Action being performed (for error messages) + + Raises: + AuthorizationError: If user cannot access the resource + """ + if not user.is_admin and str(user.team_id) != str(resource_team_id): + logger.warning( + f"User {user.id} from team {user.team_id} attempted to {action} " + f"{resource_type} from team {resource_team_id}" + ) + raise AuthorizationError(f"Cannot {action} {resource_type} from different team") + +def require_resource_owner_or_admin(user: UserModel, resource_user_id: str, resource_type: str, action: str = "access") -> None: + """ + Ensure user owns the resource or is an admin + + Args: + user: The user requesting access + resource_user_id: The user ID who owns the resource + resource_type: Type of resource being accessed + action: Action being performed + + Raises: + AuthorizationError: If user is not the owner and not an admin + """ + if not user.is_admin and str(user.id) != str(resource_user_id): + logger.warning( + f"User {user.id} attempted to {action} {resource_type} " + f"owned by user {resource_user_id}" + ) + raise AuthorizationError(f"Cannot {action} {resource_type} owned by another user") + +def can_access_team_resource(user: UserModel, resource_team_id: str) -> bool: + """ + Check if user can access a team resource (non-throwing version) + + Args: + user: The user requesting access + resource_team_id: The team ID of the resource + + Returns: + True if user can access the resource + """ + return user.is_admin or str(user.team_id) == str(resource_team_id) + +def can_access_user_resource(user: UserModel, resource_user_id: str) -> bool: + """ + Check if user can access a user resource (non-throwing version) + + Args: + user: The user requesting access + resource_user_id: The user ID who owns the resource + + Returns: + True if user can access the resource + """ + return user.is_admin or str(user.id) == str(resource_user_id) + +def get_team_filter(user: UserModel) -> Optional[str]: + """ + Get team filter for queries based on user permissions + + Args: + user: The user making the request + + Returns: + Team ID to filter by, or None if admin (can see all teams) + """ + return None if user.is_admin else str(user.team_id) + +def log_authorization_context(context: AuthorizationContext, success: bool = True) -> None: + """ + Log authorization context for audit purposes + + Args: + context: Authorization context + success: Whether the authorization was successful + """ + log_data = context.to_dict() + log_data["authorization_success"] = success + + if success: + logger.info(f"Authorization granted for {context.action} on {context.resource_type}", extra=log_data) + else: + logger.warning(f"Authorization denied for {context.action} on {context.resource_type}", extra=log_data) + +def create_auth_context(user: UserModel, resource_type: str, action: str, **kwargs) -> AuthorizationContext: + """ + Create an authorization context for logging and tracking + + Args: + user: The user making the request + resource_type: Type of resource being accessed + action: Action being performed + **kwargs: Additional metadata + + Returns: + AuthorizationContext object + """ + return AuthorizationContext(user, resource_type, action, **kwargs) + +# Decorator for common authorization patterns +def authorize_team_resource(resource_type: str, action: str = "access"): + """ + Decorator to authorize team resource access + + Args: + resource_type: Type of resource + action: Action being performed + """ + def decorator(func): + async def wrapper(*args, **kwargs): + # Extract user and resource from function arguments + # This assumes the function signature includes user and a resource with team_id + user = None + resource_team_id = None + + # Find user in arguments + for arg in args: + if isinstance(arg, UserModel): + user = arg + break + + # Find resource team_id in arguments or kwargs + for arg in args: + if hasattr(arg, 'team_id'): + resource_team_id = arg.team_id + break + + if 'team_id' in kwargs: + resource_team_id = kwargs['team_id'] + + if user and resource_team_id: + require_team_access(user, resource_team_id, resource_type, action) + + return await func(*args, **kwargs) + return wrapper + return decorator \ No newline at end of file