fix errors because of pagination
This commit is contained in:
parent
80d8a74b12
commit
8cc7993201
@ -168,12 +168,14 @@ class FirestoreProvider:
|
||||
logger.error(f"Error getting document from {collection_name}: {e}")
|
||||
raise
|
||||
|
||||
async def list_documents(self, collection_name: str) -> List[Dict[str, Any]]:
|
||||
async def list_documents(self, collection_name: str, skip: int = 0, limit: int = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
List all documents in a collection
|
||||
List documents in a collection with optional pagination
|
||||
|
||||
Args:
|
||||
collection_name: Collection name
|
||||
skip: Number of documents to skip (default: 0)
|
||||
limit: Maximum number of documents to return (default: None for all)
|
||||
|
||||
Returns:
|
||||
List of documents
|
||||
@ -187,8 +189,15 @@ class FirestoreProvider:
|
||||
# Debug log to understand the client state
|
||||
logger.debug(f"Firestore client: {self.client}, Collection ref: {collection_ref}")
|
||||
|
||||
# Build query with pagination
|
||||
query = collection_ref
|
||||
if skip > 0:
|
||||
query = query.offset(skip)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
# Properly get the stream of documents
|
||||
docs = collection_ref.stream()
|
||||
docs = query.stream()
|
||||
results = []
|
||||
for doc in docs:
|
||||
data = doc.to_dict()
|
||||
@ -198,7 +207,13 @@ class FirestoreProvider:
|
||||
except Exception as stream_error:
|
||||
logger.error(f"Error streaming documents: {stream_error}")
|
||||
# Fallback method - try listing documents differently
|
||||
docs = list(collection_ref.get())
|
||||
query = collection_ref
|
||||
if skip > 0:
|
||||
query = query.offset(skip)
|
||||
if limit is not None:
|
||||
query = query.limit(limit)
|
||||
|
||||
docs = list(query.get())
|
||||
results = []
|
||||
for doc in docs:
|
||||
data = doc.to_dict()
|
||||
@ -210,6 +225,37 @@ class FirestoreProvider:
|
||||
# Return empty list instead of raising to avoid API failures
|
||||
return []
|
||||
|
||||
async def count_documents(self, collection_name: str) -> int:
|
||||
"""
|
||||
Count total number of documents in a collection
|
||||
|
||||
Args:
|
||||
collection_name: Collection name
|
||||
|
||||
Returns:
|
||||
Total number of documents
|
||||
"""
|
||||
try:
|
||||
collection_ref = self.get_collection(collection_name)
|
||||
|
||||
# Use aggregation query to count documents efficiently
|
||||
# Note: This requires Firestore to support count aggregation
|
||||
try:
|
||||
from google.cloud.firestore_v1.aggregation import AggregationQuery
|
||||
query = collection_ref.select([]) # Select no fields for efficiency
|
||||
aggregation_query = AggregationQuery(query)
|
||||
result = aggregation_query.count().get()
|
||||
return result[0].value
|
||||
except (ImportError, AttributeError):
|
||||
# Fallback: count by getting all documents (less efficient)
|
||||
logger.warning(f"Using fallback count method for {collection_name}")
|
||||
docs = list(collection_ref.stream())
|
||||
return len(docs)
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting documents in {collection_name}: {e}")
|
||||
# Return 0 instead of raising to avoid API failures
|
||||
return 0
|
||||
|
||||
async def update_document(self, collection_name: str, doc_id: str, data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Update a document
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||
from src.models.api_key import ApiKeyModel
|
||||
@ -12,7 +13,7 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
|
||||
def __init__(self):
|
||||
super().__init__("api_keys", ApiKeyModel)
|
||||
|
||||
async def get_by_key_hash(self, key_hash: str) -> ApiKeyModel:
|
||||
async def get_by_key_hash(self, key_hash: str) -> Optional[ApiKeyModel]:
|
||||
"""
|
||||
Get API key by hash
|
||||
|
||||
@ -34,7 +35,7 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
|
||||
logger.error(f"Error getting API key by hash: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_user_id(self, user_id: str) -> list[ApiKeyModel]:
|
||||
async def get_by_user_id(self, user_id: str) -> List[ApiKeyModel]:
|
||||
"""
|
||||
Get API keys by user ID
|
||||
|
||||
@ -53,17 +54,53 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
|
||||
logger.error(f"Error getting API keys by user ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_user(self, user_id: ObjectId) -> list[ApiKeyModel]:
|
||||
async def get_by_user(self, user_id: ObjectId, skip: int = 0, limit: int = None) -> List[ApiKeyModel]:
|
||||
"""
|
||||
Get API keys by user (alias for get_by_user_id with ObjectId)
|
||||
Get API keys by user with pagination
|
||||
|
||||
Args:
|
||||
user_id: User ID as ObjectId
|
||||
skip: Number of records to skip for pagination (default: 0)
|
||||
limit: Maximum number of records to return (default: None for all)
|
||||
|
||||
Returns:
|
||||
List of API keys
|
||||
"""
|
||||
try:
|
||||
# For now, we'll get all API keys and filter in memory
|
||||
# In a production system, this should use Firestore queries for efficiency
|
||||
api_keys = await self.get_all()
|
||||
filtered_keys = [api_key for api_key in api_keys if api_key.user_id == user_id]
|
||||
|
||||
# Apply pagination
|
||||
if skip > 0:
|
||||
filtered_keys = filtered_keys[skip:]
|
||||
if limit is not None:
|
||||
filtered_keys = filtered_keys[:limit]
|
||||
|
||||
return filtered_keys
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API keys by user with pagination: {e}")
|
||||
raise
|
||||
|
||||
async def count_by_user(self, user_id: ObjectId) -> int:
|
||||
"""
|
||||
Count API keys by user ID
|
||||
|
||||
Args:
|
||||
user_id: User ID as ObjectId
|
||||
|
||||
Returns:
|
||||
List of API keys
|
||||
Number of API keys for the user
|
||||
"""
|
||||
return await self.get_by_user_id(str(user_id))
|
||||
try:
|
||||
# For now, we'll get all API keys and filter in memory
|
||||
# In a production system, this should use Firestore count queries
|
||||
api_keys = await self.get_all()
|
||||
return len([api_key for api_key in api_keys if api_key.user_id == user_id])
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting API keys by user: {e}")
|
||||
raise
|
||||
|
||||
async def update_last_used(self, api_key_id: ObjectId) -> bool:
|
||||
"""
|
||||
|
||||
@ -60,15 +60,19 @@ class FirestoreRepository(Generic[T]):
|
||||
logger.error(f"Error getting {self.collection_name} document by ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_all(self) -> List[T]:
|
||||
async def get_all(self, skip: int = 0, limit: int = None) -> List[T]:
|
||||
"""
|
||||
Get all documents from the collection
|
||||
Get all documents from the collection with optional pagination
|
||||
|
||||
Args:
|
||||
skip: Number of documents to skip (default: 0)
|
||||
limit: Maximum number of documents to return (default: None for all)
|
||||
|
||||
Returns:
|
||||
List of model instances
|
||||
"""
|
||||
try:
|
||||
docs = await self.provider.list_documents(self.collection_name)
|
||||
docs = await self.provider.list_documents(self.collection_name, skip=skip, limit=limit)
|
||||
|
||||
# Transform data to handle legacy format issues
|
||||
transformed_docs = []
|
||||
@ -160,3 +164,16 @@ class FirestoreRepository(Generic[T]):
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting {self.collection_name} document: {e}")
|
||||
raise
|
||||
|
||||
async def count(self) -> int:
|
||||
"""
|
||||
Get total count of documents in the collection
|
||||
|
||||
Returns:
|
||||
Total number of documents
|
||||
"""
|
||||
try:
|
||||
return await self.provider.count_documents(self.collection_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting {self.collection_name} documents: {e}")
|
||||
raise
|
||||
@ -24,14 +24,18 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
|
||||
"""
|
||||
return await super().get_by_id(team_id)
|
||||
|
||||
async def get_all(self) -> List[TeamModel]:
|
||||
async def get_all(self, skip: int = 0, limit: int = None) -> List[TeamModel]:
|
||||
"""
|
||||
Get all teams
|
||||
Get all teams with pagination
|
||||
|
||||
Args:
|
||||
skip: Number of records to skip for pagination (default: 0)
|
||||
limit: Maximum number of records to return (default: None for all)
|
||||
|
||||
Returns:
|
||||
List of teams
|
||||
"""
|
||||
return await super().get_all()
|
||||
return await super().get_all(skip=skip, limit=limit)
|
||||
|
||||
async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||
from src.models.user import UserModel
|
||||
|
||||
@ -10,7 +12,7 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
|
||||
def __init__(self):
|
||||
super().__init__("users", UserModel)
|
||||
|
||||
async def get_by_email(self, email: str) -> UserModel:
|
||||
async def get_by_email(self, email: str) -> Optional[UserModel]:
|
||||
"""
|
||||
Get user by email
|
||||
|
||||
@ -32,7 +34,7 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
|
||||
logger.error(f"Error getting user by email: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_team_id(self, team_id: str) -> list[UserModel]:
|
||||
async def get_by_team_id(self, team_id: str) -> List[UserModel]:
|
||||
"""
|
||||
Get users by team ID
|
||||
|
||||
@ -51,5 +53,53 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
|
||||
logger.error(f"Error getting users by team ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_team(self, team_id: ObjectId, skip: int = 0, limit: int = None) -> List[UserModel]:
|
||||
"""
|
||||
Get users by team ID with pagination
|
||||
|
||||
Args:
|
||||
team_id: Team ID as ObjectId
|
||||
skip: Number of records to skip for pagination (default: 0)
|
||||
limit: Maximum number of records to return (default: None for all)
|
||||
|
||||
Returns:
|
||||
List of users
|
||||
"""
|
||||
try:
|
||||
# For now, we'll get all users and filter in memory
|
||||
# In a production system, this should use Firestore queries for efficiency
|
||||
users = await self.get_all()
|
||||
filtered_users = [user for user in users if user.team_id == team_id]
|
||||
|
||||
# Apply pagination
|
||||
if skip > 0:
|
||||
filtered_users = filtered_users[skip:]
|
||||
if limit is not None:
|
||||
filtered_users = filtered_users[:limit]
|
||||
|
||||
return filtered_users
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting users by team with pagination: {e}")
|
||||
raise
|
||||
|
||||
async def count_by_team(self, team_id: ObjectId) -> int:
|
||||
"""
|
||||
Count users by team ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID as ObjectId
|
||||
|
||||
Returns:
|
||||
Number of users in the team
|
||||
"""
|
||||
try:
|
||||
# For now, we'll get all users and filter in memory
|
||||
# In a production system, this should use Firestore count queries
|
||||
users = await self.get_all()
|
||||
return len([user for user in users if user.team_id == team_id])
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting users by team: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
firestore_user_repository = FirestoreUserRepository()
|
||||
@ -170,18 +170,23 @@ class AuthService:
|
||||
is_active=created_key.is_active
|
||||
)
|
||||
|
||||
async def list_user_api_keys(self, user: UserModel) -> ApiKeyListResponse:
|
||||
async def list_user_api_keys(self, user: UserModel, skip: int = 0, limit: int = 50) -> ApiKeyListResponse:
|
||||
"""
|
||||
List API keys for a specific user
|
||||
List API keys for a specific user with pagination
|
||||
|
||||
Args:
|
||||
user: The user to list API keys for
|
||||
skip: Number of records to skip for pagination (default: 0)
|
||||
limit: Maximum number of records to return (default: 50)
|
||||
|
||||
Returns:
|
||||
ApiKeyListResponse: List of API keys for the user
|
||||
ApiKeyListResponse: Paginated list of API keys for the user
|
||||
"""
|
||||
# Get API keys for user
|
||||
keys = await api_key_repository.get_by_user(user.id)
|
||||
# Get API keys for user with pagination
|
||||
keys = await api_key_repository.get_by_user(user.id, skip=skip, limit=limit)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = await api_key_repository.count_by_user(user.id)
|
||||
|
||||
# Convert to response models
|
||||
response_keys = []
|
||||
@ -198,7 +203,7 @@ class AuthService:
|
||||
is_active=key.is_active
|
||||
))
|
||||
|
||||
return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys))
|
||||
return ApiKeyListResponse(api_keys=response_keys, total=total_count)
|
||||
|
||||
async def revoke_api_key(self, key_id: str, user: UserModel) -> bool:
|
||||
"""
|
||||
|
||||
@ -38,15 +38,22 @@ class TeamService:
|
||||
updated_at=created_team.updated_at
|
||||
)
|
||||
|
||||
async def list_teams(self) -> TeamListResponse:
|
||||
async def list_teams(self, skip: int = 0, limit: int = 50) -> TeamListResponse:
|
||||
"""
|
||||
List all teams
|
||||
List all teams with pagination
|
||||
|
||||
Args:
|
||||
skip: Number of records to skip for pagination (default: 0)
|
||||
limit: Maximum number of records to return (default: 50)
|
||||
|
||||
Returns:
|
||||
TeamListResponse: List of all teams
|
||||
TeamListResponse: Paginated list of teams
|
||||
"""
|
||||
# Get all teams
|
||||
teams = await team_repository.get_all()
|
||||
# Get teams with pagination
|
||||
teams = await team_repository.get_all(skip=skip, limit=limit)
|
||||
|
||||
# Get total count for pagination
|
||||
total_count = await team_repository.count()
|
||||
|
||||
# Convert to response models
|
||||
response_teams = []
|
||||
@ -59,7 +66,7 @@ class TeamService:
|
||||
updated_at=team.updated_at
|
||||
))
|
||||
|
||||
return TeamListResponse(teams=response_teams, total=len(response_teams))
|
||||
return TeamListResponse(teams=response_teams, total=total_count)
|
||||
|
||||
async def get_team(self, team_id: str) -> TeamResponse:
|
||||
"""
|
||||
|
||||
@ -150,15 +150,17 @@ class UserService:
|
||||
updated_at=created_user.updated_at
|
||||
)
|
||||
|
||||
async def list_users(self, team_id: Optional[str] = None) -> UserListResponse:
|
||||
async def list_users(self, skip: int = 0, limit: int = 50, team_id: Optional[str] = None) -> UserListResponse:
|
||||
"""
|
||||
List users, optionally filtered by team
|
||||
List users with pagination, optionally filtered by team
|
||||
|
||||
Args:
|
||||
skip: Number of records to skip for pagination (default: 0)
|
||||
limit: Maximum number of records to return (default: 50)
|
||||
team_id: Optional team ID to filter by
|
||||
|
||||
Returns:
|
||||
UserListResponse: List of users
|
||||
UserListResponse: Paginated list of users
|
||||
|
||||
Raises:
|
||||
ValueError: If team_id is invalid
|
||||
@ -167,11 +169,13 @@ class UserService:
|
||||
if team_id:
|
||||
try:
|
||||
filter_team_id = ObjectId(team_id)
|
||||
users = await user_repository.get_by_team(filter_team_id)
|
||||
users = await user_repository.get_by_team(filter_team_id, skip=skip, limit=limit)
|
||||
total_count = await user_repository.count_by_team(filter_team_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid team ID")
|
||||
else:
|
||||
users = await user_repository.get_all()
|
||||
users = await user_repository.get_all(skip=skip, limit=limit)
|
||||
total_count = await user_repository.count()
|
||||
|
||||
# Convert to response
|
||||
response_users = []
|
||||
@ -187,7 +191,7 @@ class UserService:
|
||||
updated_at=user.updated_at
|
||||
))
|
||||
|
||||
return UserListResponse(users=response_users, total=len(response_users))
|
||||
return UserListResponse(users=response_users, total=total_count)
|
||||
|
||||
async def get_user(self, user_id: str) -> UserResponse:
|
||||
"""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user