fix errors because of pagination

This commit is contained in:
johnpccd 2025-05-25 22:28:41 +02:00
parent 80d8a74b12
commit 8cc7993201
8 changed files with 206 additions and 36 deletions

View File

@ -168,12 +168,14 @@ class FirestoreProvider:
logger.error(f"Error getting document from {collection_name}: {e}")
raise
async def list_documents(self, collection_name: str) -> List[Dict[str, Any]]:
async def list_documents(self, collection_name: str, skip: int = 0, limit: int = None) -> List[Dict[str, Any]]:
"""
List all documents in a collection
List documents in a collection with optional pagination
Args:
collection_name: Collection name
skip: Number of documents to skip (default: 0)
limit: Maximum number of documents to return (default: None for all)
Returns:
List of documents
@ -187,8 +189,15 @@ class FirestoreProvider:
# Debug log to understand the client state
logger.debug(f"Firestore client: {self.client}, Collection ref: {collection_ref}")
# Build query with pagination
query = collection_ref
if skip > 0:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
# Properly get the stream of documents
docs = collection_ref.stream()
docs = query.stream()
results = []
for doc in docs:
data = doc.to_dict()
@ -198,7 +207,13 @@ class FirestoreProvider:
except Exception as stream_error:
logger.error(f"Error streaming documents: {stream_error}")
# Fallback method - try listing documents differently
docs = list(collection_ref.get())
query = collection_ref
if skip > 0:
query = query.offset(skip)
if limit is not None:
query = query.limit(limit)
docs = list(query.get())
results = []
for doc in docs:
data = doc.to_dict()
@ -210,6 +225,37 @@ class FirestoreProvider:
# Return empty list instead of raising to avoid API failures
return []
async def count_documents(self, collection_name: str) -> int:
"""
Count total number of documents in a collection
Args:
collection_name: Collection name
Returns:
Total number of documents
"""
try:
collection_ref = self.get_collection(collection_name)
# Use aggregation query to count documents efficiently
# Note: This requires Firestore to support count aggregation
try:
from google.cloud.firestore_v1.aggregation import AggregationQuery
query = collection_ref.select([]) # Select no fields for efficiency
aggregation_query = AggregationQuery(query)
result = aggregation_query.count().get()
return result[0].value
except (ImportError, AttributeError):
# Fallback: count by getting all documents (less efficient)
logger.warning(f"Using fallback count method for {collection_name}")
docs = list(collection_ref.stream())
return len(docs)
except Exception as e:
logger.error(f"Error counting documents in {collection_name}: {e}")
# Return 0 instead of raising to avoid API failures
return 0
async def update_document(self, collection_name: str, doc_id: str, data: Dict[str, Any]) -> bool:
"""
Update a document

View File

@ -1,5 +1,6 @@
import logging
from datetime import datetime
from typing import List, Optional
from bson import ObjectId
from src.db.repositories.firestore_repository import FirestoreRepository
from src.models.api_key import ApiKeyModel
@ -12,7 +13,7 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
def __init__(self):
super().__init__("api_keys", ApiKeyModel)
async def get_by_key_hash(self, key_hash: str) -> ApiKeyModel:
async def get_by_key_hash(self, key_hash: str) -> Optional[ApiKeyModel]:
"""
Get API key by hash
@ -34,7 +35,7 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
logger.error(f"Error getting API key by hash: {e}")
raise
async def get_by_user_id(self, user_id: str) -> list[ApiKeyModel]:
async def get_by_user_id(self, user_id: str) -> List[ApiKeyModel]:
"""
Get API keys by user ID
@ -53,17 +54,53 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
logger.error(f"Error getting API keys by user ID: {e}")
raise
async def get_by_user(self, user_id: ObjectId) -> list[ApiKeyModel]:
async def get_by_user(self, user_id: ObjectId, skip: int = 0, limit: int = None) -> List[ApiKeyModel]:
"""
Get API keys by user (alias for get_by_user_id with ObjectId)
Get API keys by user with pagination
Args:
user_id: User ID as ObjectId
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: None for all)
Returns:
List of API keys
"""
try:
# For now, we'll get all API keys and filter in memory
# In a production system, this should use Firestore queries for efficiency
api_keys = await self.get_all()
filtered_keys = [api_key for api_key in api_keys if api_key.user_id == user_id]
# Apply pagination
if skip > 0:
filtered_keys = filtered_keys[skip:]
if limit is not None:
filtered_keys = filtered_keys[:limit]
return filtered_keys
except Exception as e:
logger.error(f"Error getting API keys by user with pagination: {e}")
raise
async def count_by_user(self, user_id: ObjectId) -> int:
"""
Count API keys by user ID
Args:
user_id: User ID as ObjectId
Returns:
List of API keys
Number of API keys for the user
"""
return await self.get_by_user_id(str(user_id))
try:
# For now, we'll get all API keys and filter in memory
# In a production system, this should use Firestore count queries
api_keys = await self.get_all()
return len([api_key for api_key in api_keys if api_key.user_id == user_id])
except Exception as e:
logger.error(f"Error counting API keys by user: {e}")
raise
async def update_last_used(self, api_key_id: ObjectId) -> bool:
"""

View File

@ -60,15 +60,19 @@ class FirestoreRepository(Generic[T]):
logger.error(f"Error getting {self.collection_name} document by ID: {e}")
raise
async def get_all(self) -> List[T]:
async def get_all(self, skip: int = 0, limit: int = None) -> List[T]:
"""
Get all documents from the collection
Get all documents from the collection with optional pagination
Args:
skip: Number of documents to skip (default: 0)
limit: Maximum number of documents to return (default: None for all)
Returns:
List of model instances
"""
try:
docs = await self.provider.list_documents(self.collection_name)
docs = await self.provider.list_documents(self.collection_name, skip=skip, limit=limit)
# Transform data to handle legacy format issues
transformed_docs = []
@ -160,3 +164,16 @@ class FirestoreRepository(Generic[T]):
except Exception as e:
logger.error(f"Error deleting {self.collection_name} document: {e}")
raise
async def count(self) -> int:
"""
Get total count of documents in the collection
Returns:
Total number of documents
"""
try:
return await self.provider.count_documents(self.collection_name)
except Exception as e:
logger.error(f"Error counting {self.collection_name} documents: {e}")
raise

View File

@ -24,14 +24,18 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
"""
return await super().get_by_id(team_id)
async def get_all(self) -> List[TeamModel]:
async def get_all(self, skip: int = 0, limit: int = None) -> List[TeamModel]:
"""
Get all teams
Get all teams with pagination
Args:
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: None for all)
Returns:
List of teams
"""
return await super().get_all()
return await super().get_all(skip=skip, limit=limit)
async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]:
"""

View File

@ -1,4 +1,6 @@
import logging
from typing import List, Optional
from bson import ObjectId
from src.db.repositories.firestore_repository import FirestoreRepository
from src.models.user import UserModel
@ -10,7 +12,7 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
def __init__(self):
super().__init__("users", UserModel)
async def get_by_email(self, email: str) -> UserModel:
async def get_by_email(self, email: str) -> Optional[UserModel]:
"""
Get user by email
@ -32,7 +34,7 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
logger.error(f"Error getting user by email: {e}")
raise
async def get_by_team_id(self, team_id: str) -> list[UserModel]:
async def get_by_team_id(self, team_id: str) -> List[UserModel]:
"""
Get users by team ID
@ -51,5 +53,53 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
logger.error(f"Error getting users by team ID: {e}")
raise
async def get_by_team(self, team_id: ObjectId, skip: int = 0, limit: int = None) -> List[UserModel]:
"""
Get users by team ID with pagination
Args:
team_id: Team ID as ObjectId
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: None for all)
Returns:
List of users
"""
try:
# For now, we'll get all users and filter in memory
# In a production system, this should use Firestore queries for efficiency
users = await self.get_all()
filtered_users = [user for user in users if user.team_id == team_id]
# Apply pagination
if skip > 0:
filtered_users = filtered_users[skip:]
if limit is not None:
filtered_users = filtered_users[:limit]
return filtered_users
except Exception as e:
logger.error(f"Error getting users by team with pagination: {e}")
raise
async def count_by_team(self, team_id: ObjectId) -> int:
"""
Count users by team ID
Args:
team_id: Team ID as ObjectId
Returns:
Number of users in the team
"""
try:
# For now, we'll get all users and filter in memory
# In a production system, this should use Firestore count queries
users = await self.get_all()
return len([user for user in users if user.team_id == team_id])
except Exception as e:
logger.error(f"Error counting users by team: {e}")
raise
# Create a singleton repository
firestore_user_repository = FirestoreUserRepository()

View File

@ -170,18 +170,23 @@ class AuthService:
is_active=created_key.is_active
)
async def list_user_api_keys(self, user: UserModel) -> ApiKeyListResponse:
async def list_user_api_keys(self, user: UserModel, skip: int = 0, limit: int = 50) -> ApiKeyListResponse:
"""
List API keys for a specific user
List API keys for a specific user with pagination
Args:
user: The user to list API keys for
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: 50)
Returns:
ApiKeyListResponse: List of API keys for the user
ApiKeyListResponse: Paginated list of API keys for the user
"""
# Get API keys for user
keys = await api_key_repository.get_by_user(user.id)
# Get API keys for user with pagination
keys = await api_key_repository.get_by_user(user.id, skip=skip, limit=limit)
# Get total count for pagination
total_count = await api_key_repository.count_by_user(user.id)
# Convert to response models
response_keys = []
@ -198,7 +203,7 @@ class AuthService:
is_active=key.is_active
))
return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys))
return ApiKeyListResponse(api_keys=response_keys, total=total_count)
async def revoke_api_key(self, key_id: str, user: UserModel) -> bool:
"""

View File

@ -38,15 +38,22 @@ class TeamService:
updated_at=created_team.updated_at
)
async def list_teams(self) -> TeamListResponse:
async def list_teams(self, skip: int = 0, limit: int = 50) -> TeamListResponse:
"""
List all teams
List all teams with pagination
Args:
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: 50)
Returns:
TeamListResponse: List of all teams
TeamListResponse: Paginated list of teams
"""
# Get all teams
teams = await team_repository.get_all()
# Get teams with pagination
teams = await team_repository.get_all(skip=skip, limit=limit)
# Get total count for pagination
total_count = await team_repository.count()
# Convert to response models
response_teams = []
@ -59,7 +66,7 @@ class TeamService:
updated_at=team.updated_at
))
return TeamListResponse(teams=response_teams, total=len(response_teams))
return TeamListResponse(teams=response_teams, total=total_count)
async def get_team(self, team_id: str) -> TeamResponse:
"""

View File

@ -150,15 +150,17 @@ class UserService:
updated_at=created_user.updated_at
)
async def list_users(self, team_id: Optional[str] = None) -> UserListResponse:
async def list_users(self, skip: int = 0, limit: int = 50, team_id: Optional[str] = None) -> UserListResponse:
"""
List users, optionally filtered by team
List users with pagination, optionally filtered by team
Args:
skip: Number of records to skip for pagination (default: 0)
limit: Maximum number of records to return (default: 50)
team_id: Optional team ID to filter by
Returns:
UserListResponse: List of users
UserListResponse: Paginated list of users
Raises:
ValueError: If team_id is invalid
@ -167,11 +169,13 @@ class UserService:
if team_id:
try:
filter_team_id = ObjectId(team_id)
users = await user_repository.get_by_team(filter_team_id)
users = await user_repository.get_by_team(filter_team_id, skip=skip, limit=limit)
total_count = await user_repository.count_by_team(filter_team_id)
except Exception:
raise ValueError("Invalid team ID")
else:
users = await user_repository.get_all()
users = await user_repository.get_all(skip=skip, limit=limit)
total_count = await user_repository.count()
# Convert to response
response_users = []
@ -187,7 +191,7 @@ class UserService:
updated_at=user.updated_at
))
return UserListResponse(users=response_users, total=len(response_users))
return UserListResponse(users=response_users, total=total_count)
async def get_user(self, user_id: str) -> UserResponse:
"""