fix errors because of pagination
This commit is contained in:
parent
80d8a74b12
commit
8cc7993201
@ -168,12 +168,14 @@ class FirestoreProvider:
|
|||||||
logger.error(f"Error getting document from {collection_name}: {e}")
|
logger.error(f"Error getting document from {collection_name}: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def list_documents(self, collection_name: str) -> List[Dict[str, Any]]:
|
async def list_documents(self, collection_name: str, skip: int = 0, limit: int = None) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
List all documents in a collection
|
List documents in a collection with optional pagination
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
collection_name: Collection name
|
collection_name: Collection name
|
||||||
|
skip: Number of documents to skip (default: 0)
|
||||||
|
limit: Maximum number of documents to return (default: None for all)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of documents
|
List of documents
|
||||||
@ -187,8 +189,15 @@ class FirestoreProvider:
|
|||||||
# Debug log to understand the client state
|
# Debug log to understand the client state
|
||||||
logger.debug(f"Firestore client: {self.client}, Collection ref: {collection_ref}")
|
logger.debug(f"Firestore client: {self.client}, Collection ref: {collection_ref}")
|
||||||
|
|
||||||
|
# Build query with pagination
|
||||||
|
query = collection_ref
|
||||||
|
if skip > 0:
|
||||||
|
query = query.offset(skip)
|
||||||
|
if limit is not None:
|
||||||
|
query = query.limit(limit)
|
||||||
|
|
||||||
# Properly get the stream of documents
|
# Properly get the stream of documents
|
||||||
docs = collection_ref.stream()
|
docs = query.stream()
|
||||||
results = []
|
results = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
data = doc.to_dict()
|
data = doc.to_dict()
|
||||||
@ -198,7 +207,13 @@ class FirestoreProvider:
|
|||||||
except Exception as stream_error:
|
except Exception as stream_error:
|
||||||
logger.error(f"Error streaming documents: {stream_error}")
|
logger.error(f"Error streaming documents: {stream_error}")
|
||||||
# Fallback method - try listing documents differently
|
# Fallback method - try listing documents differently
|
||||||
docs = list(collection_ref.get())
|
query = collection_ref
|
||||||
|
if skip > 0:
|
||||||
|
query = query.offset(skip)
|
||||||
|
if limit is not None:
|
||||||
|
query = query.limit(limit)
|
||||||
|
|
||||||
|
docs = list(query.get())
|
||||||
results = []
|
results = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
data = doc.to_dict()
|
data = doc.to_dict()
|
||||||
@ -210,6 +225,37 @@ class FirestoreProvider:
|
|||||||
# Return empty list instead of raising to avoid API failures
|
# Return empty list instead of raising to avoid API failures
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def count_documents(self, collection_name: str) -> int:
|
||||||
|
"""
|
||||||
|
Count total number of documents in a collection
|
||||||
|
|
||||||
|
Args:
|
||||||
|
collection_name: Collection name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of documents
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
collection_ref = self.get_collection(collection_name)
|
||||||
|
|
||||||
|
# Use aggregation query to count documents efficiently
|
||||||
|
# Note: This requires Firestore to support count aggregation
|
||||||
|
try:
|
||||||
|
from google.cloud.firestore_v1.aggregation import AggregationQuery
|
||||||
|
query = collection_ref.select([]) # Select no fields for efficiency
|
||||||
|
aggregation_query = AggregationQuery(query)
|
||||||
|
result = aggregation_query.count().get()
|
||||||
|
return result[0].value
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
# Fallback: count by getting all documents (less efficient)
|
||||||
|
logger.warning(f"Using fallback count method for {collection_name}")
|
||||||
|
docs = list(collection_ref.stream())
|
||||||
|
return len(docs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error counting documents in {collection_name}: {e}")
|
||||||
|
# Return 0 instead of raising to avoid API failures
|
||||||
|
return 0
|
||||||
|
|
||||||
async def update_document(self, collection_name: str, doc_id: str, data: Dict[str, Any]) -> bool:
|
async def update_document(self, collection_name: str, doc_id: str, data: Dict[str, Any]) -> bool:
|
||||||
"""
|
"""
|
||||||
Update a document
|
Update a document
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import List, Optional
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from src.db.repositories.firestore_repository import FirestoreRepository
|
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||||
from src.models.api_key import ApiKeyModel
|
from src.models.api_key import ApiKeyModel
|
||||||
@ -12,7 +13,7 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("api_keys", ApiKeyModel)
|
super().__init__("api_keys", ApiKeyModel)
|
||||||
|
|
||||||
async def get_by_key_hash(self, key_hash: str) -> ApiKeyModel:
|
async def get_by_key_hash(self, key_hash: str) -> Optional[ApiKeyModel]:
|
||||||
"""
|
"""
|
||||||
Get API key by hash
|
Get API key by hash
|
||||||
|
|
||||||
@ -34,7 +35,7 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
|
|||||||
logger.error(f"Error getting API key by hash: {e}")
|
logger.error(f"Error getting API key by hash: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_by_user_id(self, user_id: str) -> list[ApiKeyModel]:
|
async def get_by_user_id(self, user_id: str) -> List[ApiKeyModel]:
|
||||||
"""
|
"""
|
||||||
Get API keys by user ID
|
Get API keys by user ID
|
||||||
|
|
||||||
@ -53,17 +54,53 @@ class FirestoreApiKeyRepository(FirestoreRepository[ApiKeyModel]):
|
|||||||
logger.error(f"Error getting API keys by user ID: {e}")
|
logger.error(f"Error getting API keys by user ID: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_by_user(self, user_id: ObjectId) -> list[ApiKeyModel]:
|
async def get_by_user(self, user_id: ObjectId, skip: int = 0, limit: int = None) -> List[ApiKeyModel]:
|
||||||
"""
|
"""
|
||||||
Get API keys by user (alias for get_by_user_id with ObjectId)
|
Get API keys by user with pagination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: User ID as ObjectId
|
||||||
|
skip: Number of records to skip for pagination (default: 0)
|
||||||
|
limit: Maximum number of records to return (default: None for all)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of API keys
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# For now, we'll get all API keys and filter in memory
|
||||||
|
# In a production system, this should use Firestore queries for efficiency
|
||||||
|
api_keys = await self.get_all()
|
||||||
|
filtered_keys = [api_key for api_key in api_keys if api_key.user_id == user_id]
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
if skip > 0:
|
||||||
|
filtered_keys = filtered_keys[skip:]
|
||||||
|
if limit is not None:
|
||||||
|
filtered_keys = filtered_keys[:limit]
|
||||||
|
|
||||||
|
return filtered_keys
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting API keys by user with pagination: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def count_by_user(self, user_id: ObjectId) -> int:
|
||||||
|
"""
|
||||||
|
Count API keys by user ID
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id: User ID as ObjectId
|
user_id: User ID as ObjectId
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of API keys
|
Number of API keys for the user
|
||||||
"""
|
"""
|
||||||
return await self.get_by_user_id(str(user_id))
|
try:
|
||||||
|
# For now, we'll get all API keys and filter in memory
|
||||||
|
# In a production system, this should use Firestore count queries
|
||||||
|
api_keys = await self.get_all()
|
||||||
|
return len([api_key for api_key in api_keys if api_key.user_id == user_id])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error counting API keys by user: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def update_last_used(self, api_key_id: ObjectId) -> bool:
|
async def update_last_used(self, api_key_id: ObjectId) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -60,15 +60,19 @@ class FirestoreRepository(Generic[T]):
|
|||||||
logger.error(f"Error getting {self.collection_name} document by ID: {e}")
|
logger.error(f"Error getting {self.collection_name} document by ID: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_all(self) -> List[T]:
|
async def get_all(self, skip: int = 0, limit: int = None) -> List[T]:
|
||||||
"""
|
"""
|
||||||
Get all documents from the collection
|
Get all documents from the collection with optional pagination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip: Number of documents to skip (default: 0)
|
||||||
|
limit: Maximum number of documents to return (default: None for all)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of model instances
|
List of model instances
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
docs = await self.provider.list_documents(self.collection_name)
|
docs = await self.provider.list_documents(self.collection_name, skip=skip, limit=limit)
|
||||||
|
|
||||||
# Transform data to handle legacy format issues
|
# Transform data to handle legacy format issues
|
||||||
transformed_docs = []
|
transformed_docs = []
|
||||||
@ -159,4 +163,17 @@ class FirestoreRepository(Generic[T]):
|
|||||||
return await self.provider.delete_document(self.collection_name, str(doc_id))
|
return await self.provider.delete_document(self.collection_name, str(doc_id))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error deleting {self.collection_name} document: {e}")
|
logger.error(f"Error deleting {self.collection_name} document: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def count(self) -> int:
|
||||||
|
"""
|
||||||
|
Get total count of documents in the collection
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of documents
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await self.provider.count_documents(self.collection_name)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error counting {self.collection_name} documents: {e}")
|
||||||
raise
|
raise
|
||||||
@ -24,14 +24,18 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
|
|||||||
"""
|
"""
|
||||||
return await super().get_by_id(team_id)
|
return await super().get_by_id(team_id)
|
||||||
|
|
||||||
async def get_all(self) -> List[TeamModel]:
|
async def get_all(self, skip: int = 0, limit: int = None) -> List[TeamModel]:
|
||||||
"""
|
"""
|
||||||
Get all teams
|
Get all teams with pagination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip: Number of records to skip for pagination (default: 0)
|
||||||
|
limit: Maximum number of records to return (default: None for all)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of teams
|
List of teams
|
||||||
"""
|
"""
|
||||||
return await super().get_all()
|
return await super().get_all(skip=skip, limit=limit)
|
||||||
|
|
||||||
async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]:
|
async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
|
from typing import List, Optional
|
||||||
|
from bson import ObjectId
|
||||||
from src.db.repositories.firestore_repository import FirestoreRepository
|
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||||
from src.models.user import UserModel
|
from src.models.user import UserModel
|
||||||
|
|
||||||
@ -10,7 +12,7 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("users", UserModel)
|
super().__init__("users", UserModel)
|
||||||
|
|
||||||
async def get_by_email(self, email: str) -> UserModel:
|
async def get_by_email(self, email: str) -> Optional[UserModel]:
|
||||||
"""
|
"""
|
||||||
Get user by email
|
Get user by email
|
||||||
|
|
||||||
@ -32,7 +34,7 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
|
|||||||
logger.error(f"Error getting user by email: {e}")
|
logger.error(f"Error getting user by email: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_by_team_id(self, team_id: str) -> list[UserModel]:
|
async def get_by_team_id(self, team_id: str) -> List[UserModel]:
|
||||||
"""
|
"""
|
||||||
Get users by team ID
|
Get users by team ID
|
||||||
|
|
||||||
@ -51,5 +53,53 @@ class FirestoreUserRepository(FirestoreRepository[UserModel]):
|
|||||||
logger.error(f"Error getting users by team ID: {e}")
|
logger.error(f"Error getting users by team ID: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def get_by_team(self, team_id: ObjectId, skip: int = 0, limit: int = None) -> List[UserModel]:
|
||||||
|
"""
|
||||||
|
Get users by team ID with pagination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID as ObjectId
|
||||||
|
skip: Number of records to skip for pagination (default: 0)
|
||||||
|
limit: Maximum number of records to return (default: None for all)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of users
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# For now, we'll get all users and filter in memory
|
||||||
|
# In a production system, this should use Firestore queries for efficiency
|
||||||
|
users = await self.get_all()
|
||||||
|
filtered_users = [user for user in users if user.team_id == team_id]
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
if skip > 0:
|
||||||
|
filtered_users = filtered_users[skip:]
|
||||||
|
if limit is not None:
|
||||||
|
filtered_users = filtered_users[:limit]
|
||||||
|
|
||||||
|
return filtered_users
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting users by team with pagination: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def count_by_team(self, team_id: ObjectId) -> int:
|
||||||
|
"""
|
||||||
|
Count users by team ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Team ID as ObjectId
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of users in the team
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# For now, we'll get all users and filter in memory
|
||||||
|
# In a production system, this should use Firestore count queries
|
||||||
|
users = await self.get_all()
|
||||||
|
return len([user for user in users if user.team_id == team_id])
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error counting users by team: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
# Create a singleton repository
|
# Create a singleton repository
|
||||||
firestore_user_repository = FirestoreUserRepository()
|
firestore_user_repository = FirestoreUserRepository()
|
||||||
@ -170,18 +170,23 @@ class AuthService:
|
|||||||
is_active=created_key.is_active
|
is_active=created_key.is_active
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_user_api_keys(self, user: UserModel) -> ApiKeyListResponse:
|
async def list_user_api_keys(self, user: UserModel, skip: int = 0, limit: int = 50) -> ApiKeyListResponse:
|
||||||
"""
|
"""
|
||||||
List API keys for a specific user
|
List API keys for a specific user with pagination
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user: The user to list API keys for
|
user: The user to list API keys for
|
||||||
|
skip: Number of records to skip for pagination (default: 0)
|
||||||
|
limit: Maximum number of records to return (default: 50)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiKeyListResponse: List of API keys for the user
|
ApiKeyListResponse: Paginated list of API keys for the user
|
||||||
"""
|
"""
|
||||||
# Get API keys for user
|
# Get API keys for user with pagination
|
||||||
keys = await api_key_repository.get_by_user(user.id)
|
keys = await api_key_repository.get_by_user(user.id, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
# Get total count for pagination
|
||||||
|
total_count = await api_key_repository.count_by_user(user.id)
|
||||||
|
|
||||||
# Convert to response models
|
# Convert to response models
|
||||||
response_keys = []
|
response_keys = []
|
||||||
@ -198,7 +203,7 @@ class AuthService:
|
|||||||
is_active=key.is_active
|
is_active=key.is_active
|
||||||
))
|
))
|
||||||
|
|
||||||
return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys))
|
return ApiKeyListResponse(api_keys=response_keys, total=total_count)
|
||||||
|
|
||||||
async def revoke_api_key(self, key_id: str, user: UserModel) -> bool:
|
async def revoke_api_key(self, key_id: str, user: UserModel) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -38,15 +38,22 @@ class TeamService:
|
|||||||
updated_at=created_team.updated_at
|
updated_at=created_team.updated_at
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_teams(self) -> TeamListResponse:
|
async def list_teams(self, skip: int = 0, limit: int = 50) -> TeamListResponse:
|
||||||
"""
|
"""
|
||||||
List all teams
|
List all teams with pagination
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skip: Number of records to skip for pagination (default: 0)
|
||||||
|
limit: Maximum number of records to return (default: 50)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
TeamListResponse: List of all teams
|
TeamListResponse: Paginated list of teams
|
||||||
"""
|
"""
|
||||||
# Get all teams
|
# Get teams with pagination
|
||||||
teams = await team_repository.get_all()
|
teams = await team_repository.get_all(skip=skip, limit=limit)
|
||||||
|
|
||||||
|
# Get total count for pagination
|
||||||
|
total_count = await team_repository.count()
|
||||||
|
|
||||||
# Convert to response models
|
# Convert to response models
|
||||||
response_teams = []
|
response_teams = []
|
||||||
@ -59,7 +66,7 @@ class TeamService:
|
|||||||
updated_at=team.updated_at
|
updated_at=team.updated_at
|
||||||
))
|
))
|
||||||
|
|
||||||
return TeamListResponse(teams=response_teams, total=len(response_teams))
|
return TeamListResponse(teams=response_teams, total=total_count)
|
||||||
|
|
||||||
async def get_team(self, team_id: str) -> TeamResponse:
|
async def get_team(self, team_id: str) -> TeamResponse:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -150,15 +150,17 @@ class UserService:
|
|||||||
updated_at=created_user.updated_at
|
updated_at=created_user.updated_at
|
||||||
)
|
)
|
||||||
|
|
||||||
async def list_users(self, team_id: Optional[str] = None) -> UserListResponse:
|
async def list_users(self, skip: int = 0, limit: int = 50, team_id: Optional[str] = None) -> UserListResponse:
|
||||||
"""
|
"""
|
||||||
List users, optionally filtered by team
|
List users with pagination, optionally filtered by team
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
skip: Number of records to skip for pagination (default: 0)
|
||||||
|
limit: Maximum number of records to return (default: 50)
|
||||||
team_id: Optional team ID to filter by
|
team_id: Optional team ID to filter by
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
UserListResponse: List of users
|
UserListResponse: Paginated list of users
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If team_id is invalid
|
ValueError: If team_id is invalid
|
||||||
@ -167,11 +169,13 @@ class UserService:
|
|||||||
if team_id:
|
if team_id:
|
||||||
try:
|
try:
|
||||||
filter_team_id = ObjectId(team_id)
|
filter_team_id = ObjectId(team_id)
|
||||||
users = await user_repository.get_by_team(filter_team_id)
|
users = await user_repository.get_by_team(filter_team_id, skip=skip, limit=limit)
|
||||||
|
total_count = await user_repository.count_by_team(filter_team_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("Invalid team ID")
|
raise ValueError("Invalid team ID")
|
||||||
else:
|
else:
|
||||||
users = await user_repository.get_all()
|
users = await user_repository.get_all(skip=skip, limit=limit)
|
||||||
|
total_count = await user_repository.count()
|
||||||
|
|
||||||
# Convert to response
|
# Convert to response
|
||||||
response_users = []
|
response_users = []
|
||||||
@ -187,7 +191,7 @@ class UserService:
|
|||||||
updated_at=user.updated_at
|
updated_at=user.updated_at
|
||||||
))
|
))
|
||||||
|
|
||||||
return UserListResponse(users=response_users, total=len(response_users))
|
return UserListResponse(users=response_users, total=total_count)
|
||||||
|
|
||||||
async def get_user(self, user_id: str) -> UserResponse:
|
async def get_user(self, user_id: str) -> UserResponse:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user