This commit is contained in:
johnpccd 2025-05-24 12:48:46 +02:00
parent d8aec1e6b4
commit fcbf6e9f85
10 changed files with 73 additions and 987 deletions

View File

@ -28,7 +28,12 @@ sereact/
│ │ └── v1/ # API version 1 routes
│ ├── auth/ # Authentication and authorization
│ ├── config/ # Configuration management
│ ├── core/ # Core application logic
│ ├── db/ # Database layer
│ │ ├── providers/ # Database providers (Firestore)
│ │ └── repositories/ # Data access repositories
│ ├── models/ # Database models
│ ├── schemas/ # API request/response schemas
│ ├── services/ # Business logic services
│ └── utils/ # Utility functions
├── tests/ # Test code

View File

@ -1,100 +1,44 @@
import logging
import os
import json
from google.cloud import firestore
from google.oauth2 import service_account
from src.config.config import settings
logger = logging.getLogger(__name__)
class Database:
client = None
database_name = None
def connect_to_database(self):
"""Create database connection."""
try:
# Store database name
self.database_name = settings.FIRESTORE_DATABASE_NAME
# Print project ID and database name for debugging
logger.info(f"Attempting to connect to Firestore with project ID: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}")
# First try using Application Default Credentials (for Cloud environments)
try:
logger.info("Attempting to connect using Application Default Credentials")
self.client = firestore.Client(
project=settings.FIRESTORE_PROJECT_ID,
database=self.database_name
)
# Test connection by trying to access a collection
self.client.collection('test').limit(1).get()
logger.info(f"Connected to Firestore project using Application Default Credentials: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}")
return
except Exception as adc_error:
logger.error(f"Application Default Credentials failed: {adc_error}", exc_info=True)
# Fall back to service account file
credentials_path = settings.FIRESTORE_CREDENTIALS_FILE
if not os.path.exists(credentials_path):
logger.error(f"Firestore credentials file not found: {credentials_path}")
raise FileNotFoundError(f"Credentials file not found: {credentials_path}")
# Print key file contents (without sensitive parts) for debugging
try:
with open(credentials_path, 'r') as f:
key_data = json.load(f)
# Log non-sensitive parts of the key
logger.info(f"Using credentials file with project_id: {key_data.get('project_id')}")
logger.info(f"Client email: {key_data.get('client_email')}")
logger.info(f"Key file type: {key_data.get('type')}")
except Exception as e:
logger.error(f"Error reading key file: {e}")
# Load credentials
credentials = service_account.Credentials.from_service_account_file(credentials_path)
# Initialize Firestore client
self.client = firestore.Client(
project=settings.FIRESTORE_PROJECT_ID,
database=self.database_name,
credentials=credentials
)
# Test connection by trying to access a collection
self.client.collection('test').limit(1).get()
logger.info(f"Connected to Firestore project using credentials file: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}")
except Exception as e:
logger.error(f"Failed to connect to Firestore: {e}", exc_info=True)
raise
def close_database_connection(self):
"""Close database connection."""
try:
# No explicit close method needed for Firestore client
self.client = None
logger.info("Closed Firestore connection")
except Exception as e:
logger.error(f"Failed to close Firestore connection: {e}")
def get_database(self):
"""Get the database instance."""
if self.client is None:
logger.warning("Database client is None. Attempting to reconnect...")
try:
self.connect_to_database()
except Exception as e:
logger.error(f"Failed to reconnect to database: {e}")
# Return None to avoid further errors but log this issue
return None
# Verify that client is properly initialized
if self.client is None:
logger.error("Database client is still None after reconnect attempt")
return None
return self.client
# Create a singleton database instance
db = Database()
import logging
from src.config.config import settings
from src.db.providers.firestore_provider import FirestoreProvider
logger = logging.getLogger(__name__)
class Database:
"""Database connection manager for Firestore"""
def __init__(self):
self.provider = FirestoreProvider()
def connect_to_database(self):
"""Create database connection."""
try:
logger.info(f"Connecting to Firestore database: {settings.FIRESTORE_DATABASE_NAME}")
self.provider.connect()
logger.info(f"Successfully connected to Firestore database: {settings.FIRESTORE_DATABASE_NAME}")
except Exception as e:
logger.error(f"Failed to connect to Firestore: {e}", exc_info=True)
raise
def close_database_connection(self):
"""Close database connection."""
try:
self.provider.disconnect()
logger.info("Closed Firestore connection")
except Exception as e:
logger.error(f"Failed to close Firestore connection: {e}")
def get_database(self):
"""Get the database instance."""
if self.provider.client is None:
logger.warning("Database client is None. Attempting to reconnect...")
try:
self.connect_to_database()
except Exception as e:
logger.error(f"Failed to reconnect to database: {e}")
return None
return self.provider.client
# Create a singleton database instance
db = Database()

View File

@ -0,0 +1,5 @@
"""Database providers package"""
from .firestore_provider import FirestoreProvider, firestore_db
__all__ = ["FirestoreProvider", "firestore_db"]

View File

@ -99,7 +99,7 @@ class FirestoreProvider:
try:
collection = self.get_collection(collection_name)
# Handle ObjectId conversion for Firestore
# Process data for Firestore
processed_data = {}
for key, value in data.items():
if value is None:

View File

@ -3,36 +3,19 @@ from src.config.config import settings
logger = logging.getLogger(__name__)
# Determine which repository implementation to use
# Default to Firestore for now
REPOSITORY_TYPE = "firestore"
def init_repositories():
"""Initialize the repository implementations based on configuration."""
global REPOSITORY_TYPE
"""Initialize the Firestore repository implementations."""
logger.info("Initializing Firestore repositories")
logger.info(f"Initializing repositories with type: {REPOSITORY_TYPE}")
# Import and initialize Firestore repositories
from src.db.providers.firestore_provider import FirestoreProvider
if REPOSITORY_TYPE == "firestore":
# Import and initialize Firestore repositories
from src.db.repositories.firestore_team_repository import firestore_team_repository
from src.db.providers.firestore_provider import firestore_db
# Initialize the Firestore provider
firestore_provider = FirestoreProvider()
if not firestore_provider.client:
firestore_provider.connect()
# Make sure Firestore client is initialized
if not firestore_db.client:
firestore_db.connect()
# Replace the default repositories with Firestore implementations
from src.db.repositories.team_repository import team_repository as default_team_repo
# Dynamically update the module to use Firestore repositories
import sys
sys.modules["src.db.repositories.team_repository"].team_repository = firestore_team_repository
logger.info("Firestore repositories initialized")
else:
# Default MongoDB repositories are already imported
logger.info("Using default MongoDB repositories")
logger.info("Firestore repositories initialized successfully")
# Initialize repositories at module import time
init_repositories()

View File

@ -1,199 +0,0 @@
import logging
from datetime import datetime
from typing import List, Optional
from bson import ObjectId
from src.db import db
from src.models.api_key import ApiKeyModel
logger = logging.getLogger(__name__)
class ApiKeyRepository:
"""Repository for API key operations"""
collection_name = "api_keys"
@property
def collection(self):
client = db.get_database()
if client is None:
logger.error("Database connection is None, cannot access collection")
raise RuntimeError("Database connection is not available")
# Use the Firestore client to get the collection
return client.collection(self.collection_name)
async def create(self, api_key: ApiKeyModel) -> ApiKeyModel:
"""
Create a new API key
Args:
api_key: API key data
Returns:
Created API key with ID
"""
try:
# Convert to dict for Firestore
api_key_dict = api_key.dict(by_alias=True)
# Create a new document with an auto-generated ID
doc_ref = self.collection.document()
# Set the ID in the model and data
api_key.id = doc_ref.id
api_key_dict['id'] = doc_ref.id
# Save to Firestore
doc_ref.set(api_key_dict)
logger.info(f"API key created: {doc_ref.id}")
return api_key
except Exception as e:
logger.error(f"Error creating API key: {e}", exc_info=True)
raise
async def get_by_id(self, key_id: str) -> Optional[ApiKeyModel]:
"""
Get API key by ID
Args:
key_id: API key ID
Returns:
API key if found, None otherwise
"""
try:
doc_ref = self.collection.document(str(key_id))
doc = doc_ref.get()
if doc.exists:
data = doc.to_dict()
data['id'] = doc.id
return ApiKeyModel(**data)
return None
except Exception as e:
logger.error(f"Error getting API key by ID: {e}", exc_info=True)
raise
async def get_by_hash(self, key_hash: str) -> Optional[ApiKeyModel]:
"""
Get API key by hash
Args:
key_hash: API key hash
Returns:
API key if found, None otherwise
"""
try:
query = self.collection.where("key_hash", "==", key_hash).limit(1)
results = query.stream()
for doc in results:
data = doc.to_dict()
data['id'] = doc.id
return ApiKeyModel(**data)
return None
except Exception as e:
logger.error(f"Error getting API key by hash: {e}", exc_info=True)
raise
async def get_by_user(self, user_id: str) -> List[ApiKeyModel]:
"""
Get API keys by user ID
Args:
user_id: User ID
Returns:
List of API keys for the user
"""
try:
keys = []
query = self.collection.where("user_id", "==", str(user_id))
results = query.stream()
for doc in results:
data = doc.to_dict()
data['id'] = doc.id
keys.append(ApiKeyModel(**data))
return keys
except Exception as e:
logger.error(f"Error getting API keys by user: {e}", exc_info=True)
raise
async def get_by_team(self, team_id: str) -> List[ApiKeyModel]:
"""
Get API keys by team ID
Args:
team_id: Team ID
Returns:
List of API keys for the team
"""
try:
keys = []
query = self.collection.where("team_id", "==", str(team_id))
results = query.stream()
for doc in results:
data = doc.to_dict()
data['id'] = doc.id
keys.append(ApiKeyModel(**data))
return keys
except Exception as e:
logger.error(f"Error getting API keys by team: {e}", exc_info=True)
raise
async def update_last_used(self, key_id: str) -> None:
"""
Update API key's last used timestamp
Args:
key_id: API key ID
"""
try:
doc_ref = self.collection.document(str(key_id))
doc_ref.update({"last_used": datetime.utcnow()})
except Exception as e:
logger.error(f"Error updating API key last used: {e}", exc_info=True)
raise
async def deactivate(self, key_id: str) -> bool:
"""
Deactivate API key
Args:
key_id: API key ID
Returns:
True if deactivated, False otherwise
"""
try:
doc_ref = self.collection.document(str(key_id))
doc_ref.update({"is_active": False})
return True
except Exception as e:
logger.error(f"Error deactivating API key: {e}", exc_info=True)
raise
async def delete(self, key_id: str) -> bool:
"""
Delete API key
Args:
key_id: API key ID
Returns:
True if deleted, False otherwise
"""
try:
doc_ref = self.collection.document(str(key_id))
doc_ref.delete()
return True
except Exception as e:
logger.error(f"Error deleting API key: {e}", exc_info=True)
raise
# Create a singleton repository
api_key_repository = ApiKeyRepository()

View File

@ -1,6 +1,5 @@
import logging
from typing import List, Optional
from bson import ObjectId
from src.db.repositories.firestore_repository import FirestoreRepository
from src.models.team import TeamModel
@ -13,20 +12,17 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
def __init__(self):
super().__init__("teams", TeamModel)
# Override methods that need special handling for ObjectId and other MongoDB specific types
async def get_by_id(self, team_id) -> Optional[TeamModel]:
async def get_by_id(self, team_id: str) -> Optional[TeamModel]:
"""
Get team by ID
Args:
team_id: Team ID (can be ObjectId or string)
team_id: Team ID as string
Returns:
Team if found, None otherwise
"""
# Convert ObjectId to string if needed
doc_id = str(team_id)
return await super().get_by_id(doc_id)
return await super().get_by_id(team_id)
async def get_all(self) -> List[TeamModel]:
"""
@ -37,34 +33,30 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
"""
return await super().get_all()
async def update(self, team_id, team_data: dict) -> Optional[TeamModel]:
async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]:
"""
Update team
Args:
team_id: Team ID (can be ObjectId or string)
team_id: Team ID as string
team_data: Update data
Returns:
Updated team if found, None otherwise
"""
# Convert ObjectId to string if needed
doc_id = str(team_id)
return await super().update(doc_id, team_data)
return await super().update(team_id, team_data)
async def delete(self, team_id) -> bool:
async def delete(self, team_id: str) -> bool:
"""
Delete team
Args:
team_id: Team ID (can be ObjectId or string)
team_id: Team ID as string
Returns:
True if team was deleted, False otherwise
"""
# Convert ObjectId to string if needed
doc_id = str(team_id)
return await super().delete(doc_id)
return await super().delete(team_id)
# Create a singleton repository
firestore_team_repository = FirestoreTeamRepository()

View File

@ -1,302 +0,0 @@
import logging
from datetime import datetime
from typing import List, Optional, Dict, Any
from bson import ObjectId
from src.db import db
from src.models.image import ImageModel
logger = logging.getLogger(__name__)
class ImageRepository:
"""Repository for image operations"""
collection_name = "images"
@property
def collection(self):
return db.get_database()[self.collection_name]
async def create(self, image: ImageModel) -> ImageModel:
"""
Create a new image record
Args:
image: Image data
Returns:
Created image with ID
"""
try:
result = await self.collection.insert_one(image.dict(by_alias=True))
created_image = await self.get_by_id(result.inserted_id)
logger.info(f"Image created: {result.inserted_id}")
return created_image
except Exception as e:
logger.error(f"Error creating image: {e}")
raise
async def get_by_id(self, image_id: ObjectId) -> Optional[ImageModel]:
"""
Get image by ID
Args:
image_id: Image ID
Returns:
Image if found, None otherwise
"""
try:
image = await self.collection.find_one({"_id": image_id})
if image:
return ImageModel(**image)
return None
except Exception as e:
logger.error(f"Error getting image by ID: {e}")
raise
async def get_by_ids(self, image_ids: List[str]) -> List[ImageModel]:
"""
Get images by list of IDs
Args:
image_ids: List of image ID strings
Returns:
List of images
"""
try:
# Convert string IDs to ObjectIds
object_ids = []
for image_id in image_ids:
try:
object_ids.append(ObjectId(image_id))
except:
logger.warning(f"Invalid ObjectId: {image_id}")
continue
images = []
cursor = self.collection.find({"_id": {"$in": object_ids}})
async for document in cursor:
images.append(ImageModel(**document))
return images
except Exception as e:
logger.error(f"Error getting images by IDs: {e}")
raise
async def get_by_team(
self,
team_id: ObjectId,
limit: int = 100,
skip: int = 0,
collection_id: Optional[ObjectId] = None,
tags: Optional[List[str]] = None
) -> List[ImageModel]:
"""
Get images by team ID with pagination and filters
Args:
team_id: Team ID
limit: Max number of results
skip: Number of records to skip
collection_id: Optional collection filter
tags: Optional tags filter
Returns:
List of images for the team
"""
try:
# Build query
query = {"team_id": team_id}
if collection_id:
query["collection_id"] = collection_id
if tags:
query["tags"] = {"$in": tags}
images = []
cursor = self.collection.find(query).sort("upload_date", -1).skip(skip).limit(limit)
async for document in cursor:
images.append(ImageModel(**document))
return images
except Exception as e:
logger.error(f"Error getting images by team: {e}")
raise
async def count_by_team(
self,
team_id: ObjectId,
collection_id: Optional[ObjectId] = None,
tags: Optional[List[str]] = None
) -> int:
"""
Count images by team ID with filters
Args:
team_id: Team ID
collection_id: Optional collection filter
tags: Optional tags filter
Returns:
Number of images for the team
"""
try:
# Build query
query = {"team_id": team_id}
if collection_id:
query["collection_id"] = collection_id
if tags:
query["tags"] = {"$in": tags}
return await self.collection.count_documents(query)
except Exception as e:
logger.error(f"Error counting images by team: {e}")
raise
async def get_by_uploader(self, uploader_id: ObjectId, limit: int = 100, skip: int = 0) -> List[ImageModel]:
"""
Get images by uploader ID with pagination
Args:
uploader_id: Uploader user ID
limit: Max number of results
skip: Number of records to skip
Returns:
List of images uploaded by the user
"""
try:
images = []
cursor = self.collection.find({"uploader_id": uploader_id}).sort("upload_date", -1).skip(skip).limit(limit)
async for document in cursor:
images.append(ImageModel(**document))
return images
except Exception as e:
logger.error(f"Error getting images by uploader: {e}")
raise
async def search_by_metadata(self, team_id: ObjectId, query: Dict[str, Any], limit: int = 100, skip: int = 0) -> List[ImageModel]:
"""
Search images by metadata
Args:
team_id: Team ID
query: Search query
limit: Max number of results
skip: Number of records to skip
Returns:
List of matching images
"""
try:
# Ensure we only search within the team's images
search_query = {"team_id": team_id, **query}
images = []
cursor = self.collection.find(search_query).sort("upload_date", -1).skip(skip).limit(limit)
async for document in cursor:
images.append(ImageModel(**document))
return images
except Exception as e:
logger.error(f"Error searching images by metadata: {e}")
raise
async def update(self, image_id: ObjectId, image_data: dict) -> Optional[ImageModel]:
"""
Update image
Args:
image_id: Image ID
image_data: Update data
Returns:
Updated image if found, None otherwise
"""
try:
# Don't allow updating _id
if "_id" in image_data:
del image_data["_id"]
result = await self.collection.update_one(
{"_id": image_id},
{"$set": image_data}
)
if result.modified_count == 0:
logger.warning(f"No image updated for ID: {image_id}")
return None
return await self.get_by_id(image_id)
except Exception as e:
logger.error(f"Error updating image: {e}")
raise
async def update_last_accessed(self, image_id: ObjectId) -> None:
"""
Update image's last accessed timestamp
Args:
image_id: Image ID
"""
try:
await self.collection.update_one(
{"_id": image_id},
{"$set": {"last_accessed": datetime.utcnow()}}
)
except Exception as e:
logger.error(f"Error updating image last accessed: {e}")
raise
async def delete(self, image_id: ObjectId) -> bool:
"""
Delete image
Args:
image_id: Image ID
Returns:
True if image was deleted, False otherwise
"""
try:
result = await self.collection.delete_one({"_id": image_id})
return result.deleted_count > 0
except Exception as e:
logger.error(f"Error deleting image: {e}")
raise
async def update_embedding_status(self, image_id: ObjectId, embedding_id: str, model: str) -> Optional[ImageModel]:
"""
Update image embedding status
Args:
image_id: Image ID
embedding_id: Vector DB embedding ID
model: Model used for embedding
Returns:
Updated image if found, None otherwise
"""
try:
result = await self.collection.update_one(
{"_id": image_id},
{"$set": {
"embedding_id": embedding_id,
"embedding_model": model,
"has_embedding": True
}}
)
if result.modified_count == 0:
logger.warning(f"No image updated for embedding ID: {image_id}")
return None
return await self.get_by_id(image_id)
except Exception as e:
logger.error(f"Error updating image embedding status: {e}")
raise
# Create a singleton repository
image_repository = ImageRepository()

View File

@ -1,151 +0,0 @@
import logging
from datetime import datetime
from typing import List, Optional
from bson import ObjectId
from src.db import db
from src.models.team import TeamModel
logger = logging.getLogger(__name__)
class TeamRepository:
"""Repository for team operations"""
collection_name = "teams"
@property
def collection(self):
client = db.get_database()
if client is None:
logger.error("Database connection is None, cannot access collection")
raise RuntimeError("Database connection is not available")
# Use the Firestore client to get the collection
return client.collection(self.collection_name)
async def create(self, team: TeamModel) -> TeamModel:
"""
Create a new team
Args:
team: Team data
Returns:
Created team with ID
"""
try:
# Convert team to dict for Firestore
team_dict = team.dict(by_alias=True)
# Add document to Firestore and get reference
doc_ref = self.collection.document()
doc_ref.set(team_dict)
# Update ID and retrieve the created team
team.id = doc_ref.id
logger.info(f"Team created: {doc_ref.id}")
return team
except Exception as e:
logger.error(f"Error creating team: {e}")
raise
async def get_by_id(self, team_id: ObjectId) -> Optional[TeamModel]:
"""
Get team by ID
Args:
team_id: Team ID
Returns:
Team if found, None otherwise
"""
try:
doc_ref = self.collection.document(str(team_id))
team_doc = doc_ref.get()
if team_doc.exists:
team_data = team_doc.to_dict()
team_data['id'] = team_doc.id
return TeamModel(**team_data)
return None
except Exception as e:
logger.error(f"Error getting team by ID: {e}")
raise
async def get_all(self) -> List[TeamModel]:
"""
Get all teams
Returns:
List of teams
"""
try:
teams = []
team_docs = self.collection.stream()
for doc in team_docs:
team_data = doc.to_dict()
team_data['id'] = doc.id
teams.append(TeamModel(**team_data))
return teams
except Exception as e:
logger.error(f"Error getting all teams: {e}")
raise
async def update(self, team_id: ObjectId, team_data: dict) -> Optional[TeamModel]:
"""
Update team
Args:
team_id: Team ID
team_data: Update data
Returns:
Updated team if found, None otherwise
"""
try:
# Add updated_at timestamp
team_data["updated_at"] = datetime.utcnow()
# Don't allow updating _id or id
if "_id" in team_data:
del team_data["_id"]
if "id" in team_data:
del team_data["id"]
# Get document reference and update
doc_ref = self.collection.document(str(team_id))
doc_ref.update(team_data)
# Get updated team
updated_team = await self.get_by_id(team_id)
if not updated_team:
logger.warning(f"No team found after update for ID: {team_id}")
return None
return updated_team
except Exception as e:
logger.error(f"Error updating team: {e}")
raise
async def delete(self, team_id: ObjectId) -> bool:
"""
Delete team
Args:
team_id: Team ID
Returns:
True if team was deleted, False otherwise
"""
try:
doc_ref = self.collection.document(str(team_id))
doc_ref.delete()
return True
except Exception as e:
logger.error(f"Error deleting team: {e}")
raise
# Create a singleton repository
team_repository = TeamRepository()

View File

@ -1,191 +0,0 @@
import logging
from datetime import datetime
from typing import List, Optional
from bson import ObjectId
from src.db import db
from src.models.user import UserModel
logger = logging.getLogger(__name__)
class UserRepository:
"""Repository for user operations"""
collection_name = "users"
@property
def collection(self):
client = db.get_database()
if client is None:
logger.error("Database connection is None, cannot access collection")
raise RuntimeError("Database connection is not available")
# Use the Firestore client to get the collection
return client.collection(self.collection_name)
async def create(self, user: UserModel) -> UserModel:
"""
Create a new user
Args:
user: User data
Returns:
Created user with ID
"""
try:
# Convert user to dict for Firestore
user_dict = user.dict(by_alias=True)
# Add document to Firestore and get reference
doc_ref = self.collection.document()
doc_ref.set(user_dict)
# Update ID and retrieve the created user
user.id = doc_ref.id
logger.info(f"User created: {doc_ref.id}")
return user
except Exception as e:
logger.error(f"Error creating user: {e}")
raise
async def get_by_id(self, user_id: ObjectId) -> Optional[UserModel]:
"""
Get user by ID
Args:
user_id: User ID
Returns:
User if found, None otherwise
"""
try:
doc_ref = self.collection.document(str(user_id))
user_doc = doc_ref.get()
if user_doc.exists:
user_data = user_doc.to_dict()
user_data['id'] = user_doc.id
return UserModel(**user_data)
return None
except Exception as e:
logger.error(f"Error getting user by ID: {e}")
raise
async def get_by_email(self, email: str) -> Optional[UserModel]:
"""
Get user by email
Args:
email: User email
Returns:
User if found, None otherwise
"""
try:
query = self.collection.where("email", "==", email).limit(1)
results = query.stream()
for doc in results:
user_data = doc.to_dict()
user_data['id'] = doc.id
return UserModel(**user_data)
return None
except Exception as e:
logger.error(f"Error getting user by email: {e}")
raise
async def get_by_team(self, team_id: ObjectId) -> List[UserModel]:
"""
Get users by team ID
Args:
team_id: Team ID
Returns:
List of users in the team
"""
try:
users = []
query = self.collection.where("team_id", "==", str(team_id))
results = query.stream()
for doc in results:
user_data = doc.to_dict()
user_data['id'] = doc.id
users.append(UserModel(**user_data))
return users
except Exception as e:
logger.error(f"Error getting users by team: {e}")
raise
async def update(self, user_id: ObjectId, user_data: dict) -> Optional[UserModel]:
"""
Update user
Args:
user_id: User ID
user_data: Update data
Returns:
Updated user if found, None otherwise
"""
try:
# Add updated_at timestamp
user_data["updated_at"] = datetime.utcnow()
# Don't allow updating _id or id
if "_id" in user_data:
del user_data["_id"]
if "id" in user_data:
del user_data["id"]
# Get document reference and update
doc_ref = self.collection.document(str(user_id))
doc_ref.update(user_data)
# Get updated user
updated_user = await self.get_by_id(user_id)
if not updated_user:
logger.warning(f"No user found after update for ID: {user_id}")
return None
return updated_user
except Exception as e:
logger.error(f"Error updating user: {e}")
raise
async def delete(self, user_id: ObjectId) -> bool:
"""
Delete user
Args:
user_id: User ID
Returns:
True if user was deleted, False otherwise
"""
try:
doc_ref = self.collection.document(str(user_id))
doc_ref.delete()
return True
except Exception as e:
logger.error(f"Error deleting user: {e}")
raise
async def update_last_login(self, user_id: ObjectId) -> None:
"""
Update user's last login timestamp
Args:
user_id: User ID
"""
try:
doc_ref = self.collection.document(str(user_id))
doc_ref.update({"last_login": datetime.utcnow()})
except Exception as e:
logger.error(f"Error updating user last login: {e}")
raise
# Create a singleton repository
user_repository = UserRepository()