From 8cc799320154d63ed51e96ba29b5379b40cbb636 Mon Sep 17 00:00:00 2001 From: johnpccd Date: Sun, 25 May 2025 22:28:41 +0200 Subject: [PATCH] fix errors because of pagination --- src/db/providers/firestore_provider.py | 54 +++++++++++++++++-- .../firestore_api_key_repository.py | 49 ++++++++++++++--- src/db/repositories/firestore_repository.py | 23 ++++++-- .../repositories/firestore_team_repository.py | 10 ++-- .../repositories/firestore_user_repository.py | 54 ++++++++++++++++++- src/services/auth_service.py | 17 +++--- src/services/team_service.py | 19 ++++--- src/services/user_service.py | 16 +++--- 8 files changed, 206 insertions(+), 36 deletions(-) diff --git a/src/db/providers/firestore_provider.py b/src/db/providers/firestore_provider.py index 29999e3..94f41dd 100644 --- a/src/db/providers/firestore_provider.py +++ b/src/db/providers/firestore_provider.py @@ -168,12 +168,14 @@ class FirestoreProvider: logger.error(f"Error getting document from {collection_name}: {e}") raise - async def list_documents(self, collection_name: str) -> List[Dict[str, Any]]: + async def list_documents(self, collection_name: str, skip: int = 0, limit: int = None) -> List[Dict[str, Any]]: """ - List all documents in a collection + List documents in a collection with optional pagination Args: collection_name: Collection name + skip: Number of documents to skip (default: 0) + limit: Maximum number of documents to return (default: None for all) Returns: List of documents @@ -187,8 +189,15 @@ class FirestoreProvider: # Debug log to understand the client state logger.debug(f"Firestore client: {self.client}, Collection ref: {collection_ref}") + # Build query with pagination + query = collection_ref + if skip > 0: + query = query.offset(skip) + if limit is not None: + query = query.limit(limit) + # Properly get the stream of documents - docs = collection_ref.stream() + docs = query.stream() results = [] for doc in docs: data = doc.to_dict() @@ -198,7 +207,13 @@ class FirestoreProvider: except Exception as stream_error: logger.error(f"Error streaming documents: {stream_error}") # Fallback method - try listing documents differently - docs = list(collection_ref.get()) + query = collection_ref + if skip > 0: + query = query.offset(skip) + if limit is not None: + query = query.limit(limit) + + docs = list(query.get()) results = [] for doc in docs: data = doc.to_dict() @@ -210,6 +225,37 @@ class FirestoreProvider: # Return empty list instead of raising to avoid API failures return [] + async def count_documents(self, collection_name: str) -> int: + """ + Count total number of documents in a collection + + Args: + collection_name: Collection name + + Returns: + Total number of documents + """ + try: + collection_ref = self.get_collection(collection_name) + + # Use aggregation query to count documents efficiently + # Note: This requires Firestore to support count aggregation + try: + from google.cloud.firestore_v1.aggregation import AggregationQuery + query = collection_ref.select([]) # Select no fields for efficiency + aggregation_query = AggregationQuery(query) + result = aggregation_query.count().get() + return result[0].value + except (ImportError, AttributeError): + # Fallback: count by getting all documents (less efficient) + logger.warning(f"Using fallback count method for {collection_name}") + docs = list(collection_ref.stream()) + return len(docs) + except Exception as e: + logger.error(f"Error counting documents in {collection_name}: {e}") + # Return 0 instead of raising to avoid API failures + return 0 + async def update_document(self, collection_name: str, doc_id: str, data: Dict[str, Any]) -> bool: """ Update a document diff --git a/src/db/repositories/firestore_api_key_repository.py b/src/db/repositories/firestore_api_key_repository.py index 4b5057a..ca193d0 100644 --- a/src/db/repositories/firestore_api_key_repository.py +++ b/src/db/repositories/firestore_api_key_repository.py @@ -1,5 +1,6 @@ import logging from datetime import datetime +from typing import List, Optional from bson import ObjectId from src.db.repositories.firestore_repository import FirestoreRepository from src.models.api_key import ApiKeyModel @@ -12,7 +13,7 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]): def __init__(self): super().__init__("api_keys", ApiKeyModel) - async def get_by_key_hash(self, key_hash: str) -> ApiKeyModel: + async def get_by_key_hash(self, key_hash: str) -> Optional[ApiKeyModel]: """ Get API key by hash @@ -34,7 +35,7 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]): logger.error(f"Error getting API key by hash: {e}") raise - async def get_by_user_id(self, user_id: str) -> list[ApiKeyModel]: + async def get_by_user_id(self, user_id: str) -> List[ApiKeyModel]: """ Get API keys by user ID @@ -53,17 +54,53 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]): logger.error(f"Error getting API keys by user ID: {e}") raise - async def get_by_user(self, user_id: ObjectId) -> list[ApiKeyModel]: + async def get_by_user(self, user_id: ObjectId, skip: int = 0, limit: int = None) -> List[ApiKeyModel]: """ - Get API keys by user (alias for get_by_user_id with ObjectId) + Get API keys by user with pagination + + Args: + user_id: User ID as ObjectId + skip: Number of records to skip for pagination (default: 0) + limit: Maximum number of records to return (default: None for all) + + Returns: + List of API keys + """ + try: + # For now, we'll get all API keys and filter in memory + # In a production system, this should use Firestore queries for efficiency + api_keys = await self.get_all() + filtered_keys = [api_key for api_key in api_keys if api_key.user_id == user_id] + + # Apply pagination + if skip > 0: + filtered_keys = filtered_keys[skip:] + if limit is not None: + filtered_keys = filtered_keys[:limit] + + return filtered_keys + except Exception as e: + logger.error(f"Error getting API keys by user with pagination: {e}") + raise + + async def count_by_user(self, user_id: ObjectId) -> int: + """ + Count API keys by user ID Args: user_id: User ID as ObjectId Returns: - List of API keys + Number of API keys for the user """ - return await self.get_by_user_id(str(user_id)) + try: + # For now, we'll get all API keys and filter in memory + # In a production system, this should use Firestore count queries + api_keys = await self.get_all() + return len([api_key for api_key in api_keys if api_key.user_id == user_id]) + except Exception as e: + logger.error(f"Error counting API keys by user: {e}") + raise async def update_last_used(self, api_key_id: ObjectId) -> bool: """ diff --git a/src/db/repositories/firestore_repository.py b/src/db/repositories/firestore_repository.py index cc1c366..958316a 100644 --- a/src/db/repositories/firestore_repository.py +++ b/src/db/repositories/firestore_repository.py @@ -60,15 +60,19 @@ class FirestoreRepository(Generic[T]): logger.error(f"Error getting {self.collection_name} document by ID: {e}") raise - async def get_all(self) -> List[T]: + async def get_all(self, skip: int = 0, limit: int = None) -> List[T]: """ - Get all documents from the collection + Get all documents from the collection with optional pagination + + Args: + skip: Number of documents to skip (default: 0) + limit: Maximum number of documents to return (default: None for all) Returns: List of model instances """ try: - docs = await self.provider.list_documents(self.collection_name) + docs = await self.provider.list_documents(self.collection_name, skip=skip, limit=limit) # Transform data to handle legacy format issues transformed_docs = [] @@ -159,4 +163,17 @@ class FirestoreRepository(Generic[T]): return await self.provider.delete_document(self.collection_name, str(doc_id)) except Exception as e: logger.error(f"Error deleting {self.collection_name} document: {e}") + raise + + async def count(self) -> int: + """ + Get total count of documents in the collection + + Returns: + Total number of documents + """ + try: + return await self.provider.count_documents(self.collection_name) + except Exception as e: + logger.error(f"Error counting {self.collection_name} documents: {e}") raise \ No newline at end of file diff --git a/src/db/repositories/firestore_team_repository.py b/src/db/repositories/firestore_team_repository.py index fd3eb58..483071b 100644 --- a/src/db/repositories/firestore_team_repository.py +++ b/src/db/repositories/firestore_team_repository.py @@ -24,14 +24,18 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]): """ return await super().get_by_id(team_id) - async def get_all(self) -> List[TeamModel]: + async def get_all(self, skip: int = 0, limit: int = None) -> List[TeamModel]: """ - Get all teams + Get all teams with pagination + + Args: + skip: Number of records to skip for pagination (default: 0) + limit: Maximum number of records to return (default: None for all) Returns: List of teams """ - return await super().get_all() + return await super().get_all(skip=skip, limit=limit) async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]: """ diff --git a/src/db/repositories/firestore_user_repository.py b/src/db/repositories/firestore_user_repository.py index 3b6922a..5947e96 100644 --- a/src/db/repositories/firestore_user_repository.py +++ b/src/db/repositories/firestore_user_repository.py @@ -1,4 +1,6 @@ import logging +from typing import List, Optional +from bson import ObjectId from src.db.repositories.firestore_repository import FirestoreRepository from src.models.user import UserModel @@ -10,7 +12,7 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]): def __init__(self): super().__init__("users", UserModel) - async def get_by_email(self, email: str) -> UserModel: + async def get_by_email(self, email: str) -> Optional[UserModel]: """ Get user by email @@ -32,7 +34,7 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]): logger.error(f"Error getting user by email: {e}") raise - async def get_by_team_id(self, team_id: str) -> list[UserModel]: + async def get_by_team_id(self, team_id: str) -> List[UserModel]: """ Get users by team ID @@ -51,5 +53,53 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]): logger.error(f"Error getting users by team ID: {e}") raise + async def get_by_team(self, team_id: ObjectId, skip: int = 0, limit: int = None) -> List[UserModel]: + """ + Get users by team ID with pagination + + Args: + team_id: Team ID as ObjectId + skip: Number of records to skip for pagination (default: 0) + limit: Maximum number of records to return (default: None for all) + + Returns: + List of users + """ + try: + # For now, we'll get all users and filter in memory + # In a production system, this should use Firestore queries for efficiency + users = await self.get_all() + filtered_users = [user for user in users if user.team_id == team_id] + + # Apply pagination + if skip > 0: + filtered_users = filtered_users[skip:] + if limit is not None: + filtered_users = filtered_users[:limit] + + return filtered_users + except Exception as e: + logger.error(f"Error getting users by team with pagination: {e}") + raise + + async def count_by_team(self, team_id: ObjectId) -> int: + """ + Count users by team ID + + Args: + team_id: Team ID as ObjectId + + Returns: + Number of users in the team + """ + try: + # For now, we'll get all users and filter in memory + # In a production system, this should use Firestore count queries + users = await self.get_all() + return len([user for user in users if user.team_id == team_id]) + except Exception as e: + logger.error(f"Error counting users by team: {e}") + raise + # Create a singleton repository firestore_user_repository = FirestoreUserRepository() \ No newline at end of file diff --git a/src/services/auth_service.py b/src/services/auth_service.py index b8a465b..fc402f2 100644 --- a/src/services/auth_service.py +++ b/src/services/auth_service.py @@ -170,18 +170,23 @@ class AuthService: is_active=created_key.is_active ) - async def list_user_api_keys(self, user: UserModel) -> ApiKeyListResponse: + async def list_user_api_keys(self, user: UserModel, skip: int = 0, limit: int = 50) -> ApiKeyListResponse: """ - List API keys for a specific user + List API keys for a specific user with pagination Args: user: The user to list API keys for + skip: Number of records to skip for pagination (default: 0) + limit: Maximum number of records to return (default: 50) Returns: - ApiKeyListResponse: List of API keys for the user + ApiKeyListResponse: Paginated list of API keys for the user """ - # Get API keys for user - keys = await api_key_repository.get_by_user(user.id) + # Get API keys for user with pagination + keys = await api_key_repository.get_by_user(user.id, skip=skip, limit=limit) + + # Get total count for pagination + total_count = await api_key_repository.count_by_user(user.id) # Convert to response models response_keys = [] @@ -198,7 +203,7 @@ class AuthService: is_active=key.is_active )) - return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys)) + return ApiKeyListResponse(api_keys=response_keys, total=total_count) async def revoke_api_key(self, key_id: str, user: UserModel) -> bool: """ diff --git a/src/services/team_service.py b/src/services/team_service.py index 9e301aa..6276fcf 100644 --- a/src/services/team_service.py +++ b/src/services/team_service.py @@ -38,15 +38,22 @@ class TeamService: updated_at=created_team.updated_at ) - async def list_teams(self) -> TeamListResponse: + async def list_teams(self, skip: int = 0, limit: int = 50) -> TeamListResponse: """ - List all teams + List all teams with pagination + + Args: + skip: Number of records to skip for pagination (default: 0) + limit: Maximum number of records to return (default: 50) Returns: - TeamListResponse: List of all teams + TeamListResponse: Paginated list of teams """ - # Get all teams - teams = await team_repository.get_all() + # Get teams with pagination + teams = await team_repository.get_all(skip=skip, limit=limit) + + # Get total count for pagination + total_count = await team_repository.count() # Convert to response models response_teams = [] @@ -59,7 +66,7 @@ class TeamService: updated_at=team.updated_at )) - return TeamListResponse(teams=response_teams, total=len(response_teams)) + return TeamListResponse(teams=response_teams, total=total_count) async def get_team(self, team_id: str) -> TeamResponse: """ diff --git a/src/services/user_service.py b/src/services/user_service.py index fb19b34..31e8620 100644 --- a/src/services/user_service.py +++ b/src/services/user_service.py @@ -150,15 +150,17 @@ class UserService: updated_at=created_user.updated_at ) - async def list_users(self, team_id: Optional[str] = None) -> UserListResponse: + async def list_users(self, skip: int = 0, limit: int = 50, team_id: Optional[str] = None) -> UserListResponse: """ - List users, optionally filtered by team + List users with pagination, optionally filtered by team Args: + skip: Number of records to skip for pagination (default: 0) + limit: Maximum number of records to return (default: 50) team_id: Optional team ID to filter by Returns: - UserListResponse: List of users + UserListResponse: Paginated list of users Raises: ValueError: If team_id is invalid @@ -167,11 +169,13 @@ class UserService: if team_id: try: filter_team_id = ObjectId(team_id) - users = await user_repository.get_by_team(filter_team_id) + users = await user_repository.get_by_team(filter_team_id, skip=skip, limit=limit) + total_count = await user_repository.count_by_team(filter_team_id) except Exception: raise ValueError("Invalid team ID") else: - users = await user_repository.get_all() + users = await user_repository.get_all(skip=skip, limit=limit) + total_count = await user_repository.count() # Convert to response response_users = [] @@ -187,7 +191,7 @@ class UserService: updated_at=user.updated_at )) - return UserListResponse(users=response_users, total=len(response_users)) + return UserListResponse(users=response_users, total=total_count) async def get_user(self, user_id: str) -> UserResponse: """