diff --git a/README.md b/README.md index 622895c..e71dc1a 100644 --- a/README.md +++ b/README.md @@ -28,7 +28,12 @@ sereact/ │ │ └── v1/ # API version 1 routes │ ├── auth/ # Authentication and authorization │ ├── config/ # Configuration management + │ ├── core/ # Core application logic + │ ├── db/ # Database layer + │ │ ├── providers/ # Database providers (Firestore) + │ │ └── repositories/ # Data access repositories │ ├── models/ # Database models + │ ├── schemas/ # API request/response schemas │ ├── services/ # Business logic services │ └── utils/ # Utility functions ├── tests/ # Test code diff --git a/src/db/__init__.py b/src/db/__init__.py index 3d19fa9..f373381 100644 --- a/src/db/__init__.py +++ b/src/db/__init__.py @@ -1,100 +1,44 @@ -import logging -import os -import json -from google.cloud import firestore -from google.oauth2 import service_account -from src.config.config import settings - -logger = logging.getLogger(__name__) - -class Database: - client = None - database_name = None - - def connect_to_database(self): - """Create database connection.""" - try: - # Store database name - self.database_name = settings.FIRESTORE_DATABASE_NAME - - # Print project ID and database name for debugging - logger.info(f"Attempting to connect to Firestore with project ID: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}") - - # First try using Application Default Credentials (for Cloud environments) - try: - logger.info("Attempting to connect using Application Default Credentials") - self.client = firestore.Client( - project=settings.FIRESTORE_PROJECT_ID, - database=self.database_name - ) - # Test connection by trying to access a collection - self.client.collection('test').limit(1).get() - logger.info(f"Connected to Firestore project using Application Default Credentials: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}") - return - except Exception as adc_error: - logger.error(f"Application Default Credentials failed: {adc_error}", exc_info=True) - - # Fall back to service account file - credentials_path = settings.FIRESTORE_CREDENTIALS_FILE - if not os.path.exists(credentials_path): - logger.error(f"Firestore credentials file not found: {credentials_path}") - raise FileNotFoundError(f"Credentials file not found: {credentials_path}") - - # Print key file contents (without sensitive parts) for debugging - try: - with open(credentials_path, 'r') as f: - key_data = json.load(f) - # Log non-sensitive parts of the key - logger.info(f"Using credentials file with project_id: {key_data.get('project_id')}") - logger.info(f"Client email: {key_data.get('client_email')}") - logger.info(f"Key file type: {key_data.get('type')}") - except Exception as e: - logger.error(f"Error reading key file: {e}") - - # Load credentials - credentials = service_account.Credentials.from_service_account_file(credentials_path) - - # Initialize Firestore client - self.client = firestore.Client( - project=settings.FIRESTORE_PROJECT_ID, - database=self.database_name, - credentials=credentials - ) - - # Test connection by trying to access a collection - self.client.collection('test').limit(1).get() - - logger.info(f"Connected to Firestore project using credentials file: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}") - except Exception as e: - logger.error(f"Failed to connect to Firestore: {e}", exc_info=True) - raise - - def close_database_connection(self): - """Close database connection.""" - try: - # No explicit close method needed for Firestore client - self.client = None - logger.info("Closed Firestore connection") - except Exception as e: - logger.error(f"Failed to close Firestore connection: {e}") - - def get_database(self): - """Get the database instance.""" - if self.client is None: - logger.warning("Database client is None. Attempting to reconnect...") - try: - self.connect_to_database() - except Exception as e: - logger.error(f"Failed to reconnect to database: {e}") - # Return None to avoid further errors but log this issue - return None - - # Verify that client is properly initialized - if self.client is None: - logger.error("Database client is still None after reconnect attempt") - return None - - return self.client - -# Create a singleton database instance -db = Database() +import logging +from src.config.config import settings +from src.db.providers.firestore_provider import FirestoreProvider + +logger = logging.getLogger(__name__) + +class Database: + """Database connection manager for Firestore""" + + def __init__(self): + self.provider = FirestoreProvider() + + def connect_to_database(self): + """Create database connection.""" + try: + logger.info(f"Connecting to Firestore database: {settings.FIRESTORE_DATABASE_NAME}") + self.provider.connect() + logger.info(f"Successfully connected to Firestore database: {settings.FIRESTORE_DATABASE_NAME}") + except Exception as e: + logger.error(f"Failed to connect to Firestore: {e}", exc_info=True) + raise + + def close_database_connection(self): + """Close database connection.""" + try: + self.provider.disconnect() + logger.info("Closed Firestore connection") + except Exception as e: + logger.error(f"Failed to close Firestore connection: {e}") + + def get_database(self): + """Get the database instance.""" + if self.provider.client is None: + logger.warning("Database client is None. Attempting to reconnect...") + try: + self.connect_to_database() + except Exception as e: + logger.error(f"Failed to reconnect to database: {e}") + return None + + return self.provider.client + +# Create a singleton database instance +db = Database() diff --git a/src/db/providers/__init__.py b/src/db/providers/__init__.py new file mode 100644 index 0000000..c63905e --- /dev/null +++ b/src/db/providers/__init__.py @@ -0,0 +1,5 @@ +"""Database providers package""" + +from .firestore_provider import FirestoreProvider, firestore_db + +__all__ = ["FirestoreProvider", "firestore_db"] \ No newline at end of file diff --git a/src/db/providers/firestore_provider.py b/src/db/providers/firestore_provider.py index 53b5210..29999e3 100644 --- a/src/db/providers/firestore_provider.py +++ b/src/db/providers/firestore_provider.py @@ -99,7 +99,7 @@ class FirestoreProvider: try: collection = self.get_collection(collection_name) - # Handle ObjectId conversion for Firestore + # Process data for Firestore processed_data = {} for key, value in data.items(): if value is None: diff --git a/src/db/repositories/__init__.py b/src/db/repositories/__init__.py index d17da79..0d86619 100644 --- a/src/db/repositories/__init__.py +++ b/src/db/repositories/__init__.py @@ -3,36 +3,19 @@ from src.config.config import settings logger = logging.getLogger(__name__) -# Determine which repository implementation to use -# Default to Firestore for now -REPOSITORY_TYPE = "firestore" - def init_repositories(): - """Initialize the repository implementations based on configuration.""" - global REPOSITORY_TYPE + """Initialize the Firestore repository implementations.""" + logger.info("Initializing Firestore repositories") - logger.info(f"Initializing repositories with type: {REPOSITORY_TYPE}") + # Import and initialize Firestore repositories + from src.db.providers.firestore_provider import FirestoreProvider - if REPOSITORY_TYPE == "firestore": - # Import and initialize Firestore repositories - from src.db.repositories.firestore_team_repository import firestore_team_repository - from src.db.providers.firestore_provider import firestore_db + # Initialize the Firestore provider + firestore_provider = FirestoreProvider() + if not firestore_provider.client: + firestore_provider.connect() - # Make sure Firestore client is initialized - if not firestore_db.client: - firestore_db.connect() - - # Replace the default repositories with Firestore implementations - from src.db.repositories.team_repository import team_repository as default_team_repo - - # Dynamically update the module to use Firestore repositories - import sys - sys.modules["src.db.repositories.team_repository"].team_repository = firestore_team_repository - - logger.info("Firestore repositories initialized") - else: - # Default MongoDB repositories are already imported - logger.info("Using default MongoDB repositories") + logger.info("Firestore repositories initialized successfully") # Initialize repositories at module import time init_repositories() diff --git a/src/db/repositories/api_key_repository.py b/src/db/repositories/api_key_repository.py deleted file mode 100644 index 7373c77..0000000 --- a/src/db/repositories/api_key_repository.py +++ /dev/null @@ -1,199 +0,0 @@ -import logging -from datetime import datetime -from typing import List, Optional -from bson import ObjectId - -from src.db import db -from src.models.api_key import ApiKeyModel - -logger = logging.getLogger(__name__) - -class ApiKeyRepository: - """Repository for API key operations""" - - collection_name = "api_keys" - - @property - def collection(self): - client = db.get_database() - if client is None: - logger.error("Database connection is None, cannot access collection") - raise RuntimeError("Database connection is not available") - # Use the Firestore client to get the collection - return client.collection(self.collection_name) - - async def create(self, api_key: ApiKeyModel) -> ApiKeyModel: - """ - Create a new API key - - Args: - api_key: API key data - - Returns: - Created API key with ID - """ - try: - # Convert to dict for Firestore - api_key_dict = api_key.dict(by_alias=True) - - # Create a new document with an auto-generated ID - doc_ref = self.collection.document() - - # Set the ID in the model and data - api_key.id = doc_ref.id - api_key_dict['id'] = doc_ref.id - - # Save to Firestore - doc_ref.set(api_key_dict) - - logger.info(f"API key created: {doc_ref.id}") - return api_key - except Exception as e: - logger.error(f"Error creating API key: {e}", exc_info=True) - raise - - async def get_by_id(self, key_id: str) -> Optional[ApiKeyModel]: - """ - Get API key by ID - - Args: - key_id: API key ID - - Returns: - API key if found, None otherwise - """ - try: - doc_ref = self.collection.document(str(key_id)) - doc = doc_ref.get() - if doc.exists: - data = doc.to_dict() - data['id'] = doc.id - return ApiKeyModel(**data) - return None - except Exception as e: - logger.error(f"Error getting API key by ID: {e}", exc_info=True) - raise - - async def get_by_hash(self, key_hash: str) -> Optional[ApiKeyModel]: - """ - Get API key by hash - - Args: - key_hash: API key hash - - Returns: - API key if found, None otherwise - """ - try: - query = self.collection.where("key_hash", "==", key_hash).limit(1) - results = query.stream() - - for doc in results: - data = doc.to_dict() - data['id'] = doc.id - return ApiKeyModel(**data) - return None - except Exception as e: - logger.error(f"Error getting API key by hash: {e}", exc_info=True) - raise - - async def get_by_user(self, user_id: str) -> List[ApiKeyModel]: - """ - Get API keys by user ID - - Args: - user_id: User ID - - Returns: - List of API keys for the user - """ - try: - keys = [] - query = self.collection.where("user_id", "==", str(user_id)) - results = query.stream() - - for doc in results: - data = doc.to_dict() - data['id'] = doc.id - keys.append(ApiKeyModel(**data)) - return keys - except Exception as e: - logger.error(f"Error getting API keys by user: {e}", exc_info=True) - raise - - async def get_by_team(self, team_id: str) -> List[ApiKeyModel]: - """ - Get API keys by team ID - - Args: - team_id: Team ID - - Returns: - List of API keys for the team - """ - try: - keys = [] - query = self.collection.where("team_id", "==", str(team_id)) - results = query.stream() - - for doc in results: - data = doc.to_dict() - data['id'] = doc.id - keys.append(ApiKeyModel(**data)) - return keys - except Exception as e: - logger.error(f"Error getting API keys by team: {e}", exc_info=True) - raise - - async def update_last_used(self, key_id: str) -> None: - """ - Update API key's last used timestamp - - Args: - key_id: API key ID - """ - try: - doc_ref = self.collection.document(str(key_id)) - doc_ref.update({"last_used": datetime.utcnow()}) - except Exception as e: - logger.error(f"Error updating API key last used: {e}", exc_info=True) - raise - - async def deactivate(self, key_id: str) -> bool: - """ - Deactivate API key - - Args: - key_id: API key ID - - Returns: - True if deactivated, False otherwise - """ - try: - doc_ref = self.collection.document(str(key_id)) - doc_ref.update({"is_active": False}) - return True - except Exception as e: - logger.error(f"Error deactivating API key: {e}", exc_info=True) - raise - - async def delete(self, key_id: str) -> bool: - """ - Delete API key - - Args: - key_id: API key ID - - Returns: - True if deleted, False otherwise - """ - try: - doc_ref = self.collection.document(str(key_id)) - doc_ref.delete() - return True - except Exception as e: - logger.error(f"Error deleting API key: {e}", exc_info=True) - raise - -# Create a singleton repository -api_key_repository = ApiKeyRepository() \ No newline at end of file diff --git a/src/db/repositories/firestore_team_repository.py b/src/db/repositories/firestore_team_repository.py index e1d721b..fd3eb58 100644 --- a/src/db/repositories/firestore_team_repository.py +++ b/src/db/repositories/firestore_team_repository.py @@ -1,6 +1,5 @@ import logging from typing import List, Optional -from bson import ObjectId from src.db.repositories.firestore_repository import FirestoreRepository from src.models.team import TeamModel @@ -13,20 +12,17 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]): def __init__(self): super().__init__("teams", TeamModel) - # Override methods that need special handling for ObjectId and other MongoDB specific types - async def get_by_id(self, team_id) -> Optional[TeamModel]: + async def get_by_id(self, team_id: str) -> Optional[TeamModel]: """ Get team by ID Args: - team_id: Team ID (can be ObjectId or string) + team_id: Team ID as string Returns: Team if found, None otherwise """ - # Convert ObjectId to string if needed - doc_id = str(team_id) - return await super().get_by_id(doc_id) + return await super().get_by_id(team_id) async def get_all(self) -> List[TeamModel]: """ @@ -37,34 +33,30 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]): """ return await super().get_all() - async def update(self, team_id, team_data: dict) -> Optional[TeamModel]: + async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]: """ Update team Args: - team_id: Team ID (can be ObjectId or string) + team_id: Team ID as string team_data: Update data Returns: Updated team if found, None otherwise """ - # Convert ObjectId to string if needed - doc_id = str(team_id) - return await super().update(doc_id, team_data) + return await super().update(team_id, team_data) - async def delete(self, team_id) -> bool: + async def delete(self, team_id: str) -> bool: """ Delete team Args: - team_id: Team ID (can be ObjectId or string) + team_id: Team ID as string Returns: True if team was deleted, False otherwise """ - # Convert ObjectId to string if needed - doc_id = str(team_id) - return await super().delete(doc_id) + return await super().delete(team_id) # Create a singleton repository firestore_team_repository = FirestoreTeamRepository() \ No newline at end of file diff --git a/src/db/repositories/image_repository.py b/src/db/repositories/image_repository.py deleted file mode 100644 index e7c411b..0000000 --- a/src/db/repositories/image_repository.py +++ /dev/null @@ -1,302 +0,0 @@ -import logging -from datetime import datetime -from typing import List, Optional, Dict, Any -from bson import ObjectId - -from src.db import db -from src.models.image import ImageModel - -logger = logging.getLogger(__name__) - -class ImageRepository: - """Repository for image operations""" - - collection_name = "images" - - @property - def collection(self): - return db.get_database()[self.collection_name] - - async def create(self, image: ImageModel) -> ImageModel: - """ - Create a new image record - - Args: - image: Image data - - Returns: - Created image with ID - """ - try: - result = await self.collection.insert_one(image.dict(by_alias=True)) - created_image = await self.get_by_id(result.inserted_id) - logger.info(f"Image created: {result.inserted_id}") - return created_image - except Exception as e: - logger.error(f"Error creating image: {e}") - raise - - async def get_by_id(self, image_id: ObjectId) -> Optional[ImageModel]: - """ - Get image by ID - - Args: - image_id: Image ID - - Returns: - Image if found, None otherwise - """ - try: - image = await self.collection.find_one({"_id": image_id}) - if image: - return ImageModel(**image) - return None - except Exception as e: - logger.error(f"Error getting image by ID: {e}") - raise - - async def get_by_ids(self, image_ids: List[str]) -> List[ImageModel]: - """ - Get images by list of IDs - - Args: - image_ids: List of image ID strings - - Returns: - List of images - """ - try: - # Convert string IDs to ObjectIds - object_ids = [] - for image_id in image_ids: - try: - object_ids.append(ObjectId(image_id)) - except: - logger.warning(f"Invalid ObjectId: {image_id}") - continue - - images = [] - cursor = self.collection.find({"_id": {"$in": object_ids}}) - async for document in cursor: - images.append(ImageModel(**document)) - return images - except Exception as e: - logger.error(f"Error getting images by IDs: {e}") - raise - - async def get_by_team( - self, - team_id: ObjectId, - limit: int = 100, - skip: int = 0, - collection_id: Optional[ObjectId] = None, - tags: Optional[List[str]] = None - ) -> List[ImageModel]: - """ - Get images by team ID with pagination and filters - - Args: - team_id: Team ID - limit: Max number of results - skip: Number of records to skip - collection_id: Optional collection filter - tags: Optional tags filter - - Returns: - List of images for the team - """ - try: - # Build query - query = {"team_id": team_id} - - if collection_id: - query["collection_id"] = collection_id - - if tags: - query["tags"] = {"$in": tags} - - images = [] - cursor = self.collection.find(query).sort("upload_date", -1).skip(skip).limit(limit) - async for document in cursor: - images.append(ImageModel(**document)) - return images - except Exception as e: - logger.error(f"Error getting images by team: {e}") - raise - - async def count_by_team( - self, - team_id: ObjectId, - collection_id: Optional[ObjectId] = None, - tags: Optional[List[str]] = None - ) -> int: - """ - Count images by team ID with filters - - Args: - team_id: Team ID - collection_id: Optional collection filter - tags: Optional tags filter - - Returns: - Number of images for the team - """ - try: - # Build query - query = {"team_id": team_id} - - if collection_id: - query["collection_id"] = collection_id - - if tags: - query["tags"] = {"$in": tags} - - return await self.collection.count_documents(query) - except Exception as e: - logger.error(f"Error counting images by team: {e}") - raise - - async def get_by_uploader(self, uploader_id: ObjectId, limit: int = 100, skip: int = 0) -> List[ImageModel]: - """ - Get images by uploader ID with pagination - - Args: - uploader_id: Uploader user ID - limit: Max number of results - skip: Number of records to skip - - Returns: - List of images uploaded by the user - """ - try: - images = [] - cursor = self.collection.find({"uploader_id": uploader_id}).sort("upload_date", -1).skip(skip).limit(limit) - async for document in cursor: - images.append(ImageModel(**document)) - return images - except Exception as e: - logger.error(f"Error getting images by uploader: {e}") - raise - - async def search_by_metadata(self, team_id: ObjectId, query: Dict[str, Any], limit: int = 100, skip: int = 0) -> List[ImageModel]: - """ - Search images by metadata - - Args: - team_id: Team ID - query: Search query - limit: Max number of results - skip: Number of records to skip - - Returns: - List of matching images - """ - try: - # Ensure we only search within the team's images - search_query = {"team_id": team_id, **query} - - images = [] - cursor = self.collection.find(search_query).sort("upload_date", -1).skip(skip).limit(limit) - async for document in cursor: - images.append(ImageModel(**document)) - return images - except Exception as e: - logger.error(f"Error searching images by metadata: {e}") - raise - - async def update(self, image_id: ObjectId, image_data: dict) -> Optional[ImageModel]: - """ - Update image - - Args: - image_id: Image ID - image_data: Update data - - Returns: - Updated image if found, None otherwise - """ - try: - # Don't allow updating _id - if "_id" in image_data: - del image_data["_id"] - - result = await self.collection.update_one( - {"_id": image_id}, - {"$set": image_data} - ) - - if result.modified_count == 0: - logger.warning(f"No image updated for ID: {image_id}") - return None - - return await self.get_by_id(image_id) - except Exception as e: - logger.error(f"Error updating image: {e}") - raise - - async def update_last_accessed(self, image_id: ObjectId) -> None: - """ - Update image's last accessed timestamp - - Args: - image_id: Image ID - """ - try: - await self.collection.update_one( - {"_id": image_id}, - {"$set": {"last_accessed": datetime.utcnow()}} - ) - except Exception as e: - logger.error(f"Error updating image last accessed: {e}") - raise - - async def delete(self, image_id: ObjectId) -> bool: - """ - Delete image - - Args: - image_id: Image ID - - Returns: - True if image was deleted, False otherwise - """ - try: - result = await self.collection.delete_one({"_id": image_id}) - return result.deleted_count > 0 - except Exception as e: - logger.error(f"Error deleting image: {e}") - raise - - async def update_embedding_status(self, image_id: ObjectId, embedding_id: str, model: str) -> Optional[ImageModel]: - """ - Update image embedding status - - Args: - image_id: Image ID - embedding_id: Vector DB embedding ID - model: Model used for embedding - - Returns: - Updated image if found, None otherwise - """ - try: - result = await self.collection.update_one( - {"_id": image_id}, - {"$set": { - "embedding_id": embedding_id, - "embedding_model": model, - "has_embedding": True - }} - ) - - if result.modified_count == 0: - logger.warning(f"No image updated for embedding ID: {image_id}") - return None - - return await self.get_by_id(image_id) - except Exception as e: - logger.error(f"Error updating image embedding status: {e}") - raise - -# Create a singleton repository -image_repository = ImageRepository() \ No newline at end of file diff --git a/src/db/repositories/team_repository.py b/src/db/repositories/team_repository.py deleted file mode 100644 index cefee75..0000000 --- a/src/db/repositories/team_repository.py +++ /dev/null @@ -1,151 +0,0 @@ -import logging -from datetime import datetime -from typing import List, Optional -from bson import ObjectId - -from src.db import db -from src.models.team import TeamModel - -logger = logging.getLogger(__name__) - -class TeamRepository: - """Repository for team operations""" - - collection_name = "teams" - - @property - def collection(self): - client = db.get_database() - if client is None: - logger.error("Database connection is None, cannot access collection") - raise RuntimeError("Database connection is not available") - # Use the Firestore client to get the collection - return client.collection(self.collection_name) - - async def create(self, team: TeamModel) -> TeamModel: - """ - Create a new team - - Args: - team: Team data - - Returns: - Created team with ID - """ - try: - # Convert team to dict for Firestore - team_dict = team.dict(by_alias=True) - - # Add document to Firestore and get reference - doc_ref = self.collection.document() - doc_ref.set(team_dict) - - # Update ID and retrieve the created team - team.id = doc_ref.id - logger.info(f"Team created: {doc_ref.id}") - - return team - except Exception as e: - logger.error(f"Error creating team: {e}") - raise - - async def get_by_id(self, team_id: ObjectId) -> Optional[TeamModel]: - """ - Get team by ID - - Args: - team_id: Team ID - - Returns: - Team if found, None otherwise - """ - try: - doc_ref = self.collection.document(str(team_id)) - team_doc = doc_ref.get() - - if team_doc.exists: - team_data = team_doc.to_dict() - team_data['id'] = team_doc.id - return TeamModel(**team_data) - return None - except Exception as e: - logger.error(f"Error getting team by ID: {e}") - raise - - async def get_all(self) -> List[TeamModel]: - """ - Get all teams - - Returns: - List of teams - """ - try: - teams = [] - team_docs = self.collection.stream() - - for doc in team_docs: - team_data = doc.to_dict() - team_data['id'] = doc.id - teams.append(TeamModel(**team_data)) - - return teams - except Exception as e: - logger.error(f"Error getting all teams: {e}") - raise - - async def update(self, team_id: ObjectId, team_data: dict) -> Optional[TeamModel]: - """ - Update team - - Args: - team_id: Team ID - team_data: Update data - - Returns: - Updated team if found, None otherwise - """ - try: - # Add updated_at timestamp - team_data["updated_at"] = datetime.utcnow() - - # Don't allow updating _id or id - if "_id" in team_data: - del team_data["_id"] - if "id" in team_data: - del team_data["id"] - - # Get document reference and update - doc_ref = self.collection.document(str(team_id)) - doc_ref.update(team_data) - - # Get updated team - updated_team = await self.get_by_id(team_id) - if not updated_team: - logger.warning(f"No team found after update for ID: {team_id}") - return None - - return updated_team - except Exception as e: - logger.error(f"Error updating team: {e}") - raise - - async def delete(self, team_id: ObjectId) -> bool: - """ - Delete team - - Args: - team_id: Team ID - - Returns: - True if team was deleted, False otherwise - """ - try: - doc_ref = self.collection.document(str(team_id)) - doc_ref.delete() - return True - except Exception as e: - logger.error(f"Error deleting team: {e}") - raise - -# Create a singleton repository -team_repository = TeamRepository() \ No newline at end of file diff --git a/src/db/repositories/user_repository.py b/src/db/repositories/user_repository.py deleted file mode 100644 index e6cc5c1..0000000 --- a/src/db/repositories/user_repository.py +++ /dev/null @@ -1,191 +0,0 @@ -import logging -from datetime import datetime -from typing import List, Optional -from bson import ObjectId - -from src.db import db -from src.models.user import UserModel - -logger = logging.getLogger(__name__) - -class UserRepository: - """Repository for user operations""" - - collection_name = "users" - - @property - def collection(self): - client = db.get_database() - if client is None: - logger.error("Database connection is None, cannot access collection") - raise RuntimeError("Database connection is not available") - # Use the Firestore client to get the collection - return client.collection(self.collection_name) - - async def create(self, user: UserModel) -> UserModel: - """ - Create a new user - - Args: - user: User data - - Returns: - Created user with ID - """ - try: - # Convert user to dict for Firestore - user_dict = user.dict(by_alias=True) - - # Add document to Firestore and get reference - doc_ref = self.collection.document() - doc_ref.set(user_dict) - - # Update ID and retrieve the created user - user.id = doc_ref.id - logger.info(f"User created: {doc_ref.id}") - - return user - except Exception as e: - logger.error(f"Error creating user: {e}") - raise - - async def get_by_id(self, user_id: ObjectId) -> Optional[UserModel]: - """ - Get user by ID - - Args: - user_id: User ID - - Returns: - User if found, None otherwise - """ - try: - doc_ref = self.collection.document(str(user_id)) - user_doc = doc_ref.get() - - if user_doc.exists: - user_data = user_doc.to_dict() - user_data['id'] = user_doc.id - return UserModel(**user_data) - return None - except Exception as e: - logger.error(f"Error getting user by ID: {e}") - raise - - async def get_by_email(self, email: str) -> Optional[UserModel]: - """ - Get user by email - - Args: - email: User email - - Returns: - User if found, None otherwise - """ - try: - query = self.collection.where("email", "==", email).limit(1) - results = query.stream() - - for doc in results: - user_data = doc.to_dict() - user_data['id'] = doc.id - return UserModel(**user_data) - return None - except Exception as e: - logger.error(f"Error getting user by email: {e}") - raise - - async def get_by_team(self, team_id: ObjectId) -> List[UserModel]: - """ - Get users by team ID - - Args: - team_id: Team ID - - Returns: - List of users in the team - """ - try: - users = [] - query = self.collection.where("team_id", "==", str(team_id)) - results = query.stream() - - for doc in results: - user_data = doc.to_dict() - user_data['id'] = doc.id - users.append(UserModel(**user_data)) - return users - except Exception as e: - logger.error(f"Error getting users by team: {e}") - raise - - async def update(self, user_id: ObjectId, user_data: dict) -> Optional[UserModel]: - """ - Update user - - Args: - user_id: User ID - user_data: Update data - - Returns: - Updated user if found, None otherwise - """ - try: - # Add updated_at timestamp - user_data["updated_at"] = datetime.utcnow() - - # Don't allow updating _id or id - if "_id" in user_data: - del user_data["_id"] - if "id" in user_data: - del user_data["id"] - - # Get document reference and update - doc_ref = self.collection.document(str(user_id)) - doc_ref.update(user_data) - - # Get updated user - updated_user = await self.get_by_id(user_id) - if not updated_user: - logger.warning(f"No user found after update for ID: {user_id}") - return None - - return updated_user - except Exception as e: - logger.error(f"Error updating user: {e}") - raise - - async def delete(self, user_id: ObjectId) -> bool: - """ - Delete user - - Args: - user_id: User ID - - Returns: - True if user was deleted, False otherwise - """ - try: - doc_ref = self.collection.document(str(user_id)) - doc_ref.delete() - return True - except Exception as e: - logger.error(f"Error deleting user: {e}") - raise - - async def update_last_login(self, user_id: ObjectId) -> None: - """ - Update user's last login timestamp - - Args: - user_id: User ID - """ - try: - doc_ref = self.collection.document(str(user_id)) - doc_ref.update({"last_login": datetime.utcnow()}) - except Exception as e: - logger.error(f"Error updating user last login: {e}") - raise - -# Create a singleton repository -user_repository = UserRepository() \ No newline at end of file