cp
This commit is contained in:
parent
d8aec1e6b4
commit
fcbf6e9f85
@ -28,7 +28,12 @@ sereact/
|
|||||||
│ │ └── v1/ # API version 1 routes
|
│ │ └── v1/ # API version 1 routes
|
||||||
│ ├── auth/ # Authentication and authorization
|
│ ├── auth/ # Authentication and authorization
|
||||||
│ ├── config/ # Configuration management
|
│ ├── config/ # Configuration management
|
||||||
|
│ ├── core/ # Core application logic
|
||||||
|
│ ├── db/ # Database layer
|
||||||
|
│ │ ├── providers/ # Database providers (Firestore)
|
||||||
|
│ │ └── repositories/ # Data access repositories
|
||||||
│ ├── models/ # Database models
|
│ ├── models/ # Database models
|
||||||
|
│ ├── schemas/ # API request/response schemas
|
||||||
│ ├── services/ # Business logic services
|
│ ├── services/ # Business logic services
|
||||||
│ └── utils/ # Utility functions
|
│ └── utils/ # Utility functions
|
||||||
├── tests/ # Test code
|
├── tests/ # Test code
|
||||||
|
|||||||
@ -1,70 +1,21 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import json
|
|
||||||
from google.cloud import firestore
|
|
||||||
from google.oauth2 import service_account
|
|
||||||
from src.config.config import settings
|
from src.config.config import settings
|
||||||
|
from src.db.providers.firestore_provider import FirestoreProvider
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class Database:
|
class Database:
|
||||||
client = None
|
"""Database connection manager for Firestore"""
|
||||||
database_name = None
|
|
||||||
|
def __init__(self):
|
||||||
|
self.provider = FirestoreProvider()
|
||||||
|
|
||||||
def connect_to_database(self):
|
def connect_to_database(self):
|
||||||
"""Create database connection."""
|
"""Create database connection."""
|
||||||
try:
|
try:
|
||||||
# Store database name
|
logger.info(f"Connecting to Firestore database: {settings.FIRESTORE_DATABASE_NAME}")
|
||||||
self.database_name = settings.FIRESTORE_DATABASE_NAME
|
self.provider.connect()
|
||||||
|
logger.info(f"Successfully connected to Firestore database: {settings.FIRESTORE_DATABASE_NAME}")
|
||||||
# Print project ID and database name for debugging
|
|
||||||
logger.info(f"Attempting to connect to Firestore with project ID: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}")
|
|
||||||
|
|
||||||
# First try using Application Default Credentials (for Cloud environments)
|
|
||||||
try:
|
|
||||||
logger.info("Attempting to connect using Application Default Credentials")
|
|
||||||
self.client = firestore.Client(
|
|
||||||
project=settings.FIRESTORE_PROJECT_ID,
|
|
||||||
database=self.database_name
|
|
||||||
)
|
|
||||||
# Test connection by trying to access a collection
|
|
||||||
self.client.collection('test').limit(1).get()
|
|
||||||
logger.info(f"Connected to Firestore project using Application Default Credentials: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}")
|
|
||||||
return
|
|
||||||
except Exception as adc_error:
|
|
||||||
logger.error(f"Application Default Credentials failed: {adc_error}", exc_info=True)
|
|
||||||
|
|
||||||
# Fall back to service account file
|
|
||||||
credentials_path = settings.FIRESTORE_CREDENTIALS_FILE
|
|
||||||
if not os.path.exists(credentials_path):
|
|
||||||
logger.error(f"Firestore credentials file not found: {credentials_path}")
|
|
||||||
raise FileNotFoundError(f"Credentials file not found: {credentials_path}")
|
|
||||||
|
|
||||||
# Print key file contents (without sensitive parts) for debugging
|
|
||||||
try:
|
|
||||||
with open(credentials_path, 'r') as f:
|
|
||||||
key_data = json.load(f)
|
|
||||||
# Log non-sensitive parts of the key
|
|
||||||
logger.info(f"Using credentials file with project_id: {key_data.get('project_id')}")
|
|
||||||
logger.info(f"Client email: {key_data.get('client_email')}")
|
|
||||||
logger.info(f"Key file type: {key_data.get('type')}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error reading key file: {e}")
|
|
||||||
|
|
||||||
# Load credentials
|
|
||||||
credentials = service_account.Credentials.from_service_account_file(credentials_path)
|
|
||||||
|
|
||||||
# Initialize Firestore client
|
|
||||||
self.client = firestore.Client(
|
|
||||||
project=settings.FIRESTORE_PROJECT_ID,
|
|
||||||
database=self.database_name,
|
|
||||||
credentials=credentials
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test connection by trying to access a collection
|
|
||||||
self.client.collection('test').limit(1).get()
|
|
||||||
|
|
||||||
logger.info(f"Connected to Firestore project using credentials file: {settings.FIRESTORE_PROJECT_ID}, database: {self.database_name}")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to connect to Firestore: {e}", exc_info=True)
|
logger.error(f"Failed to connect to Firestore: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
@ -72,29 +23,22 @@ class Database:
|
|||||||
def close_database_connection(self):
|
def close_database_connection(self):
|
||||||
"""Close database connection."""
|
"""Close database connection."""
|
||||||
try:
|
try:
|
||||||
# No explicit close method needed for Firestore client
|
self.provider.disconnect()
|
||||||
self.client = None
|
|
||||||
logger.info("Closed Firestore connection")
|
logger.info("Closed Firestore connection")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to close Firestore connection: {e}")
|
logger.error(f"Failed to close Firestore connection: {e}")
|
||||||
|
|
||||||
def get_database(self):
|
def get_database(self):
|
||||||
"""Get the database instance."""
|
"""Get the database instance."""
|
||||||
if self.client is None:
|
if self.provider.client is None:
|
||||||
logger.warning("Database client is None. Attempting to reconnect...")
|
logger.warning("Database client is None. Attempting to reconnect...")
|
||||||
try:
|
try:
|
||||||
self.connect_to_database()
|
self.connect_to_database()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to reconnect to database: {e}")
|
logger.error(f"Failed to reconnect to database: {e}")
|
||||||
# Return None to avoid further errors but log this issue
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Verify that client is properly initialized
|
return self.provider.client
|
||||||
if self.client is None:
|
|
||||||
logger.error("Database client is still None after reconnect attempt")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return self.client
|
|
||||||
|
|
||||||
# Create a singleton database instance
|
# Create a singleton database instance
|
||||||
db = Database()
|
db = Database()
|
||||||
|
|||||||
5
src/db/providers/__init__.py
Normal file
5
src/db/providers/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
"""Database providers package"""
|
||||||
|
|
||||||
|
from .firestore_provider import FirestoreProvider, firestore_db
|
||||||
|
|
||||||
|
__all__ = ["FirestoreProvider", "firestore_db"]
|
||||||
@ -99,7 +99,7 @@ class FirestoreProvider:
|
|||||||
try:
|
try:
|
||||||
collection = self.get_collection(collection_name)
|
collection = self.get_collection(collection_name)
|
||||||
|
|
||||||
# Handle ObjectId conversion for Firestore
|
# Process data for Firestore
|
||||||
processed_data = {}
|
processed_data = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if value is None:
|
if value is None:
|
||||||
|
|||||||
@ -3,36 +3,19 @@ from src.config.config import settings
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Determine which repository implementation to use
|
|
||||||
# Default to Firestore for now
|
|
||||||
REPOSITORY_TYPE = "firestore"
|
|
||||||
|
|
||||||
def init_repositories():
|
def init_repositories():
|
||||||
"""Initialize the repository implementations based on configuration."""
|
"""Initialize the Firestore repository implementations."""
|
||||||
global REPOSITORY_TYPE
|
logger.info("Initializing Firestore repositories")
|
||||||
|
|
||||||
logger.info(f"Initializing repositories with type: {REPOSITORY_TYPE}")
|
# Import and initialize Firestore repositories
|
||||||
|
from src.db.providers.firestore_provider import FirestoreProvider
|
||||||
|
|
||||||
if REPOSITORY_TYPE == "firestore":
|
# Initialize the Firestore provider
|
||||||
# Import and initialize Firestore repositories
|
firestore_provider = FirestoreProvider()
|
||||||
from src.db.repositories.firestore_team_repository import firestore_team_repository
|
if not firestore_provider.client:
|
||||||
from src.db.providers.firestore_provider import firestore_db
|
firestore_provider.connect()
|
||||||
|
|
||||||
# Make sure Firestore client is initialized
|
logger.info("Firestore repositories initialized successfully")
|
||||||
if not firestore_db.client:
|
|
||||||
firestore_db.connect()
|
|
||||||
|
|
||||||
# Replace the default repositories with Firestore implementations
|
|
||||||
from src.db.repositories.team_repository import team_repository as default_team_repo
|
|
||||||
|
|
||||||
# Dynamically update the module to use Firestore repositories
|
|
||||||
import sys
|
|
||||||
sys.modules["src.db.repositories.team_repository"].team_repository = firestore_team_repository
|
|
||||||
|
|
||||||
logger.info("Firestore repositories initialized")
|
|
||||||
else:
|
|
||||||
# Default MongoDB repositories are already imported
|
|
||||||
logger.info("Using default MongoDB repositories")
|
|
||||||
|
|
||||||
# Initialize repositories at module import time
|
# Initialize repositories at module import time
|
||||||
init_repositories()
|
init_repositories()
|
||||||
|
|||||||
@ -1,199 +0,0 @@
|
|||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Optional
|
|
||||||
from bson import ObjectId
|
|
||||||
|
|
||||||
from src.db import db
|
|
||||||
from src.models.api_key import ApiKeyModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class ApiKeyRepository:
|
|
||||||
"""Repository for API key operations"""
|
|
||||||
|
|
||||||
collection_name = "api_keys"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def collection(self):
|
|
||||||
client = db.get_database()
|
|
||||||
if client is None:
|
|
||||||
logger.error("Database connection is None, cannot access collection")
|
|
||||||
raise RuntimeError("Database connection is not available")
|
|
||||||
# Use the Firestore client to get the collection
|
|
||||||
return client.collection(self.collection_name)
|
|
||||||
|
|
||||||
async def create(self, api_key: ApiKeyModel) -> ApiKeyModel:
|
|
||||||
"""
|
|
||||||
Create a new API key
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: API key data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created API key with ID
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Convert to dict for Firestore
|
|
||||||
api_key_dict = api_key.dict(by_alias=True)
|
|
||||||
|
|
||||||
# Create a new document with an auto-generated ID
|
|
||||||
doc_ref = self.collection.document()
|
|
||||||
|
|
||||||
# Set the ID in the model and data
|
|
||||||
api_key.id = doc_ref.id
|
|
||||||
api_key_dict['id'] = doc_ref.id
|
|
||||||
|
|
||||||
# Save to Firestore
|
|
||||||
doc_ref.set(api_key_dict)
|
|
||||||
|
|
||||||
logger.info(f"API key created: {doc_ref.id}")
|
|
||||||
return api_key
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating API key: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_id(self, key_id: str) -> Optional[ApiKeyModel]:
|
|
||||||
"""
|
|
||||||
Get API key by ID
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key_id: API key ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API key if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
doc_ref = self.collection.document(str(key_id))
|
|
||||||
doc = doc_ref.get()
|
|
||||||
if doc.exists:
|
|
||||||
data = doc.to_dict()
|
|
||||||
data['id'] = doc.id
|
|
||||||
return ApiKeyModel(**data)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting API key by ID: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_hash(self, key_hash: str) -> Optional[ApiKeyModel]:
|
|
||||||
"""
|
|
||||||
Get API key by hash
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key_hash: API key hash
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API key if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
query = self.collection.where("key_hash", "==", key_hash).limit(1)
|
|
||||||
results = query.stream()
|
|
||||||
|
|
||||||
for doc in results:
|
|
||||||
data = doc.to_dict()
|
|
||||||
data['id'] = doc.id
|
|
||||||
return ApiKeyModel(**data)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting API key by hash: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_user(self, user_id: str) -> List[ApiKeyModel]:
|
|
||||||
"""
|
|
||||||
Get API keys by user ID
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of API keys for the user
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
keys = []
|
|
||||||
query = self.collection.where("user_id", "==", str(user_id))
|
|
||||||
results = query.stream()
|
|
||||||
|
|
||||||
for doc in results:
|
|
||||||
data = doc.to_dict()
|
|
||||||
data['id'] = doc.id
|
|
||||||
keys.append(ApiKeyModel(**data))
|
|
||||||
return keys
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting API keys by user: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_team(self, team_id: str) -> List[ApiKeyModel]:
|
|
||||||
"""
|
|
||||||
Get API keys by team ID
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team_id: Team ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of API keys for the team
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
keys = []
|
|
||||||
query = self.collection.where("team_id", "==", str(team_id))
|
|
||||||
results = query.stream()
|
|
||||||
|
|
||||||
for doc in results:
|
|
||||||
data = doc.to_dict()
|
|
||||||
data['id'] = doc.id
|
|
||||||
keys.append(ApiKeyModel(**data))
|
|
||||||
return keys
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting API keys by team: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update_last_used(self, key_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Update API key's last used timestamp
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key_id: API key ID
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
doc_ref = self.collection.document(str(key_id))
|
|
||||||
doc_ref.update({"last_used": datetime.utcnow()})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating API key last used: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def deactivate(self, key_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Deactivate API key
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key_id: API key ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deactivated, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
doc_ref = self.collection.document(str(key_id))
|
|
||||||
doc_ref.update({"is_active": False})
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deactivating API key: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def delete(self, key_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Delete API key
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key_id: API key ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if deleted, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
doc_ref = self.collection.document(str(key_id))
|
|
||||||
doc_ref.delete()
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deleting API key: {e}", exc_info=True)
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Create a singleton repository
|
|
||||||
api_key_repository = ApiKeyRepository()
|
|
||||||
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from bson import ObjectId
|
|
||||||
|
|
||||||
from src.db.repositories.firestore_repository import FirestoreRepository
|
from src.db.repositories.firestore_repository import FirestoreRepository
|
||||||
from src.models.team import TeamModel
|
from src.models.team import TeamModel
|
||||||
@ -13,20 +12,17 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__("teams", TeamModel)
|
super().__init__("teams", TeamModel)
|
||||||
|
|
||||||
# Override methods that need special handling for ObjectId and other MongoDB specific types
|
async def get_by_id(self, team_id: str) -> Optional[TeamModel]:
|
||||||
async def get_by_id(self, team_id) -> Optional[TeamModel]:
|
|
||||||
"""
|
"""
|
||||||
Get team by ID
|
Get team by ID
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
team_id: Team ID (can be ObjectId or string)
|
team_id: Team ID as string
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Team if found, None otherwise
|
Team if found, None otherwise
|
||||||
"""
|
"""
|
||||||
# Convert ObjectId to string if needed
|
return await super().get_by_id(team_id)
|
||||||
doc_id = str(team_id)
|
|
||||||
return await super().get_by_id(doc_id)
|
|
||||||
|
|
||||||
async def get_all(self) -> List[TeamModel]:
|
async def get_all(self) -> List[TeamModel]:
|
||||||
"""
|
"""
|
||||||
@ -37,34 +33,30 @@ class FirestoreTeamRepository(FirestoreRepository[TeamModel]):
|
|||||||
"""
|
"""
|
||||||
return await super().get_all()
|
return await super().get_all()
|
||||||
|
|
||||||
async def update(self, team_id, team_data: dict) -> Optional[TeamModel]:
|
async def update(self, team_id: str, team_data: dict) -> Optional[TeamModel]:
|
||||||
"""
|
"""
|
||||||
Update team
|
Update team
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
team_id: Team ID (can be ObjectId or string)
|
team_id: Team ID as string
|
||||||
team_data: Update data
|
team_data: Update data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated team if found, None otherwise
|
Updated team if found, None otherwise
|
||||||
"""
|
"""
|
||||||
# Convert ObjectId to string if needed
|
return await super().update(team_id, team_data)
|
||||||
doc_id = str(team_id)
|
|
||||||
return await super().update(doc_id, team_data)
|
|
||||||
|
|
||||||
async def delete(self, team_id) -> bool:
|
async def delete(self, team_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Delete team
|
Delete team
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
team_id: Team ID (can be ObjectId or string)
|
team_id: Team ID as string
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if team was deleted, False otherwise
|
True if team was deleted, False otherwise
|
||||||
"""
|
"""
|
||||||
# Convert ObjectId to string if needed
|
return await super().delete(team_id)
|
||||||
doc_id = str(team_id)
|
|
||||||
return await super().delete(doc_id)
|
|
||||||
|
|
||||||
# Create a singleton repository
|
# Create a singleton repository
|
||||||
firestore_team_repository = FirestoreTeamRepository()
|
firestore_team_repository = FirestoreTeamRepository()
|
||||||
@ -1,302 +0,0 @@
|
|||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Optional, Dict, Any
|
|
||||||
from bson import ObjectId
|
|
||||||
|
|
||||||
from src.db import db
|
|
||||||
from src.models.image import ImageModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class ImageRepository:
|
|
||||||
"""Repository for image operations"""
|
|
||||||
|
|
||||||
collection_name = "images"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def collection(self):
|
|
||||||
return db.get_database()[self.collection_name]
|
|
||||||
|
|
||||||
async def create(self, image: ImageModel) -> ImageModel:
|
|
||||||
"""
|
|
||||||
Create a new image record
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image: Image data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created image with ID
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await self.collection.insert_one(image.dict(by_alias=True))
|
|
||||||
created_image = await self.get_by_id(result.inserted_id)
|
|
||||||
logger.info(f"Image created: {result.inserted_id}")
|
|
||||||
return created_image
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating image: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_id(self, image_id: ObjectId) -> Optional[ImageModel]:
|
|
||||||
"""
|
|
||||||
Get image by ID
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_id: Image ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Image if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
image = await self.collection.find_one({"_id": image_id})
|
|
||||||
if image:
|
|
||||||
return ImageModel(**image)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting image by ID: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_ids(self, image_ids: List[str]) -> List[ImageModel]:
|
|
||||||
"""
|
|
||||||
Get images by list of IDs
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_ids: List of image ID strings
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of images
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Convert string IDs to ObjectIds
|
|
||||||
object_ids = []
|
|
||||||
for image_id in image_ids:
|
|
||||||
try:
|
|
||||||
object_ids.append(ObjectId(image_id))
|
|
||||||
except:
|
|
||||||
logger.warning(f"Invalid ObjectId: {image_id}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
images = []
|
|
||||||
cursor = self.collection.find({"_id": {"$in": object_ids}})
|
|
||||||
async for document in cursor:
|
|
||||||
images.append(ImageModel(**document))
|
|
||||||
return images
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting images by IDs: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_team(
|
|
||||||
self,
|
|
||||||
team_id: ObjectId,
|
|
||||||
limit: int = 100,
|
|
||||||
skip: int = 0,
|
|
||||||
collection_id: Optional[ObjectId] = None,
|
|
||||||
tags: Optional[List[str]] = None
|
|
||||||
) -> List[ImageModel]:
|
|
||||||
"""
|
|
||||||
Get images by team ID with pagination and filters
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team_id: Team ID
|
|
||||||
limit: Max number of results
|
|
||||||
skip: Number of records to skip
|
|
||||||
collection_id: Optional collection filter
|
|
||||||
tags: Optional tags filter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of images for the team
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Build query
|
|
||||||
query = {"team_id": team_id}
|
|
||||||
|
|
||||||
if collection_id:
|
|
||||||
query["collection_id"] = collection_id
|
|
||||||
|
|
||||||
if tags:
|
|
||||||
query["tags"] = {"$in": tags}
|
|
||||||
|
|
||||||
images = []
|
|
||||||
cursor = self.collection.find(query).sort("upload_date", -1).skip(skip).limit(limit)
|
|
||||||
async for document in cursor:
|
|
||||||
images.append(ImageModel(**document))
|
|
||||||
return images
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting images by team: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def count_by_team(
|
|
||||||
self,
|
|
||||||
team_id: ObjectId,
|
|
||||||
collection_id: Optional[ObjectId] = None,
|
|
||||||
tags: Optional[List[str]] = None
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Count images by team ID with filters
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team_id: Team ID
|
|
||||||
collection_id: Optional collection filter
|
|
||||||
tags: Optional tags filter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of images for the team
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Build query
|
|
||||||
query = {"team_id": team_id}
|
|
||||||
|
|
||||||
if collection_id:
|
|
||||||
query["collection_id"] = collection_id
|
|
||||||
|
|
||||||
if tags:
|
|
||||||
query["tags"] = {"$in": tags}
|
|
||||||
|
|
||||||
return await self.collection.count_documents(query)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error counting images by team: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_uploader(self, uploader_id: ObjectId, limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
|
||||||
"""
|
|
||||||
Get images by uploader ID with pagination
|
|
||||||
|
|
||||||
Args:
|
|
||||||
uploader_id: Uploader user ID
|
|
||||||
limit: Max number of results
|
|
||||||
skip: Number of records to skip
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of images uploaded by the user
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
images = []
|
|
||||||
cursor = self.collection.find({"uploader_id": uploader_id}).sort("upload_date", -1).skip(skip).limit(limit)
|
|
||||||
async for document in cursor:
|
|
||||||
images.append(ImageModel(**document))
|
|
||||||
return images
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting images by uploader: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def search_by_metadata(self, team_id: ObjectId, query: Dict[str, Any], limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
|
||||||
"""
|
|
||||||
Search images by metadata
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team_id: Team ID
|
|
||||||
query: Search query
|
|
||||||
limit: Max number of results
|
|
||||||
skip: Number of records to skip
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of matching images
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Ensure we only search within the team's images
|
|
||||||
search_query = {"team_id": team_id, **query}
|
|
||||||
|
|
||||||
images = []
|
|
||||||
cursor = self.collection.find(search_query).sort("upload_date", -1).skip(skip).limit(limit)
|
|
||||||
async for document in cursor:
|
|
||||||
images.append(ImageModel(**document))
|
|
||||||
return images
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error searching images by metadata: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update(self, image_id: ObjectId, image_data: dict) -> Optional[ImageModel]:
|
|
||||||
"""
|
|
||||||
Update image
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_id: Image ID
|
|
||||||
image_data: Update data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated image if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Don't allow updating _id
|
|
||||||
if "_id" in image_data:
|
|
||||||
del image_data["_id"]
|
|
||||||
|
|
||||||
result = await self.collection.update_one(
|
|
||||||
{"_id": image_id},
|
|
||||||
{"$set": image_data}
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.modified_count == 0:
|
|
||||||
logger.warning(f"No image updated for ID: {image_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return await self.get_by_id(image_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating image: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update_last_accessed(self, image_id: ObjectId) -> None:
|
|
||||||
"""
|
|
||||||
Update image's last accessed timestamp
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_id: Image ID
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await self.collection.update_one(
|
|
||||||
{"_id": image_id},
|
|
||||||
{"$set": {"last_accessed": datetime.utcnow()}}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating image last accessed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def delete(self, image_id: ObjectId) -> bool:
|
|
||||||
"""
|
|
||||||
Delete image
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_id: Image ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if image was deleted, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await self.collection.delete_one({"_id": image_id})
|
|
||||||
return result.deleted_count > 0
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deleting image: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update_embedding_status(self, image_id: ObjectId, embedding_id: str, model: str) -> Optional[ImageModel]:
|
|
||||||
"""
|
|
||||||
Update image embedding status
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_id: Image ID
|
|
||||||
embedding_id: Vector DB embedding ID
|
|
||||||
model: Model used for embedding
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated image if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
result = await self.collection.update_one(
|
|
||||||
{"_id": image_id},
|
|
||||||
{"$set": {
|
|
||||||
"embedding_id": embedding_id,
|
|
||||||
"embedding_model": model,
|
|
||||||
"has_embedding": True
|
|
||||||
}}
|
|
||||||
)
|
|
||||||
|
|
||||||
if result.modified_count == 0:
|
|
||||||
logger.warning(f"No image updated for embedding ID: {image_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return await self.get_by_id(image_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating image embedding status: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Create a singleton repository
|
|
||||||
image_repository = ImageRepository()
|
|
||||||
@ -1,151 +0,0 @@
|
|||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Optional
|
|
||||||
from bson import ObjectId
|
|
||||||
|
|
||||||
from src.db import db
|
|
||||||
from src.models.team import TeamModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class TeamRepository:
|
|
||||||
"""Repository for team operations"""
|
|
||||||
|
|
||||||
collection_name = "teams"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def collection(self):
|
|
||||||
client = db.get_database()
|
|
||||||
if client is None:
|
|
||||||
logger.error("Database connection is None, cannot access collection")
|
|
||||||
raise RuntimeError("Database connection is not available")
|
|
||||||
# Use the Firestore client to get the collection
|
|
||||||
return client.collection(self.collection_name)
|
|
||||||
|
|
||||||
async def create(self, team: TeamModel) -> TeamModel:
|
|
||||||
"""
|
|
||||||
Create a new team
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team: Team data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created team with ID
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Convert team to dict for Firestore
|
|
||||||
team_dict = team.dict(by_alias=True)
|
|
||||||
|
|
||||||
# Add document to Firestore and get reference
|
|
||||||
doc_ref = self.collection.document()
|
|
||||||
doc_ref.set(team_dict)
|
|
||||||
|
|
||||||
# Update ID and retrieve the created team
|
|
||||||
team.id = doc_ref.id
|
|
||||||
logger.info(f"Team created: {doc_ref.id}")
|
|
||||||
|
|
||||||
return team
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating team: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_id(self, team_id: ObjectId) -> Optional[TeamModel]:
|
|
||||||
"""
|
|
||||||
Get team by ID
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team_id: Team ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Team if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
doc_ref = self.collection.document(str(team_id))
|
|
||||||
team_doc = doc_ref.get()
|
|
||||||
|
|
||||||
if team_doc.exists:
|
|
||||||
team_data = team_doc.to_dict()
|
|
||||||
team_data['id'] = team_doc.id
|
|
||||||
return TeamModel(**team_data)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting team by ID: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_all(self) -> List[TeamModel]:
|
|
||||||
"""
|
|
||||||
Get all teams
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of teams
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
teams = []
|
|
||||||
team_docs = self.collection.stream()
|
|
||||||
|
|
||||||
for doc in team_docs:
|
|
||||||
team_data = doc.to_dict()
|
|
||||||
team_data['id'] = doc.id
|
|
||||||
teams.append(TeamModel(**team_data))
|
|
||||||
|
|
||||||
return teams
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting all teams: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update(self, team_id: ObjectId, team_data: dict) -> Optional[TeamModel]:
|
|
||||||
"""
|
|
||||||
Update team
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team_id: Team ID
|
|
||||||
team_data: Update data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated team if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Add updated_at timestamp
|
|
||||||
team_data["updated_at"] = datetime.utcnow()
|
|
||||||
|
|
||||||
# Don't allow updating _id or id
|
|
||||||
if "_id" in team_data:
|
|
||||||
del team_data["_id"]
|
|
||||||
if "id" in team_data:
|
|
||||||
del team_data["id"]
|
|
||||||
|
|
||||||
# Get document reference and update
|
|
||||||
doc_ref = self.collection.document(str(team_id))
|
|
||||||
doc_ref.update(team_data)
|
|
||||||
|
|
||||||
# Get updated team
|
|
||||||
updated_team = await self.get_by_id(team_id)
|
|
||||||
if not updated_team:
|
|
||||||
logger.warning(f"No team found after update for ID: {team_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return updated_team
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating team: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def delete(self, team_id: ObjectId) -> bool:
|
|
||||||
"""
|
|
||||||
Delete team
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team_id: Team ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if team was deleted, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
doc_ref = self.collection.document(str(team_id))
|
|
||||||
doc_ref.delete()
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deleting team: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Create a singleton repository
|
|
||||||
team_repository = TeamRepository()
|
|
||||||
@ -1,191 +0,0 @@
|
|||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import List, Optional
|
|
||||||
from bson import ObjectId
|
|
||||||
|
|
||||||
from src.db import db
|
|
||||||
from src.models.user import UserModel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class UserRepository:
|
|
||||||
"""Repository for user operations"""
|
|
||||||
|
|
||||||
collection_name = "users"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def collection(self):
|
|
||||||
client = db.get_database()
|
|
||||||
if client is None:
|
|
||||||
logger.error("Database connection is None, cannot access collection")
|
|
||||||
raise RuntimeError("Database connection is not available")
|
|
||||||
# Use the Firestore client to get the collection
|
|
||||||
return client.collection(self.collection_name)
|
|
||||||
|
|
||||||
async def create(self, user: UserModel) -> UserModel:
|
|
||||||
"""
|
|
||||||
Create a new user
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user: User data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created user with ID
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Convert user to dict for Firestore
|
|
||||||
user_dict = user.dict(by_alias=True)
|
|
||||||
|
|
||||||
# Add document to Firestore and get reference
|
|
||||||
doc_ref = self.collection.document()
|
|
||||||
doc_ref.set(user_dict)
|
|
||||||
|
|
||||||
# Update ID and retrieve the created user
|
|
||||||
user.id = doc_ref.id
|
|
||||||
logger.info(f"User created: {doc_ref.id}")
|
|
||||||
|
|
||||||
return user
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating user: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_id(self, user_id: ObjectId) -> Optional[UserModel]:
|
|
||||||
"""
|
|
||||||
Get user by ID
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
doc_ref = self.collection.document(str(user_id))
|
|
||||||
user_doc = doc_ref.get()
|
|
||||||
|
|
||||||
if user_doc.exists:
|
|
||||||
user_data = user_doc.to_dict()
|
|
||||||
user_data['id'] = user_doc.id
|
|
||||||
return UserModel(**user_data)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting user by ID: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_email(self, email: str) -> Optional[UserModel]:
|
|
||||||
"""
|
|
||||||
Get user by email
|
|
||||||
|
|
||||||
Args:
|
|
||||||
email: User email
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
query = self.collection.where("email", "==", email).limit(1)
|
|
||||||
results = query.stream()
|
|
||||||
|
|
||||||
for doc in results:
|
|
||||||
user_data = doc.to_dict()
|
|
||||||
user_data['id'] = doc.id
|
|
||||||
return UserModel(**user_data)
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting user by email: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def get_by_team(self, team_id: ObjectId) -> List[UserModel]:
|
|
||||||
"""
|
|
||||||
Get users by team ID
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team_id: Team ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of users in the team
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
users = []
|
|
||||||
query = self.collection.where("team_id", "==", str(team_id))
|
|
||||||
results = query.stream()
|
|
||||||
|
|
||||||
for doc in results:
|
|
||||||
user_data = doc.to_dict()
|
|
||||||
user_data['id'] = doc.id
|
|
||||||
users.append(UserModel(**user_data))
|
|
||||||
return users
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting users by team: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update(self, user_id: ObjectId, user_data: dict) -> Optional[UserModel]:
|
|
||||||
"""
|
|
||||||
Update user
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
user_data: Update data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated user if found, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Add updated_at timestamp
|
|
||||||
user_data["updated_at"] = datetime.utcnow()
|
|
||||||
|
|
||||||
# Don't allow updating _id or id
|
|
||||||
if "_id" in user_data:
|
|
||||||
del user_data["_id"]
|
|
||||||
if "id" in user_data:
|
|
||||||
del user_data["id"]
|
|
||||||
|
|
||||||
# Get document reference and update
|
|
||||||
doc_ref = self.collection.document(str(user_id))
|
|
||||||
doc_ref.update(user_data)
|
|
||||||
|
|
||||||
# Get updated user
|
|
||||||
updated_user = await self.get_by_id(user_id)
|
|
||||||
if not updated_user:
|
|
||||||
logger.warning(f"No user found after update for ID: {user_id}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return updated_user
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating user: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def delete(self, user_id: ObjectId) -> bool:
|
|
||||||
"""
|
|
||||||
Delete user
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if user was deleted, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
doc_ref = self.collection.document(str(user_id))
|
|
||||||
doc_ref.delete()
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deleting user: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def update_last_login(self, user_id: ObjectId) -> None:
|
|
||||||
"""
|
|
||||||
Update user's last login timestamp
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
doc_ref = self.collection.document(str(user_id))
|
|
||||||
doc_ref.update({"last_login": datetime.utcnow()})
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error updating user last login: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Create a singleton repository
|
|
||||||
user_repository = UserRepository()
|
|
||||||
Loading…
x
Reference in New Issue
Block a user