fix errors because of pagination

This commit is contained in:
johnpccd 2025-05-25 22:28:41 +02:00
parent 80d8a74b12
commit 8cc7993201
8 changed files with 206 additions and 36 deletions

View File

@ -168,12 +168,14 @@ class FirestoreProvider:
logger.error(f"Error getting document from {collection_name}: {e}") logger.error(f"Error getting document from {collection_name}: {e}")
raise raise
async def list_documents(self, collection_name: str) -> List[Dict[str, Any]]: async def list_documents(self, collection_name: str, skip: int = 0, limit: int = None) -> List[Dict[str, Any]]:
""" """
List all documents in a collection List documents in a collection with optional pagination
Args: Args:
collection_name: Collection name collection_name: Collection name
skip: Number of documents to skip (default: 0)
limit: Maximum number of documents to return (default: None for all)
Returns: Returns:
List of documents List of documents
@ -187,8 +189,15 @@ class FirestoreProvider:
# Debug log to understand the client state # Debug log to understand the client state
logger.debug(f"Firestore client: {self.client}, Collection ref: {collection_ref}") logger.debug(f"Firestore client: {self.client}, Collection ref: {collection_ref}")
# Build query with pagination
query = collection_ref
if skip > 0:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
# Properly get the stream of documents # Properly get the stream of documents
docs = collection_ref.stream() docs = query.stream()
results = [] results = []
for doc in docs: for doc in docs:
data = doc.to_dict() data = doc.to_dict()
@ -198,7 +207,13 @@ class FirestoreProvider:
except Exception as stream_error: except Exception as stream_error:
logger.error(f"Error streaming documents: {stream_error}") logger.error(f"Error streaming documents: {stream_error}")
# Fallback method - try listing documents differently # Fallback method - try listing documents differently
docs = list(collection_ref.get()) query = collection_ref
if skip > 0:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
docs = list(query.get())
results = [] results = []
for doc in docs: for doc in docs:
data = doc.to_dict() data = doc.to_dict()
@ -210,6 +225,37 @@ class FirestoreProvider:
# Return empty list instead of raising to avoid API failures # Return empty list instead of raising to avoid API failures
return [] return []
async def count_documents(self, collection_name: str) -> int:
"""
Count total number of documents in a collection
Args:
collection_name: Collection name
Returns:
Total number of documents
"""
try:
collection_ref = self.get_collection(collection_name)
# Use aggregation query to count documents efficiently
# Note: This requires Firestore to support count aggregation
try:
from google.cloud.firestore_v1.aggregation import AggregationQuery
query = collection_ref.select([]) # Select no fields for efficiency
aggregation_query = AggregationQuery(query)
result = aggregation_query.count().get()
return result[0].value
except (ImportError, AttributeError):
# Fallback: count by getting all documents (less efficient)
logger.warning(f"Using fallback count method for {collection_name}")
docs = list(collection_ref.stream())
return len(docs)
except Exception as e:
logger.error(f"Error counting documents in {collection_name}: {e}")
# Return 0 instead of raising to avoid API failures
return 0
async def update_document(self, collection_name: str, doc_id: str, data: Dict[str, Any]) -> bool: async def update_document(self, collection_name: str, doc_id: str, data: Dict[str, Any]) -> bool:
""" """
Update a document Update a document

View File

@ -1,5 +1,6 @@
import logging import logging
from datetime import datetime from datetime import datetime
from typing import List, Optional
from bson import ObjectId from bson import ObjectId
from src.db.repositories.firestore_repository import FirestoreRepository from src.db.repositories.firestore_repository import FirestoreRepository
from src.models.api_key import ApiKeyModel from src.models.api_key import ApiKeyModel
@ -12,7 +13,7 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
def __init__(self): def __init__(self):
super().__init__("api_keys", ApiKeyModel) super().__init__("api_keys", ApiKeyModel)
async def get_by_key_hash(self, key_hash: str) -> ApiKeyModel: async def get_by_key_hash(self, key_hash: str) -> Optional[ApiKeyModel]:
""" """
Get API key by hash Get API key by hash
@ -34,7 +35,7 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
logger.error(f"Error getting API key by hash: {e}") logger.error(f"Error getting API key by hash: {e}")
raise raise
async def get_by_user_id(self, user_id: str) -> list[ApiKeyModel]: async def get_by_user_id(self, user_id: str) -> List[ApiKeyModel]:
""" """
Get API keys by user ID Get API keys by user ID
@ -53,17 +54,53 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
logger.error(f"Error getting API keys by user ID: {e}") logger.error(f"Error getting API keys by user ID: {e}")
raise raise
async def get_by_user(self, user_id: ObjectId) -> list[ApiKeyModel]: async def get_by_user(self, user_id: ObjectId, skip: int = 0, limit: int = None) -> List[ApiKeyModel]:
""" """
Get API keys by user (alias for get_by_user_id with ObjectId) Get API keys by user with pagination
Args:
user_id: User ID as ObjectId
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: None for all)
Returns:
List of API keys
"""
try:
# For now, we'll get all API keys and filter in memory
# In a production system, this should use Firestore queries for efficiency
api_keys = await self.get_all()
filtered_keys = [api_key for api_key in api_keys if api_key.user_id == user_id]
# Apply pagination
if skip > 0:
filtered_keys = filtered_keys[skip:]
if limit is not None:
filtered_keys = filtered_keys[:limit]
return filtered_keys
except Exception as e:
logger.error(f"Error getting API keys by user with pagination: {e}")
raise
async def count_by_user(self, user_id: ObjectId) -> int:
"""
Count API keys by user ID
Args: Args:
user_id: User ID as ObjectId user_id: User ID as ObjectId
Returns: Returns:
List of API keys Number of API keys for the user
""" """
return await self.get_by_user_id(str(user_id)) try:
# For now, we'll get all API keys and filter in memory
# In a production system, this should use Firestore count queries
api_keys = await self.get_all()
return len([api_key for api_key in api_keys if api_key.user_id == user_id])
except Exception as e:
logger.error(f"Error counting API keys by user: {e}")
raise
async def update_last_used(self, api_key_id: ObjectId) -> bool: async def update_last_used(self, api_key_id: ObjectId) -> bool:
""" """

View File

@ -60,15 +60,19 @@ class FirestoreRepository(Generic[T]):
logger.error(f"Error getting {self.collection_name} document by ID: {e}") logger.error(f"Error getting {self.collection_name} document by ID: {e}")
raise raise
async def get_all(self) -> List[T]: async def get_all(self, skip: int = 0, limit: int = None) -> List[T]:
""" """
Get all documents from the collection Get all documents from the collection with optional pagination
Args:
skip: Number of documents to skip (default: 0)
limit: Maximum number of documents to return (default: None for all)
Returns: Returns:
List of model instances List of model instances
""" """
try: try:
docs = await self.provider.list_documents(self.collection_name) docs = await self.provider.list_documents(self.collection_name, skip=skip, limit=limit)
# Transform data to handle legacy format issues # Transform data to handle legacy format issues
transformed_docs = [] transformed_docs = []
@ -160,3 +164,16 @@ class FirestoreRepository(Generic[T]):
except Exception as e: except Exception as e:
logger.error(f"Error deleting {self.collection_name} document: {e}") logger.error(f"Error deleting {self.collection_name} document: {e}")
raise raise
async def count(self) -> int:
"""
Get total count of documents in the collection
Returns:
Total number of documents
"""
try:
return await self.provider.count_documents(self.collection_name)
except Exception as e:
logger.error(f"Error counting {self.collection_name} documents: {e}")
raise

View File

@ -24,14 +24,18 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
""" """
return await super().get_by_id(team_id) return await super().get_by_id(team_id)
async def get_all(self) -> List[TeamModel]: async def get_all(self, skip: int = 0, limit: int = None) -> List[TeamModel]:
""" """
Get all teams Get all teams with pagination
Args:
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: None for all)
Returns: Returns:
List of teams List of teams
""" """
return await super().get_all() return await super().get_all(skip=skip, limit=limit)
async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]: async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]:
""" """

View File

@ -1,4 +1,6 @@
import logging import logging
from typing import List, Optional
from bson import ObjectId
from src.db.repositories.firestore_repository import FirestoreRepository from src.db.repositories.firestore_repository import FirestoreRepository
from src.models.user import UserModel from src.models.user import UserModel
@ -10,7 +12,7 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
def __init__(self): def __init__(self):
super().__init__("users", UserModel) super().__init__("users", UserModel)
async def get_by_email(self, email: str) -> UserModel: async def get_by_email(self, email: str) -> Optional[UserModel]:
""" """
Get user by email Get user by email
@ -32,7 +34,7 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
logger.error(f"Error getting user by email: {e}") logger.error(f"Error getting user by email: {e}")
raise raise
async def get_by_team_id(self, team_id: str) -> list[UserModel]: async def get_by_team_id(self, team_id: str) -> List[UserModel]:
""" """
Get users by team ID Get users by team ID
@ -51,5 +53,53 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
logger.error(f"Error getting users by team ID: {e}") logger.error(f"Error getting users by team ID: {e}")
raise raise
async def get_by_team(self, team_id: ObjectId, skip: int = 0, limit: int = None) -> List[UserModel]:
"""
Get users by team ID with pagination
Args:
team_id: Team ID as ObjectId
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: None for all)
Returns:
List of users
"""
try:
# For now, we'll get all users and filter in memory
# In a production system, this should use Firestore queries for efficiency
users = await self.get_all()
filtered_users = [user for user in users if user.team_id == team_id]
# Apply pagination
if skip > 0:
filtered_users = filtered_users[skip:]
if limit is not None:
filtered_users = filtered_users[:limit]
return filtered_users
except Exception as e:
logger.error(f"Error getting users by team with pagination: {e}")
raise
async def count_by_team(self, team_id: ObjectId) -> int:
"""
Count users by team ID
Args:
team_id: Team ID as ObjectId
Returns:
Number of users in the team
"""
try:
# For now, we'll get all users and filter in memory
# In a production system, this should use Firestore count queries
users = await self.get_all()
return len([user for user in users if user.team_id == team_id])
except Exception as e:
logger.error(f"Error counting users by team: {e}")
raise
# Create a singleton repository # Create a singleton repository
firestore_user_repository = FirestoreUserRepository() firestore_user_repository = FirestoreUserRepository()

View File

@ -170,18 +170,23 @@ class AuthService:
is_active=created_key.is_active is_active=created_key.is_active
) )
async def list_user_api_keys(self, user: UserModel) -> ApiKeyListResponse: async def list_user_api_keys(self, user: UserModel, skip: int = 0, limit: int = 50) -> ApiKeyListResponse:
""" """
List API keys for a specific user List API keys for a specific user with pagination
Args: Args:
user: The user to list API keys for user: The user to list API keys for
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: 50)
Returns: Returns:
ApiKeyListResponse: List of API keys for the user ApiKeyListResponse: Paginated list of API keys for the user
""" """
# Get API keys for user # Get API keys for user with pagination
keys = await api_key_repository.get_by_user(user.id) keys = await api_key_repository.get_by_user(user.id, skip=skip, limit=limit)
# Get total count for pagination
total_count = await api_key_repository.count_by_user(user.id)
# Convert to response models # Convert to response models
response_keys = [] response_keys = []
@ -198,7 +203,7 @@ class AuthService:
is_active=key.is_active is_active=key.is_active
)) ))
return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys)) return ApiKeyListResponse(api_keys=response_keys, total=total_count)
async def revoke_api_key(self, key_id: str, user: UserModel) -> bool: async def revoke_api_key(self, key_id: str, user: UserModel) -> bool:
""" """

View File

@ -38,15 +38,22 @@ class TeamService:
updated_at=created_team.updated_at updated_at=created_team.updated_at
) )
async def list_teams(self) -> TeamListResponse: async def list_teams(self, skip: int = 0, limit: int = 50) -> TeamListResponse:
""" """
List all teams List all teams with pagination
Args:
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: 50)
Returns: Returns:
TeamListResponse: List of all teams TeamListResponse: Paginated list of teams
""" """
# Get all teams # Get teams with pagination
teams = await team_repository.get_all() teams = await team_repository.get_all(skip=skip, limit=limit)
# Get total count for pagination
total_count = await team_repository.count()
# Convert to response models # Convert to response models
response_teams = [] response_teams = []
@ -59,7 +66,7 @@ class TeamService:
updated_at=team.updated_at updated_at=team.updated_at
)) ))
return TeamListResponse(teams=response_teams, total=len(response_teams)) return TeamListResponse(teams=response_teams, total=total_count)
async def get_team(self, team_id: str) -> TeamResponse: async def get_team(self, team_id: str) -> TeamResponse:
""" """

View File

@ -150,15 +150,17 @@ class UserService:
updated_at=created_user.updated_at updated_at=created_user.updated_at
) )
async def list_users(self, team_id: Optional[str] = None) -> UserListResponse: async def list_users(self, skip: int = 0, limit: int = 50, team_id: Optional[str] = None) -> UserListResponse:
""" """
List users, optionally filtered by team List users with pagination, optionally filtered by team
Args: Args:
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: 50)
team_id: Optional team ID to filter by team_id: Optional team ID to filter by
Returns: Returns:
UserListResponse: List of users UserListResponse: Paginated list of users
Raises: Raises:
ValueError: If team_id is invalid ValueError: If team_id is invalid
@ -167,11 +169,13 @@ class UserService:
if team_id: if team_id:
try: try:
filter_team_id = ObjectId(team_id) filter_team_id = ObjectId(team_id)
users = await user_repository.get_by_team(filter_team_id) users = await user_repository.get_by_team(filter_team_id, skip=skip, limit=limit)
total_count = await user_repository.count_by_team(filter_team_id)
except Exception: except Exception:
raise ValueError("Invalid team ID") raise ValueError("Invalid team ID")
else: else:
users = await user_repository.get_all() users = await user_repository.get_all(skip=skip, limit=limit)
total_count = await user_repository.count()
# Convert to response # Convert to response
response_users = [] response_users = []
@ -187,7 +191,7 @@ class UserService:
updated_at=user.updated_at updated_at=user.updated_at
)) ))
return UserListResponse(users=response_users, total=len(response_users)) return UserListResponse(users=response_users, total=total_count)
async def get_user(self, user_id: str) -> UserResponse: async def get_user(self, user_id: str) -> UserResponse:
""" """