cp
This commit is contained in:
parent
d8aec1e6b4
commit
fcbf6e9f85
@ -28,7 +28,12 @@ sereact/
|
||||
│ │ └── v1/ # API version 1 routes
|
||||
│ ├── auth/ # Authentication and authorization
|
||||
│ ├── config/ # Configuration management
|
||||
│ ├── core/ # Core application logic
|
||||
│ ├── db/ # Database layer
|
||||
│ │ ├── providers/ # Database providers (Firestore)
|
||||
│ │ └── repositories/ # Data access repositories
|
||||
│ ├── models/ # Database models
|
||||
│ ├── schemas/ # API request/response schemas
|
||||
│ ├── services/ # Business logic services
|
||||
│ └── utils/ # Utility functions
|
||||
├── tests/ # Test code
|
||||
|
||||
@ -1,70 +1,21 @@
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
from google.cloud import firestore
|
||||
from google.oauth2 import service_account
|
||||
from src.config.config import settings
|
||||
from src.db.providers.firestore_provider import FirestoreProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Database:
|
||||
client = None
|
||||
database_name = None
|
||||
"""Database connection manager for Firestore"""
|
||||
|
||||
def __init__(self):
|
||||
self.provider = FirestoreProvider()
|
||||
|
||||
def connect_to_database(self):
|
||||
"""Create database connection."""
|
||||
try:
|
||||
# Store database name
|
||||
self.database_name = settings.FIRESTORE_DATABASE_NAME
|
||||
|
||||
# Print project ID and database name for debugging
|
||||
logger.info(f"Attempting to connect to Firestore with project ID: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}")
|
||||
|
||||
# First try using Application Default Credentials (for Cloud environments)
|
||||
try:
|
||||
logger.info("Attempting to connect using Application Default Credentials")
|
||||
self.client = firestore.Client(
|
||||
project=settings.FIRESTORE_PROJECT_ID,
|
||||
database=self.database_name
|
||||
)
|
||||
# Test connection by trying to access a collection
|
||||
self.client.collection('test').limit(1).get()
|
||||
logger.info(f"Connected to Firestore project using Application Default Credentials: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}")
|
||||
return
|
||||
except Exception as adc_error:
|
||||
logger.error(f"Application Default Credentials failed: {adc_error}", exc_info=True)
|
||||
|
||||
# Fall back to service account file
|
||||
credentials_path = settings.FIRESTORE_CREDENTIALS_FILE
|
||||
if not os.path.exists(credentials_path):
|
||||
logger.error(f"Firestore credentials file not found: {credentials_path}")
|
||||
raise FileNotFoundError(f"Credentials file not found: {credentials_path}")
|
||||
|
||||
# Print key file contents (without sensitive parts) for debugging
|
||||
try:
|
||||
with open(credentials_path, 'r') as f:
|
||||
key_data = json.load(f)
|
||||
# Log non-sensitive parts of the key
|
||||
logger.info(f"Using credentials file with project_id: {key_data.get('project_id')}")
|
||||
logger.info(f"Client email: {key_data.get('client_email')}")
|
||||
logger.info(f"Key file type: {key_data.get('type')}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading key file: {e}")
|
||||
|
||||
# Load credentials
|
||||
credentials = service_account.Credentials.from_service_account_file(credentials_path)
|
||||
|
||||
# Initialize Firestore client
|
||||
self.client = firestore.Client(
|
||||
project=settings.FIRESTORE_PROJECT_ID,
|
||||
database=self.database_name,
|
||||
credentials=credentials
|
||||
)
|
||||
|
||||
# Test connection by trying to access a collection
|
||||
self.client.collection('test').limit(1).get()
|
||||
|
||||
logger.info(f"Connected to Firestore project using credentials file: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}")
|
||||
logger.info(f"Connecting to Firestore database: {settings.FIRESTORE_DATABASE_NAME}")
|
||||
self.provider.connect()
|
||||
logger.info(f"Successfully connected to Firestore database: {settings.FIRESTORE_DATABASE_NAME}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Firestore: {e}", exc_info=True)
|
||||
raise
|
||||
@ -72,29 +23,22 @@ class Database:
|
||||
def close_database_connection(self):
|
||||
"""Close database connection."""
|
||||
try:
|
||||
# No explicit close method needed for Firestore client
|
||||
self.client = None
|
||||
self.provider.disconnect()
|
||||
logger.info("Closed Firestore connection")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close Firestore connection: {e}")
|
||||
|
||||
def get_database(self):
|
||||
"""Get the database instance."""
|
||||
if self.client is None:
|
||||
if self.provider.client is None:
|
||||
logger.warning("Database client is None. Attempting to reconnect...")
|
||||
try:
|
||||
self.connect_to_database()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to reconnect to database: {e}")
|
||||
# Return None to avoid further errors but log this issue
|
||||
return None
|
||||
|
||||
# Verify that client is properly initialized
|
||||
if self.client is None:
|
||||
logger.error("Database client is still None after reconnect attempt")
|
||||
return None
|
||||
|
||||
return self.client
|
||||
return self.provider.client
|
||||
|
||||
# Create a singleton database instance
|
||||
db = Database()
|
||||
|
||||
5
src/db/providers/__init__.py
Normal file
5
src/db/providers/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
"""Database providers package"""
|
||||
|
||||
from .firestore_provider import FirestoreProvider, firestore_db
|
||||
|
||||
__all__ = ["FirestoreProvider", "firestore_db"]
|
||||
@ -99,7 +99,7 @@ class FirestoreProvider:
|
||||
try:
|
||||
collection = self.get_collection(collection_name)
|
||||
|
||||
# Handle ObjectId conversion for Firestore
|
||||
# Process data for Firestore
|
||||
processed_data = {}
|
||||
for key, value in data.items():
|
||||
if value is None:
|
||||
|
||||
@ -3,36 +3,19 @@ from src.config.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Determine which repository implementation to use
|
||||
# Default to Firestore for now
|
||||
REPOSITORY_TYPE = "firestore"
|
||||
|
||||
def init_repositories():
|
||||
"""Initialize the repository implementations based on configuration."""
|
||||
global REPOSITORY_TYPE
|
||||
"""Initialize the Firestore repository implementations."""
|
||||
logger.info("Initializing Firestore repositories")
|
||||
|
||||
logger.info(f"Initializing repositories with type: {REPOSITORY_TYPE}")
|
||||
|
||||
if REPOSITORY_TYPE == "firestore":
|
||||
# Import and initialize Firestore repositories
|
||||
from src.db.repositories.firestore_team_repository import firestore_team_repository
|
||||
from src.db.providers.firestore_provider import firestore_db
|
||||
from src.db.providers.firestore_provider import FirestoreProvider
|
||||
|
||||
# Make sure Firestore client is initialized
|
||||
if not firestore_db.client:
|
||||
firestore_db.connect()
|
||||
# Initialize the Firestore provider
|
||||
firestore_provider = FirestoreProvider()
|
||||
if not firestore_provider.client:
|
||||
firestore_provider.connect()
|
||||
|
||||
# Replace the default repositories with Firestore implementations
|
||||
from src.db.repositories.team_repository import team_repository as default_team_repo
|
||||
|
||||
# Dynamically update the module to use Firestore repositories
|
||||
import sys
|
||||
sys.modules["src.db.repositories.team_repository"].team_repository = firestore_team_repository
|
||||
|
||||
logger.info("Firestore repositories initialized")
|
||||
else:
|
||||
# Default MongoDB repositories are already imported
|
||||
logger.info("Using default MongoDB repositories")
|
||||
logger.info("Firestore repositories initialized successfully")
|
||||
|
||||
# Initialize repositories at module import time
|
||||
init_repositories()
|
||||
|
||||
@ -1,199 +0,0 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db import db
|
||||
from src.models.api_key import ApiKeyModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ApiKeyRepository:
|
||||
"""Repository for API key operations"""
|
||||
|
||||
collection_name = "api_keys"
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
client = db.get_database()
|
||||
if client is None:
|
||||
logger.error("Database connection is None, cannot access collection")
|
||||
raise RuntimeError("Database connection is not available")
|
||||
# Use the Firestore client to get the collection
|
||||
return client.collection(self.collection_name)
|
||||
|
||||
async def create(self, api_key: ApiKeyModel) -> ApiKeyModel:
|
||||
"""
|
||||
Create a new API key
|
||||
|
||||
Args:
|
||||
api_key: API key data
|
||||
|
||||
Returns:
|
||||
Created API key with ID
|
||||
"""
|
||||
try:
|
||||
# Convert to dict for Firestore
|
||||
api_key_dict = api_key.dict(by_alias=True)
|
||||
|
||||
# Create a new document with an auto-generated ID
|
||||
doc_ref = self.collection.document()
|
||||
|
||||
# Set the ID in the model and data
|
||||
api_key.id = doc_ref.id
|
||||
api_key_dict['id'] = doc_ref.id
|
||||
|
||||
# Save to Firestore
|
||||
doc_ref.set(api_key_dict)
|
||||
|
||||
logger.info(f"API key created: {doc_ref.id}")
|
||||
return api_key
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating API key: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_by_id(self, key_id: str) -> Optional[ApiKeyModel]:
|
||||
"""
|
||||
Get API key by ID
|
||||
|
||||
Args:
|
||||
key_id: API key ID
|
||||
|
||||
Returns:
|
||||
API key if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.collection.document(str(key_id))
|
||||
doc = doc_ref.get()
|
||||
if doc.exists:
|
||||
data = doc.to_dict()
|
||||
data['id'] = doc.id
|
||||
return ApiKeyModel(**data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API key by ID: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_by_hash(self, key_hash: str) -> Optional[ApiKeyModel]:
|
||||
"""
|
||||
Get API key by hash
|
||||
|
||||
Args:
|
||||
key_hash: API key hash
|
||||
|
||||
Returns:
|
||||
API key if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
query = self.collection.where("key_hash", "==", key_hash).limit(1)
|
||||
results = query.stream()
|
||||
|
||||
for doc in results:
|
||||
data = doc.to_dict()
|
||||
data['id'] = doc.id
|
||||
return ApiKeyModel(**data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API key by hash: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_by_user(self, user_id: str) -> List[ApiKeyModel]:
|
||||
"""
|
||||
Get API keys by user ID
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
List of API keys for the user
|
||||
"""
|
||||
try:
|
||||
keys = []
|
||||
query = self.collection.where("user_id", "==", str(user_id))
|
||||
results = query.stream()
|
||||
|
||||
for doc in results:
|
||||
data = doc.to_dict()
|
||||
data['id'] = doc.id
|
||||
keys.append(ApiKeyModel(**data))
|
||||
return keys
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API keys by user: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_by_team(self, team_id: str) -> List[ApiKeyModel]:
|
||||
"""
|
||||
Get API keys by team ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
List of API keys for the team
|
||||
"""
|
||||
try:
|
||||
keys = []
|
||||
query = self.collection.where("team_id", "==", str(team_id))
|
||||
results = query.stream()
|
||||
|
||||
for doc in results:
|
||||
data = doc.to_dict()
|
||||
data['id'] = doc.id
|
||||
keys.append(ApiKeyModel(**data))
|
||||
return keys
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting API keys by team: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def update_last_used(self, key_id: str) -> None:
|
||||
"""
|
||||
Update API key's last used timestamp
|
||||
|
||||
Args:
|
||||
key_id: API key ID
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.collection.document(str(key_id))
|
||||
doc_ref.update({"last_used": datetime.utcnow()})
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating API key last used: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def deactivate(self, key_id: str) -> bool:
|
||||
"""
|
||||
Deactivate API key
|
||||
|
||||
Args:
|
||||
key_id: API key ID
|
||||
|
||||
Returns:
|
||||
True if deactivated, False otherwise
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.collection.document(str(key_id))
|
||||
doc_ref.update({"is_active": False})
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deactivating API key: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def delete(self, key_id: str) -> bool:
|
||||
"""
|
||||
Delete API key
|
||||
|
||||
Args:
|
||||
key_id: API key ID
|
||||
|
||||
Returns:
|
||||
True if deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.collection.document(str(key_id))
|
||||
doc_ref.delete()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting API key: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
api_key_repository = ApiKeyRepository()
|
||||
@ -1,6 +1,5 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||
from src.models.team import TeamModel
|
||||
@ -13,20 +12,17 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
|
||||
def __init__(self):
|
||||
super().__init__("teams", TeamModel)
|
||||
|
||||
# Override methods that need special handling for ObjectId and other MongoDB specific types
|
||||
async def get_by_id(self, team_id) -> Optional[TeamModel]:
|
||||
async def get_by_id(self, team_id: str) -> Optional[TeamModel]:
|
||||
"""
|
||||
Get team by ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID (can be ObjectId or string)
|
||||
team_id: Team ID as string
|
||||
|
||||
Returns:
|
||||
Team if found, None otherwise
|
||||
"""
|
||||
# Convert ObjectId to string if needed
|
||||
doc_id = str(team_id)
|
||||
return await super().get_by_id(doc_id)
|
||||
return await super().get_by_id(team_id)
|
||||
|
||||
async def get_all(self) -> List[TeamModel]:
|
||||
"""
|
||||
@ -37,34 +33,30 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
|
||||
"""
|
||||
return await super().get_all()
|
||||
|
||||
async def update(self, team_id, team_data: dict) -> Optional[TeamModel]:
|
||||
async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]:
|
||||
"""
|
||||
Update team
|
||||
|
||||
Args:
|
||||
team_id: Team ID (can be ObjectId or string)
|
||||
team_id: Team ID as string
|
||||
team_data: Update data
|
||||
|
||||
Returns:
|
||||
Updated team if found, None otherwise
|
||||
"""
|
||||
# Convert ObjectId to string if needed
|
||||
doc_id = str(team_id)
|
||||
return await super().update(doc_id, team_data)
|
||||
return await super().update(team_id, team_data)
|
||||
|
||||
async def delete(self, team_id) -> bool:
|
||||
async def delete(self, team_id: str) -> bool:
|
||||
"""
|
||||
Delete team
|
||||
|
||||
Args:
|
||||
team_id: Team ID (can be ObjectId or string)
|
||||
team_id: Team ID as string
|
||||
|
||||
Returns:
|
||||
True if team was deleted, False otherwise
|
||||
"""
|
||||
# Convert ObjectId to string if needed
|
||||
doc_id = str(team_id)
|
||||
return await super().delete(doc_id)
|
||||
return await super().delete(team_id)
|
||||
|
||||
# Create a singleton repository
|
||||
firestore_team_repository = FirestoreTeamRepository()
|
||||
@ -1,302 +0,0 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Dict, Any
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db import db
|
||||
from src.models.image import ImageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ImageRepository:
|
||||
"""Repository for image operations"""
|
||||
|
||||
collection_name = "images"
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
return db.get_database()[self.collection_name]
|
||||
|
||||
async def create(self, image: ImageModel) -> ImageModel:
|
||||
"""
|
||||
Create a new image record
|
||||
|
||||
Args:
|
||||
image: Image data
|
||||
|
||||
Returns:
|
||||
Created image with ID
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.insert_one(image.dict(by_alias=True))
|
||||
created_image = await self.get_by_id(result.inserted_id)
|
||||
logger.info(f"Image created: {result.inserted_id}")
|
||||
return created_image
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating image: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, image_id: ObjectId) -> Optional[ImageModel]:
|
||||
"""
|
||||
Get image by ID
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
|
||||
Returns:
|
||||
Image if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
image = await self.collection.find_one({"_id": image_id})
|
||||
if image:
|
||||
return ImageModel(**image)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting image by ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_ids(self, image_ids: List[str]) -> List[ImageModel]:
|
||||
"""
|
||||
Get images by list of IDs
|
||||
|
||||
Args:
|
||||
image_ids: List of image ID strings
|
||||
|
||||
Returns:
|
||||
List of images
|
||||
"""
|
||||
try:
|
||||
# Convert string IDs to ObjectIds
|
||||
object_ids = []
|
||||
for image_id in image_ids:
|
||||
try:
|
||||
object_ids.append(ObjectId(image_id))
|
||||
except:
|
||||
logger.warning(f"Invalid ObjectId: {image_id}")
|
||||
continue
|
||||
|
||||
images = []
|
||||
cursor = self.collection.find({"_id": {"$in": object_ids}})
|
||||
async for document in cursor:
|
||||
images.append(ImageModel(**document))
|
||||
return images
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting images by IDs: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_team(
|
||||
self,
|
||||
team_id: ObjectId,
|
||||
limit: int = 100,
|
||||
skip: int = 0,
|
||||
collection_id: Optional[ObjectId] = None,
|
||||
tags: Optional[List[str]] = None
|
||||
) -> List[ImageModel]:
|
||||
"""
|
||||
Get images by team ID with pagination and filters
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
limit: Max number of results
|
||||
skip: Number of records to skip
|
||||
collection_id: Optional collection filter
|
||||
tags: Optional tags filter
|
||||
|
||||
Returns:
|
||||
List of images for the team
|
||||
"""
|
||||
try:
|
||||
# Build query
|
||||
query = {"team_id": team_id}
|
||||
|
||||
if collection_id:
|
||||
query["collection_id"] = collection_id
|
||||
|
||||
if tags:
|
||||
query["tags"] = {"$in": tags}
|
||||
|
||||
images = []
|
||||
cursor = self.collection.find(query).sort("upload_date", -1).skip(skip).limit(limit)
|
||||
async for document in cursor:
|
||||
images.append(ImageModel(**document))
|
||||
return images
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting images by team: {e}")
|
||||
raise
|
||||
|
||||
async def count_by_team(
|
||||
self,
|
||||
team_id: ObjectId,
|
||||
collection_id: Optional[ObjectId] = None,
|
||||
tags: Optional[List[str]] = None
|
||||
) -> int:
|
||||
"""
|
||||
Count images by team ID with filters
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
collection_id: Optional collection filter
|
||||
tags: Optional tags filter
|
||||
|
||||
Returns:
|
||||
Number of images for the team
|
||||
"""
|
||||
try:
|
||||
# Build query
|
||||
query = {"team_id": team_id}
|
||||
|
||||
if collection_id:
|
||||
query["collection_id"] = collection_id
|
||||
|
||||
if tags:
|
||||
query["tags"] = {"$in": tags}
|
||||
|
||||
return await self.collection.count_documents(query)
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting images by team: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_uploader(self, uploader_id: ObjectId, limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
||||
"""
|
||||
Get images by uploader ID with pagination
|
||||
|
||||
Args:
|
||||
uploader_id: Uploader user ID
|
||||
limit: Max number of results
|
||||
skip: Number of records to skip
|
||||
|
||||
Returns:
|
||||
List of images uploaded by the user
|
||||
"""
|
||||
try:
|
||||
images = []
|
||||
cursor = self.collection.find({"uploader_id": uploader_id}).sort("upload_date", -1).skip(skip).limit(limit)
|
||||
async for document in cursor:
|
||||
images.append(ImageModel(**document))
|
||||
return images
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting images by uploader: {e}")
|
||||
raise
|
||||
|
||||
async def search_by_metadata(self, team_id: ObjectId, query: Dict[str, Any], limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
||||
"""
|
||||
Search images by metadata
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
query: Search query
|
||||
limit: Max number of results
|
||||
skip: Number of records to skip
|
||||
|
||||
Returns:
|
||||
List of matching images
|
||||
"""
|
||||
try:
|
||||
# Ensure we only search within the team's images
|
||||
search_query = {"team_id": team_id, **query}
|
||||
|
||||
images = []
|
||||
cursor = self.collection.find(search_query).sort("upload_date", -1).skip(skip).limit(limit)
|
||||
async for document in cursor:
|
||||
images.append(ImageModel(**document))
|
||||
return images
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching images by metadata: {e}")
|
||||
raise
|
||||
|
||||
async def update(self, image_id: ObjectId, image_data: dict) -> Optional[ImageModel]:
|
||||
"""
|
||||
Update image
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
image_data: Update data
|
||||
|
||||
Returns:
|
||||
Updated image if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Don't allow updating _id
|
||||
if "_id" in image_data:
|
||||
del image_data["_id"]
|
||||
|
||||
result = await self.collection.update_one(
|
||||
{"_id": image_id},
|
||||
{"$set": image_data}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
logger.warning(f"No image updated for ID: {image_id}")
|
||||
return None
|
||||
|
||||
return await self.get_by_id(image_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating image: {e}")
|
||||
raise
|
||||
|
||||
async def update_last_accessed(self, image_id: ObjectId) -> None:
|
||||
"""
|
||||
Update image's last accessed timestamp
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
"""
|
||||
try:
|
||||
await self.collection.update_one(
|
||||
{"_id": image_id},
|
||||
{"$set": {"last_accessed": datetime.utcnow()}}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating image last accessed: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self, image_id: ObjectId) -> bool:
|
||||
"""
|
||||
Delete image
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
|
||||
Returns:
|
||||
True if image was deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.delete_one({"_id": image_id})
|
||||
return result.deleted_count > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting image: {e}")
|
||||
raise
|
||||
|
||||
async def update_embedding_status(self, image_id: ObjectId, embedding_id: str, model: str) -> Optional[ImageModel]:
|
||||
"""
|
||||
Update image embedding status
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
embedding_id: Vector DB embedding ID
|
||||
model: Model used for embedding
|
||||
|
||||
Returns:
|
||||
Updated image if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
result = await self.collection.update_one(
|
||||
{"_id": image_id},
|
||||
{"$set": {
|
||||
"embedding_id": embedding_id,
|
||||
"embedding_model": model,
|
||||
"has_embedding": True
|
||||
}}
|
||||
)
|
||||
|
||||
if result.modified_count == 0:
|
||||
logger.warning(f"No image updated for embedding ID: {image_id}")
|
||||
return None
|
||||
|
||||
return await self.get_by_id(image_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating image embedding status: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
image_repository = ImageRepository()
|
||||
@ -1,151 +0,0 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db import db
|
||||
from src.models.team import TeamModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TeamRepository:
|
||||
"""Repository for team operations"""
|
||||
|
||||
collection_name = "teams"
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
client = db.get_database()
|
||||
if client is None:
|
||||
logger.error("Database connection is None, cannot access collection")
|
||||
raise RuntimeError("Database connection is not available")
|
||||
# Use the Firestore client to get the collection
|
||||
return client.collection(self.collection_name)
|
||||
|
||||
async def create(self, team: TeamModel) -> TeamModel:
|
||||
"""
|
||||
Create a new team
|
||||
|
||||
Args:
|
||||
team: Team data
|
||||
|
||||
Returns:
|
||||
Created team with ID
|
||||
"""
|
||||
try:
|
||||
# Convert team to dict for Firestore
|
||||
team_dict = team.dict(by_alias=True)
|
||||
|
||||
# Add document to Firestore and get reference
|
||||
doc_ref = self.collection.document()
|
||||
doc_ref.set(team_dict)
|
||||
|
||||
# Update ID and retrieve the created team
|
||||
team.id = doc_ref.id
|
||||
logger.info(f"Team created: {doc_ref.id}")
|
||||
|
||||
return team
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating team: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, team_id: ObjectId) -> Optional[TeamModel]:
|
||||
"""
|
||||
Get team by ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
Team if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.collection.document(str(team_id))
|
||||
team_doc = doc_ref.get()
|
||||
|
||||
if team_doc.exists:
|
||||
team_data = team_doc.to_dict()
|
||||
team_data['id'] = team_doc.id
|
||||
return TeamModel(**team_data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting team by ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_all(self) -> List[TeamModel]:
|
||||
"""
|
||||
Get all teams
|
||||
|
||||
Returns:
|
||||
List of teams
|
||||
"""
|
||||
try:
|
||||
teams = []
|
||||
team_docs = self.collection.stream()
|
||||
|
||||
for doc in team_docs:
|
||||
team_data = doc.to_dict()
|
||||
team_data['id'] = doc.id
|
||||
teams.append(TeamModel(**team_data))
|
||||
|
||||
return teams
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting all teams: {e}")
|
||||
raise
|
||||
|
||||
async def update(self, team_id: ObjectId, team_data: dict) -> Optional[TeamModel]:
|
||||
"""
|
||||
Update team
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
team_data: Update data
|
||||
|
||||
Returns:
|
||||
Updated team if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Add updated_at timestamp
|
||||
team_data["updated_at"] = datetime.utcnow()
|
||||
|
||||
# Don't allow updating _id or id
|
||||
if "_id" in team_data:
|
||||
del team_data["_id"]
|
||||
if "id" in team_data:
|
||||
del team_data["id"]
|
||||
|
||||
# Get document reference and update
|
||||
doc_ref = self.collection.document(str(team_id))
|
||||
doc_ref.update(team_data)
|
||||
|
||||
# Get updated team
|
||||
updated_team = await self.get_by_id(team_id)
|
||||
if not updated_team:
|
||||
logger.warning(f"No team found after update for ID: {team_id}")
|
||||
return None
|
||||
|
||||
return updated_team
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating team: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self, team_id: ObjectId) -> bool:
|
||||
"""
|
||||
Delete team
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
True if team was deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.collection.document(str(team_id))
|
||||
doc_ref.delete()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting team: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
team_repository = TeamRepository()
|
||||
@ -1,191 +0,0 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db import db
|
||||
from src.models.user import UserModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UserRepository:
|
||||
"""Repository for user operations"""
|
||||
|
||||
collection_name = "users"
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
client = db.get_database()
|
||||
if client is None:
|
||||
logger.error("Database connection is None, cannot access collection")
|
||||
raise RuntimeError("Database connection is not available")
|
||||
# Use the Firestore client to get the collection
|
||||
return client.collection(self.collection_name)
|
||||
|
||||
async def create(self, user: UserModel) -> UserModel:
|
||||
"""
|
||||
Create a new user
|
||||
|
||||
Args:
|
||||
user: User data
|
||||
|
||||
Returns:
|
||||
Created user with ID
|
||||
"""
|
||||
try:
|
||||
# Convert user to dict for Firestore
|
||||
user_dict = user.dict(by_alias=True)
|
||||
|
||||
# Add document to Firestore and get reference
|
||||
doc_ref = self.collection.document()
|
||||
doc_ref.set(user_dict)
|
||||
|
||||
# Update ID and retrieve the created user
|
||||
user.id = doc_ref.id
|
||||
logger.info(f"User created: {doc_ref.id}")
|
||||
|
||||
return user
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating user: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_id(self, user_id: ObjectId) -> Optional[UserModel]:
|
||||
"""
|
||||
Get user by ID
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.collection.document(str(user_id))
|
||||
user_doc = doc_ref.get()
|
||||
|
||||
if user_doc.exists:
|
||||
user_data = user_doc.to_dict()
|
||||
user_data['id'] = user_doc.id
|
||||
return UserModel(**user_data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_email(self, email: str) -> Optional[UserModel]:
|
||||
"""
|
||||
Get user by email
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
query = self.collection.where("email", "==", email).limit(1)
|
||||
results = query.stream()
|
||||
|
||||
for doc in results:
|
||||
user_data = doc.to_dict()
|
||||
user_data['id'] = doc.id
|
||||
return UserModel(**user_data)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by email: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_team(self, team_id: ObjectId) -> List[UserModel]:
|
||||
"""
|
||||
Get users by team ID
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
List of users in the team
|
||||
"""
|
||||
try:
|
||||
users = []
|
||||
query = self.collection.where("team_id", "==", str(team_id))
|
||||
results = query.stream()
|
||||
|
||||
for doc in results:
|
||||
user_data = doc.to_dict()
|
||||
user_data['id'] = doc.id
|
||||
users.append(UserModel(**user_data))
|
||||
return users
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting users by team: {e}")
|
||||
raise
|
||||
|
||||
async def update(self, user_id: ObjectId, user_data: dict) -> Optional[UserModel]:
|
||||
"""
|
||||
Update user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
user_data: Update data
|
||||
|
||||
Returns:
|
||||
Updated user if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Add updated_at timestamp
|
||||
user_data["updated_at"] = datetime.utcnow()
|
||||
|
||||
# Don't allow updating _id or id
|
||||
if "_id" in user_data:
|
||||
del user_data["_id"]
|
||||
if "id" in user_data:
|
||||
del user_data["id"]
|
||||
|
||||
# Get document reference and update
|
||||
doc_ref = self.collection.document(str(user_id))
|
||||
doc_ref.update(user_data)
|
||||
|
||||
# Get updated user
|
||||
updated_user = await self.get_by_id(user_id)
|
||||
if not updated_user:
|
||||
logger.warning(f"No user found after update for ID: {user_id}")
|
||||
return None
|
||||
|
||||
return updated_user
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user: {e}")
|
||||
raise
|
||||
|
||||
async def delete(self, user_id: ObjectId) -> bool:
|
||||
"""
|
||||
Delete user
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
True if user was deleted, False otherwise
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.collection.document(str(user_id))
|
||||
doc_ref.delete()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting user: {e}")
|
||||
raise
|
||||
|
||||
async def update_last_login(self, user_id: ObjectId) -> None:
|
||||
"""
|
||||
Update user's last login timestamp
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
"""
|
||||
try:
|
||||
doc_ref = self.collection.document(str(user_id))
|
||||
doc_ref.update({"last_login": datetime.utcnow()})
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating user last login: {e}")
|
||||
raise
|
||||
|
||||
# Create a singleton repository
|
||||
user_repository = UserRepository()
|
||||
Loading…
x
Reference in New Issue
Block a user