refactor dependency injection

This commit is contained in:
johnpccd 2025-05-25 19:29:22 +02:00
parent 71bafe0938
commit 43c2bcce83
6 changed files with 173 additions and 45 deletions

View File

@ -4,7 +4,7 @@ from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, Header, Request, Query from fastapi import APIRouter, Depends, HTTPException, Header, Request, Query
from bson import ObjectId from bson import ObjectId
from src.services.auth_service import AuthService from src.dependencies import AuthServiceDep
from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse
from src.schemas.team import TeamCreate from src.schemas.team import TeamCreate
from src.schemas.user import UserCreate from src.schemas.user import UserCreate
@ -18,15 +18,13 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["Authentication"], prefix="/auth") router = APIRouter(tags=["Authentication"], prefix="/auth")
# Initialize service
auth_service = AuthService()
@router.post("/api-keys", response_model=ApiKeyWithValueResponse, status_code=201) @router.post("/api-keys", response_model=ApiKeyWithValueResponse, status_code=201)
async def create_api_key( async def create_api_key(
key_data: ApiKeyCreate, key_data: ApiKeyCreate,
request: Request, request: Request,
user_id: str = Query(..., description="User ID for the API key"), user_id: str = Query(..., description="User ID for the API key"),
team_id: str = Query(..., description="Team ID for the API key") team_id: str = Query(..., description="Team ID for the API key"),
auth_service: AuthServiceDep = Depends()
): ):
""" """
Create a new API key for a specific user and team Create a new API key for a specific user and team
@ -38,6 +36,7 @@ async def create_api_key(
key_data: API key creation data including name and description key_data: API key creation data including name and description
user_id: The user ID to create the key for user_id: The user ID to create the key for
team_id: The team ID the user belongs to team_id: The team ID the user belongs to
auth_service: Injected authentication service
Returns: Returns:
ApiKeyWithValueResponse: The created API key with the raw key value ApiKeyWithValueResponse: The created API key with the raw key value
@ -76,7 +75,8 @@ async def create_api_key_for_user(
user_id: str, user_id: str,
key_data: ApiKeyCreate, key_data: ApiKeyCreate,
request: Request, request: Request,
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
auth_service: AuthServiceDep = Depends()
): ):
""" """
Create a new API key for a specific user (admin only) Create a new API key for a specific user (admin only)
@ -88,6 +88,7 @@ async def create_api_key_for_user(
user_id: The target user ID to create the key for user_id: The target user ID to create the key for
key_data: API key creation data including name and description key_data: API key creation data including name and description
current_user: The authenticated admin user current_user: The authenticated admin user
auth_service: Injected authentication service
Returns: Returns:
ApiKeyWithValueResponse: The created API key with the raw key value ApiKeyWithValueResponse: The created API key with the raw key value
@ -129,7 +130,8 @@ async def create_api_key_for_user(
@router.get("/api-keys", response_model=ApiKeyListResponse) @router.get("/api-keys", response_model=ApiKeyListResponse)
async def list_api_keys( async def list_api_keys(
request: Request, request: Request,
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
auth_service: AuthServiceDep = Depends()
): ):
""" """
List API keys for the current authenticated user List API keys for the current authenticated user
@ -138,6 +140,7 @@ async def list_api_keys(
Args: Args:
current_user: The authenticated user current_user: The authenticated user
auth_service: Injected authentication service
Returns: Returns:
ApiKeyListResponse: List of API keys with metadata ApiKeyListResponse: List of API keys with metadata
@ -163,7 +166,8 @@ async def list_api_keys(
async def revoke_api_key( async def revoke_api_key(
key_id: str, key_id: str,
request: Request, request: Request,
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
auth_service: AuthServiceDep = Depends()
): ):
""" """
Revoke (deactivate) an API key Revoke (deactivate) an API key
@ -173,6 +177,7 @@ async def revoke_api_key(
Args: Args:
key_id: The ID of the API key to revoke key_id: The ID of the API key to revoke
current_user: The authenticated user current_user: The authenticated user
auth_service: Injected authentication service
Returns: Returns:
None (204 No Content) None (204 No Content)
@ -209,7 +214,8 @@ async def revoke_api_key(
@router.get("/verify", status_code=200) @router.get("/verify", status_code=200)
async def verify_authentication( async def verify_authentication(
request: Request, request: Request,
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
auth_service: AuthServiceDep = Depends()
): ):
""" """
Verify the current authentication status Verify the current authentication status
@ -219,6 +225,7 @@ async def verify_authentication(
Args: Args:
current_user: The authenticated user current_user: The authenticated user
auth_service: Injected authentication service
Returns: Returns:
dict: Authentication verification response with user details dict: Authentication verification response with user details

View File

@ -6,7 +6,7 @@ from bson import ObjectId
import io import io
from src.auth.security import get_current_user from src.auth.security import get_current_user
from src.services.image_service import ImageService from src.dependencies import ImageServiceDep
from src.models.user import UserModel from src.models.user import UserModel
from src.schemas.image import ImageResponse, ImageListResponse, ImageCreate, ImageUpdate from src.schemas.image import ImageResponse, ImageListResponse, ImageCreate, ImageUpdate
from src.utils.logging import log_request from src.utils.logging import log_request
@ -15,16 +15,14 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["Images"], prefix="/images") router = APIRouter(tags=["Images"], prefix="/images")
# Initialize service
image_service = ImageService()
@router.post("", response_model=ImageResponse, status_code=201) @router.post("", response_model=ImageResponse, status_code=201)
async def upload_image( async def upload_image(
request: Request, request: Request,
file: UploadFile = File(..., description="Image file to upload"), file: UploadFile = File(..., description="Image file to upload"),
description: Optional[str] = Query(None, description="Optional description for the image"), description: Optional[str] = Query(None, description="Optional description for the image"),
collection_id: Optional[str] = Query(None, description="Optional collection ID to associate with the image"), collection_id: Optional[str] = Query(None, description="Optional collection ID to associate with the image"),
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
image_service: ImageServiceDep = Depends()
): ):
""" """
Upload a new image Upload a new image
@ -38,6 +36,7 @@ async def upload_image(
description: Optional description for the image description: Optional description for the image
collection_id: Optional collection ID to organize the image collection_id: Optional collection ID to organize the image
current_user: The authenticated user uploading the image current_user: The authenticated user uploading the image
image_service: Injected image service
Returns: Returns:
ImageResponse: The uploaded image metadata and processing status ImageResponse: The uploaded image metadata and processing status
@ -79,7 +78,8 @@ async def list_images(
skip: int = Query(0, ge=0, description="Number of records to skip for pagination"), skip: int = Query(0, ge=0, description="Number of records to skip for pagination"),
limit: int = Query(50, ge=1, le=100, description="Maximum number of records to return (1-100)"), limit: int = Query(50, ge=1, le=100, description="Maximum number of records to return (1-100)"),
collection_id: Optional[str] = Query(None, description="Filter by collection ID"), collection_id: Optional[str] = Query(None, description="Filter by collection ID"),
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
image_service: ImageServiceDep = Depends()
): ):
""" """
List images for the current user's team or all images if admin List images for the current user's team or all images if admin
@ -92,6 +92,7 @@ async def list_images(
limit: Maximum number of records to return, 1-100 (default: 50) limit: Maximum number of records to return, 1-100 (default: 50)
collection_id: Optional filter by collection ID collection_id: Optional filter by collection ID
current_user: The authenticated user current_user: The authenticated user
image_service: Injected image service
Returns: Returns:
ImageListResponse: Paginated list of images with metadata ImageListResponse: Paginated list of images with metadata
@ -128,7 +129,8 @@ async def list_images(
async def get_image( async def get_image(
image_id: str, image_id: str,
request: Request, request: Request,
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
image_service: ImageServiceDep = Depends()
): ):
""" """
Get image metadata by ID Get image metadata by ID
@ -139,6 +141,7 @@ async def get_image(
Args: Args:
image_id: The image ID to retrieve image_id: The image ID to retrieve
current_user: The authenticated user current_user: The authenticated user
image_service: Injected image service
Returns: Returns:
ImageResponse: Complete image metadata ImageResponse: Complete image metadata
@ -181,7 +184,8 @@ async def get_image(
async def download_image( async def download_image(
image_id: str, image_id: str,
request: Request, request: Request,
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
image_service: ImageServiceDep = Depends()
): ):
""" """
Download image file Download image file
@ -192,6 +196,7 @@ async def download_image(
Args: Args:
image_id: The image ID to download image_id: The image ID to download
current_user: The authenticated user current_user: The authenticated user
image_service: Injected image service
Returns: Returns:
StreamingResponse: The image file as a download StreamingResponse: The image file as a download
@ -242,7 +247,8 @@ async def update_image(
image_id: str, image_id: str,
image_data: ImageUpdate, image_data: ImageUpdate,
request: Request, request: Request,
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
image_service: ImageServiceDep = Depends()
): ):
""" """
Update image metadata Update image metadata
@ -254,6 +260,7 @@ async def update_image(
image_id: The image ID to update image_id: The image ID to update
image_data: The image update data image_data: The image update data
current_user: The authenticated user current_user: The authenticated user
image_service: Injected image service
Returns: Returns:
ImageResponse: Updated image metadata ImageResponse: Updated image metadata
@ -297,7 +304,8 @@ async def update_image(
async def delete_image( async def delete_image(
image_id: str, image_id: str,
request: Request, request: Request,
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
image_service: ImageServiceDep = Depends()
): ):
""" """
Delete an image Delete an image
@ -308,6 +316,7 @@ async def delete_image(
Args: Args:
image_id: The image ID to delete image_id: The image ID to delete
current_user: The authenticated user current_user: The authenticated user
image_service: Injected image service
Returns: Returns:
None (204 No Content) None (204 No Content)

View File

@ -3,7 +3,7 @@ from typing import Optional, List, Dict, Any
from fastapi import APIRouter, Depends, Query, Request, HTTPException from fastapi import APIRouter, Depends, Query, Request, HTTPException
from src.auth.security import get_current_user from src.auth.security import get_current_user
from src.services.search_service import SearchService from src.dependencies import SearchServiceDep
from src.models.user import UserModel from src.models.user import UserModel
from src.schemas.search import SearchResponse, SearchRequest from src.schemas.search import SearchResponse, SearchRequest
from src.utils.logging import log_request from src.utils.logging import log_request
@ -12,9 +12,6 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["Search"], prefix="/search") router = APIRouter(tags=["Search"], prefix="/search")
# Initialize service
search_service = SearchService()
@router.get("", response_model=SearchResponse) @router.get("", response_model=SearchResponse)
async def search_images( async def search_images(
request: Request, request: Request,
@ -22,7 +19,8 @@ async def search_images(
limit: int = Query(10, ge=1, le=50, description="Number of results to return (1-50)"), limit: int = Query(10, ge=1, le=50, description="Number of results to return (1-50)"),
similarity_threshold: float = Query(0.65, ge=0.0, le=1.0, description="Similarity threshold (0.0-1.0)"), similarity_threshold: float = Query(0.65, ge=0.0, le=1.0, description="Similarity threshold (0.0-1.0)"),
collection_id: Optional[str] = Query(None, description="Filter results by collection ID"), collection_id: Optional[str] = Query(None, description="Filter results by collection ID"),
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
search_service: SearchServiceDep = Depends()
): ):
""" """
Search for images using semantic similarity Search for images using semantic similarity
@ -37,6 +35,7 @@ async def search_images(
similarity_threshold: Minimum similarity score (0.0-1.0, default: 0.65) similarity_threshold: Minimum similarity score (0.0-1.0, default: 0.65)
collection_id: Optional filter to search within a specific collection collection_id: Optional filter to search within a specific collection
current_user: The authenticated user performing the search current_user: The authenticated user performing the search
search_service: Injected search service
Returns: Returns:
SearchResponse: List of matching images with similarity scores SearchResponse: List of matching images with similarity scores
@ -84,7 +83,8 @@ async def search_images(
async def search_images_advanced( async def search_images_advanced(
search_request: SearchRequest, search_request: SearchRequest,
request: Request, request: Request,
current_user: UserModel = Depends(get_current_user) current_user: UserModel = Depends(get_current_user),
search_service: SearchServiceDep = Depends()
): ):
""" """
Advanced search for images with extended options Advanced search for images with extended options
@ -96,6 +96,7 @@ async def search_images_advanced(
Args: Args:
search_request: Advanced search request with detailed parameters search_request: Advanced search request with detailed parameters
current_user: The authenticated user performing the search current_user: The authenticated user performing the search
search_service: Injected search service
Returns: Returns:
SearchResponse: List of matching images with similarity scores and metadata SearchResponse: List of matching images with similarity scores and metadata

View File

@ -2,7 +2,7 @@ import logging
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from bson import ObjectId from bson import ObjectId
from src.services.team_service import TeamService from src.dependencies import TeamServiceDep
from src.schemas.team import TeamCreate, TeamUpdate, TeamResponse, TeamListResponse from src.schemas.team import TeamCreate, TeamUpdate, TeamResponse, TeamListResponse
from src.utils.logging import log_request from src.utils.logging import log_request
@ -10,13 +10,11 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["Teams"], prefix="/teams") router = APIRouter(tags=["Teams"], prefix="/teams")
# Initialize service
team_service = TeamService()
@router.post("", response_model=TeamResponse, status_code=201) @router.post("", response_model=TeamResponse, status_code=201)
async def create_team( async def create_team(
team_data: TeamCreate, team_data: TeamCreate,
request: Request request: Request,
team_service: TeamServiceDep = Depends()
): ):
""" """
Create a new team Create a new team
@ -26,6 +24,7 @@ async def create_team(
Args: Args:
team_data: Team creation data including name and description team_data: Team creation data including name and description
team_service: Injected team service
Returns: Returns:
TeamResponse: The created team information TeamResponse: The created team information
@ -50,13 +49,19 @@ async def create_team(
raise HTTPException(status_code=500, detail="Internal server error") raise HTTPException(status_code=500, detail="Internal server error")
@router.get("", response_model=TeamListResponse) @router.get("", response_model=TeamListResponse)
async def list_teams(request: Request): async def list_teams(
request: Request,
team_service: TeamServiceDep = Depends()
):
""" """
List all teams List all teams
Retrieves a complete list of all teams in the system with their Retrieves a complete list of all teams in the system with their
basic information and member counts. basic information and member counts.
Args:
team_service: Injected team service
Returns: Returns:
TeamListResponse: List of all teams with total count TeamListResponse: List of all teams with total count
@ -78,7 +83,8 @@ async def list_teams(request: Request):
@router.get("/{team_id}", response_model=TeamResponse) @router.get("/{team_id}", response_model=TeamResponse)
async def get_team( async def get_team(
team_id: str, team_id: str,
request: Request request: Request,
team_service: TeamServiceDep = Depends()
): ):
""" """
Get a team by ID Get a team by ID
@ -88,6 +94,7 @@ async def get_team(
Args: Args:
team_id: The team ID to retrieve team_id: The team ID to retrieve
team_service: Injected team service
Returns: Returns:
TeamResponse: Complete team information TeamResponse: Complete team information
@ -119,7 +126,8 @@ async def get_team(
async def update_team( async def update_team(
team_id: str, team_id: str,
team_data: TeamUpdate, team_data: TeamUpdate,
request: Request request: Request,
team_service: TeamServiceDep = Depends()
): ):
""" """
Update a team Update a team
@ -130,6 +138,7 @@ async def update_team(
Args: Args:
team_id: The team ID to update team_id: The team ID to update
team_data: The team update data team_data: The team update data
team_service: Injected team service
Returns: Returns:
TeamResponse: Updated team information TeamResponse: Updated team information
@ -160,7 +169,8 @@ async def update_team(
@router.delete("/{team_id}", status_code=204) @router.delete("/{team_id}", status_code=204)
async def delete_team( async def delete_team(
team_id: str, team_id: str,
request: Request request: Request,
team_service: TeamServiceDep = Depends()
): ):
""" """
Delete a team Delete a team
@ -170,6 +180,7 @@ async def delete_team(
Args: Args:
team_id: The team ID to delete team_id: The team ID to delete
team_service: Injected team service
Returns: Returns:
None (204 No Content) None (204 No Content)

View File

@ -2,7 +2,7 @@ import logging
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Request, Query from fastapi import APIRouter, Depends, HTTPException, Request, Query
from src.services.user_service import UserService from src.dependencies import UserServiceDep
from src.schemas.user import UserResponse, UserListResponse, UserCreate, UserUpdate from src.schemas.user import UserResponse, UserListResponse, UserCreate, UserUpdate
from src.utils.logging import log_request from src.utils.logging import log_request
@ -10,13 +10,11 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["Users"], prefix="/users") router = APIRouter(tags=["Users"], prefix="/users")
# Initialize service
user_service = UserService()
@router.get("/me", response_model=UserResponse) @router.get("/me", response_model=UserResponse)
async def read_users_me( async def read_users_me(
request: Request, request: Request,
user_id: str = Query(..., description="User ID to retrieve information for") user_id: str = Query(..., description="User ID to retrieve information for"),
user_service: UserServiceDep = Depends()
): ):
""" """
Get user information by user ID Get user information by user ID
@ -26,6 +24,7 @@ async def read_users_me(
Args: Args:
user_id: The user ID to retrieve information for user_id: The user ID to retrieve information for
user_service: Injected user service
Returns: Returns:
UserResponse: Complete user information including profile data UserResponse: Complete user information including profile data
@ -57,7 +56,8 @@ async def read_users_me(
async def update_current_user( async def update_current_user(
user_data: UserUpdate, user_data: UserUpdate,
request: Request, request: Request,
user_id: str = Query(..., description="User ID to update") user_id: str = Query(..., description="User ID to update"),
user_service: UserServiceDep = Depends()
): ):
""" """
Update user information by user ID Update user information by user ID
@ -68,6 +68,7 @@ async def update_current_user(
Args: Args:
user_data: The user update data containing fields to modify user_data: The user update data containing fields to modify
user_id: The user ID to update user_id: The user ID to update
user_service: Injected user service
Returns: Returns:
UserResponse: Updated user information UserResponse: Updated user information
@ -98,7 +99,8 @@ async def update_current_user(
@router.post("", response_model=UserResponse, status_code=201) @router.post("", response_model=UserResponse, status_code=201)
async def create_user( async def create_user(
user_data: UserCreate, user_data: UserCreate,
request: Request request: Request,
user_service: UserServiceDep = Depends()
): ):
""" """
Create a new user Create a new user
@ -108,6 +110,7 @@ async def create_user(
Args: Args:
user_data: User creation data including name, email, and team assignment user_data: User creation data including name, email, and team assignment
user_service: Injected user service
Returns: Returns:
UserResponse: The created user information UserResponse: The created user information
@ -138,7 +141,8 @@ async def create_user(
@router.get("", response_model=UserListResponse) @router.get("", response_model=UserListResponse)
async def list_users( async def list_users(
request: Request, request: Request,
team_id: Optional[str] = Query(None, description="Filter users by team ID") team_id: Optional[str] = Query(None, description="Filter users by team ID"),
user_service: UserServiceDep = Depends()
): ):
""" """
List users with optional team filtering List users with optional team filtering
@ -148,6 +152,7 @@ async def list_users(
Args: Args:
team_id: Optional team ID to filter users by team_id: Optional team ID to filter users by
user_service: Injected user service
Returns: Returns:
UserListResponse: List of users with total count UserListResponse: List of users with total count
@ -174,7 +179,8 @@ async def list_users(
@router.get("/{user_id}", response_model=UserResponse) @router.get("/{user_id}", response_model=UserResponse)
async def get_user( async def get_user(
user_id: str, user_id: str,
request: Request request: Request,
user_service: UserServiceDep = Depends()
): ):
""" """
Get user by ID Get user by ID
@ -183,6 +189,7 @@ async def get_user(
Args: Args:
user_id: The user ID to retrieve user_id: The user ID to retrieve
user_service: Injected user service
Returns: Returns:
UserResponse: Complete user information UserResponse: Complete user information
@ -214,7 +221,8 @@ async def get_user(
async def update_user( async def update_user(
user_id: str, user_id: str,
user_data: UserUpdate, user_data: UserUpdate,
request: Request request: Request,
user_service: UserServiceDep = Depends()
): ):
""" """
Update user by ID Update user by ID
@ -225,6 +233,7 @@ async def update_user(
Args: Args:
user_id: The user ID to update user_id: The user ID to update
user_data: The user update data user_data: The user update data
user_service: Injected user service
Returns: Returns:
UserResponse: Updated user information UserResponse: Updated user information
@ -255,7 +264,8 @@ async def update_user(
@router.delete("/{user_id}", status_code=204) @router.delete("/{user_id}", status_code=204)
async def delete_user( async def delete_user(
user_id: str, user_id: str,
request: Request request: Request,
user_service: UserServiceDep = Depends()
): ):
""" """
Delete user by ID Delete user by ID
@ -264,6 +274,7 @@ async def delete_user(
Args: Args:
user_id: The user ID to delete user_id: The user ID to delete
user_service: Injected user service
Returns: Returns:
None (204 No Content) None (204 No Content)

89
src/dependencies.py Normal file
View File

@ -0,0 +1,89 @@
"""
Dependency injection module for services.
This module provides dependency injection for all services used across the API.
It follows the dependency injection pattern to improve testability and maintainability.
"""
from functools import lru_cache
from typing import Annotated
from fastapi import Depends
from src.services.auth_service import AuthService
from src.services.image_service import ImageService
from src.services.search_service import SearchService
from src.services.team_service import TeamService
from src.services.user_service import UserService
@lru_cache()
def get_auth_service() -> AuthService:
"""
Get AuthService instance.
Uses LRU cache to ensure singleton behavior for the service instance.
Returns:
AuthService: The authentication service instance
"""
return AuthService()
@lru_cache()
def get_image_service() -> ImageService:
"""
Get ImageService instance.
Uses LRU cache to ensure singleton behavior for the service instance.
Returns:
ImageService: The image service instance
"""
return ImageService()
@lru_cache()
def get_search_service() -> SearchService:
"""
Get SearchService instance.
Uses LRU cache to ensure singleton behavior for the service instance.
Returns:
SearchService: The search service instance
"""
return SearchService()
@lru_cache()
def get_team_service() -> TeamService:
"""
Get TeamService instance.
Uses LRU cache to ensure singleton behavior for the service instance.
Returns:
TeamService: The team service instance
"""
return TeamService()
@lru_cache()
def get_user_service() -> UserService:
"""
Get UserService instance.
Uses LRU cache to ensure singleton behavior for the service instance.
Returns:
UserService: The user service instance
"""
return UserService()
# Type aliases for dependency injection
AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)]
ImageServiceDep = Annotated[ImageService, Depends(get_image_service)]
SearchServiceDep = Annotated[SearchService, Depends(get_search_service)]
TeamServiceDep = Annotated[TeamService, Depends(get_team_service)]
UserServiceDep = Annotated[UserService, Depends(get_user_service)]