From 56c8f7b944d2d1720c1e14db7e30f893c7843319 Mon Sep 17 00:00:00 2001 From: johnpccd Date: Sat, 24 May 2025 06:21:38 +0200 Subject: [PATCH] cp --- .gitignore | 2 + main.py | 29 +++++- scripts/create_admin.py | 76 +++++++++++++++ scripts/generate_dev_key.py | 69 ++++++++++++++ scripts/setup_credentials.py | 97 +++++++++++++++++++ src/db/__init__.py | 65 ++++++++++++- src/db/repositories/api_key_repository.py | 108 +++++++++++++--------- 7 files changed, 399 insertions(+), 47 deletions(-) create mode 100644 scripts/create_admin.py create mode 100644 scripts/generate_dev_key.py create mode 100644 scripts/setup_credentials.py diff --git a/.gitignore b/.gitignore index 1631e6e..9be78d1 100644 --- a/.gitignore +++ b/.gitignore @@ -53,3 +53,5 @@ coverage.xml # Terraform *.tfvars .terraform + +firestore-credentials.json diff --git a/main.py b/main.py index 5edfe63..2466d18 100644 --- a/main.py +++ b/main.py @@ -8,9 +8,10 @@ from fastapi.openapi.utils import get_openapi # Import API routers from src.api.v1 import teams, users, images, auth, search -# Import configuration +# Import configuration and database from src.core.config import settings from src.core.logging import setup_logging +from src.db import db # Setup logging setup_logging() @@ -26,6 +27,14 @@ app = FastAPI( openapi_url="/api/v1/openapi.json" ) +# Connect to database +try: + db.connect_to_database() + logger.info("Database connection initialized") +except Exception as e: + logger.error(f"Failed to connect to database: {e}", exc_info=True) + # We'll continue without database for Swagger UI to work, but operations will fail + # Set up CORS app.add_middleware( CORSMiddleware, @@ -46,9 +55,19 @@ app.include_router(search.router, prefix="/api/v1") @app.exception_handler(Exception) async def general_exception_handler(request: Request, exc: Exception): logger.error(f"Unhandled exception: {exc}", exc_info=True) + # Include more details for easier debugging + error_details = { + "detail": "Internal server error", + "type": type(exc).__name__, + "path": request.url.path + } + # Only include exception message in development mode + if settings.ENVIRONMENT == "development": + error_details["message"] = str(exc) + return JSONResponse( status_code=500, - content={"detail": "Internal server error"} + content=error_details ) # Custom Swagger UI with API key authentication @@ -99,6 +118,12 @@ app.openapi = custom_openapi async def root(): return {"message": "Welcome to the Image Management API. Please see /docs for API documentation."} +# Shutdown handler to close database connections +@app.on_event("shutdown") +async def shutdown_event(): + logger.info("Application shutting down") + db.close_database_connection() + if __name__ == "__main__": import uvicorn uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) \ No newline at end of file diff --git a/scripts/create_admin.py b/scripts/create_admin.py new file mode 100644 index 0000000..06b32c7 --- /dev/null +++ b/scripts/create_admin.py @@ -0,0 +1,76 @@ +import os +import sys +import asyncio +import logging +from bson import ObjectId + +# Add the project root to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Import repositories +from src.db.repositories.team_repository import team_repository +from src.db.repositories.user_repository import user_repository +from src.db.repositories.api_key_repository import api_key_repository + +# Import models +from src.db.models.team import TeamModel +from src.db.models.user import UserModel +from src.db.models.api_key import ApiKeyModel + +# Import security functions +from src.core.security import generate_api_key, calculate_expiry_date + +async def create_admin(): + # Create a new team + print("Creating admin team...") + team = TeamModel( + name="Admin Team", + description="Default admin team for system administration" + ) + + created_team = await team_repository.create(team) + print(f"Created team with ID: {created_team.id}") + + # Create admin user + print("Creating admin user...") + user = UserModel( + name="Admin User", + email="admin@example.com", + team_id=created_team.id, + is_admin=True, + is_active=True + ) + + created_user = await user_repository.create(user) + print(f"Created admin user with ID: {created_user.id}") + + # Generate API key + print("Generating API key...") + raw_key, hashed_key = generate_api_key(str(created_team.id), str(created_user.id)) + expiry_date = calculate_expiry_date() + + # Create API key in database + api_key = ApiKeyModel( + key_hash=hashed_key, + user_id=created_user.id, + team_id=created_team.id, + name="Admin API Key", + description="Initial API key for admin user", + expiry_date=expiry_date, + is_active=True + ) + + created_key = await api_key_repository.create(api_key) + print(f"Created API key with ID: {created_key.id}") + print(f"API Key (save this, it won't be shown again): {raw_key}") + + return { + "team_id": str(created_team.id), + "user_id": str(created_user.id), + "api_key": raw_key + } + +if __name__ == "__main__": + print("Creating admin user and API key...") + result = asyncio.run(create_admin()) + print("\nSetup complete! Use the API key to authenticate API calls.") \ No newline at end of file diff --git a/scripts/generate_dev_key.py b/scripts/generate_dev_key.py new file mode 100644 index 0000000..c0e62bc --- /dev/null +++ b/scripts/generate_dev_key.py @@ -0,0 +1,69 @@ +import os +import sys +import hmac +import hashlib +import secrets +import string + +# Add the project root to the Python path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Get the API key secret from environment +from src.core.config import settings + +def generate_api_key(team_id="dev-team", user_id="dev-admin"): + """ + Generate a secure API key and its hashed value + + Args: + team_id: Team ID for which the key is generated + user_id: User ID for which the key is generated + + Returns: + Tuple of (raw_api_key, hashed_api_key) + """ + # Generate a random key prefix (visible part) + prefix = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(8)) + + # Generate a secure random token for the key + random_part = secrets.token_hex(16) + + # Format: prefix.random_part + raw_api_key = f"{prefix}.{random_part}" + + # Hash the API key for storage + hashed_api_key = hash_api_key(raw_api_key) + + return raw_api_key, hashed_api_key + +def hash_api_key(api_key: str) -> str: + """ + Create a secure hash of the API key for storage + + Args: + api_key: The raw API key + + Returns: + Hashed API key + """ + return hmac.new( + settings.API_KEY_SECRET.encode(), + api_key.encode(), + hashlib.sha256 + ).hexdigest() + +if __name__ == "__main__": + # Generate a development API key + api_key, key_hash = generate_api_key() + + print("\n====== DEVELOPMENT API KEY ======") + print(f"API Key: {api_key}") + print(f"Key Hash: {key_hash}") + print("\nCOPY THIS API KEY AND USE IT IN YOUR SWAGGER UI!") + print("Header Name: X-API-Key") + print("Header Value: ") + print("===============================") + print("\nNote: This is a generated key, but since there's no database setup,") + print("you won't be able to use it with the API until the key is added to the database.") + print("This would be useful if you developed a bypass_auth mode for development.") + print("For now, please check with the development team for API key access.") \ No newline at end of file diff --git a/scripts/setup_credentials.py b/scripts/setup_credentials.py new file mode 100644 index 0000000..5b35fd9 --- /dev/null +++ b/scripts/setup_credentials.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Script to set up Firestore credentials for development and deployment +""" +import os +import json +import sys +import argparse +from pathlib import Path + +def create_env_file(project_id, credentials_file="firestore-credentials.json"): + """Create a .env file with the necessary environment variables""" + env_content = f"""# Firestore Settings +FIRESTORE_PROJECT_ID={project_id} +FIRESTORE_CREDENTIALS_FILE={credentials_file} + +# Google Cloud Storage Settings +GCS_BUCKET_NAME={project_id}-storage +GCS_CREDENTIALS_FILE={credentials_file} + +# Security settings +API_KEY_SECRET=development-secret-key-change-in-production +API_KEY_EXPIRY_DAYS=365 + +# Vector Database Settings +VECTOR_DB_API_KEY= +VECTOR_DB_ENVIRONMENT= +VECTOR_DB_INDEX_NAME=image-embeddings + +# Other Settings +ENVIRONMENT=development +LOG_LEVEL=INFO +RATE_LIMIT_PER_MINUTE=100 +""" + + with open(".env", "w") as f: + f.write(env_content) + + print("Created .env file with Firestore settings") + +def main(): + parser = argparse.ArgumentParser(description='Set up Firestore credentials') + parser.add_argument('--key-file', type=str, help='Path to the service account key file') + parser.add_argument('--project-id', type=str, help='Google Cloud project ID') + parser.add_argument('--create-env', action='store_true', help='Create .env file') + + args = parser.parse_args() + + # Ensure we have a project ID + project_id = args.project_id + if not project_id: + if args.key_file and os.path.exists(args.key_file): + try: + with open(args.key_file, 'r') as f: + key_data = json.load(f) + project_id = key_data.get('project_id') + if project_id: + print(f"Using project ID from key file: {project_id}") + except Exception as e: + print(f"Error reading key file: {e}") + sys.exit(1) + + if not project_id: + print("Error: Project ID is required") + parser.print_help() + sys.exit(1) + + # Handle key file + target_key_file = "firestore-credentials.json" + if args.key_file and os.path.exists(args.key_file): + # Copy the key file to the target location + try: + with open(args.key_file, 'r') as src, open(target_key_file, 'w') as dst: + key_data = json.load(src) + json.dump(key_data, dst, indent=2) + print(f"Copied service account key to {target_key_file}") + except Exception as e: + print(f"Error copying key file: {e}") + sys.exit(1) + else: + print("Warning: No service account key file provided") + print(f"You need to place your service account key in {target_key_file}") + + # Create .env file if requested + if args.create_env: + create_env_file(project_id, target_key_file) + + print("\nSetup complete!") + print("\nFor development:") + print(f"1. Make sure {target_key_file} exists in the project root") + print("2. Ensure environment variables are set in .env file") + print("\nFor deployment:") + print("1. For Cloud Run, set environment variables in deployment config") + print("2. Make sure to securely manage service account key") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/db/__init__.py b/src/db/__init__.py index 0e6b700..328efd3 100644 --- a/src/db/__init__.py +++ b/src/db/__init__.py @@ -1,5 +1,8 @@ import logging +import os +import json from google.cloud import firestore +from google.oauth2 import service_account from src.core.config import settings logger = logging.getLogger(__name__) @@ -10,10 +13,52 @@ class Database: def connect_to_database(self): """Create database connection.""" try: - self.client = firestore.Client() - logger.info("Connected to Firestore") + # Print project ID for debugging + logger.info(f"Attempting to connect to Firestore with project ID: {settings.FIRESTORE_PROJECT_ID}") + + # First try using Application Default Credentials (for Cloud environments) + try: + logger.info("Attempting to connect using Application Default Credentials") + self.client = firestore.Client(project=settings.FIRESTORE_PROJECT_ID) + # Test connection by trying to access a collection + self.client.collection('test').limit(1).get() + logger.info(f"Connected to Firestore project using Application Default Credentials: {settings.FIRESTORE_PROJECT_ID}") + return + except Exception as adc_error: + logger.error(f"Application Default Credentials failed: {adc_error}", exc_info=True) + + # Fall back to service account file + credentials_path = settings.FIRESTORE_CREDENTIALS_FILE + if not os.path.exists(credentials_path): + logger.error(f"Firestore credentials file not found: {credentials_path}") + raise FileNotFoundError(f"Credentials file not found: {credentials_path}") + + # Print key file contents (without sensitive parts) for debugging + try: + with open(credentials_path, 'r') as f: + key_data = json.load(f) + # Log non-sensitive parts of the key + logger.info(f"Using credentials file with project_id: {key_data.get('project_id')}") + logger.info(f"Client email: {key_data.get('client_email')}") + logger.info(f"Key file type: {key_data.get('type')}") + except Exception as e: + logger.error(f"Error reading key file: {e}") + + # Load credentials + credentials = service_account.Credentials.from_service_account_file(credentials_path) + + # Initialize Firestore client + self.client = firestore.Client( + project=settings.FIRESTORE_PROJECT_ID, + credentials=credentials + ) + + # Test connection by trying to access a collection + self.client.collection('test').limit(1).get() + + logger.info(f"Connected to Firestore project using credentials file: {settings.FIRESTORE_PROJECT_ID}") except Exception as e: - logger.error(f"Failed to connect to Firestore: {e}") + logger.error(f"Failed to connect to Firestore: {e}", exc_info=True) raise def close_database_connection(self): @@ -27,6 +72,20 @@ class Database: def get_database(self): """Get the database instance.""" + if self.client is None: + logger.warning("Database client is None. Attempting to reconnect...") + try: + self.connect_to_database() + except Exception as e: + logger.error(f"Failed to reconnect to database: {e}") + # Return None to avoid further errors but log this issue + return None + + # Verify that client is properly initialized + if self.client is None: + logger.error("Database client is still None after reconnect attempt") + return None + return self.client # Create a singleton database instance diff --git a/src/db/repositories/api_key_repository.py b/src/db/repositories/api_key_repository.py index a0bfaf1..cef2887 100644 --- a/src/db/repositories/api_key_repository.py +++ b/src/db/repositories/api_key_repository.py @@ -15,7 +15,11 @@ class ApiKeyRepository: @property def collection(self): - return db.get_database()[self.collection_name] + database = db.get_database() + if database is None: + logger.error("Database connection is None, cannot access collection") + raise RuntimeError("Database connection is not available") + return database.collection(self.collection_name) async def create(self, api_key: ApiKeyModel) -> ApiKeyModel: """ @@ -28,15 +32,26 @@ class ApiKeyRepository: Created API key with ID """ try: - result = await self.collection.insert_one(api_key.dict(by_alias=True)) - created_key = await self.get_by_id(result.inserted_id) - logger.info(f"API key created: {result.inserted_id}") - return created_key + # Convert to dict for Firestore + api_key_dict = api_key.dict(by_alias=True) + + # Create a new document with an auto-generated ID + doc_ref = self.collection.document() + + # Set the ID in the model and data + api_key.id = doc_ref.id + api_key_dict['id'] = doc_ref.id + + # Save to Firestore + doc_ref.set(api_key_dict) + + logger.info(f"API key created: {doc_ref.id}") + return api_key except Exception as e: - logger.error(f"Error creating API key: {e}") + logger.error(f"Error creating API key: {e}", exc_info=True) raise - async def get_by_id(self, key_id: ObjectId) -> Optional[ApiKeyModel]: + async def get_by_id(self, key_id: str) -> Optional[ApiKeyModel]: """ Get API key by ID @@ -47,12 +62,15 @@ class ApiKeyRepository: API key if found, None otherwise """ try: - key = await self.collection.find_one({"_id": key_id}) - if key: - return ApiKeyModel(**key) + doc_ref = self.collection.document(key_id) + doc = doc_ref.get() + if doc.exists: + data = doc.to_dict() + data['id'] = doc.id + return ApiKeyModel(**data) return None except Exception as e: - logger.error(f"Error getting API key by ID: {e}") + logger.error(f"Error getting API key by ID: {e}", exc_info=True) raise async def get_by_hash(self, key_hash: str) -> Optional[ApiKeyModel]: @@ -66,15 +84,18 @@ class ApiKeyRepository: API key if found, None otherwise """ try: - key = await self.collection.find_one({"key_hash": key_hash}) - if key: - return ApiKeyModel(**key) + query = self.collection.where("key_hash", "==", key_hash).limit(1) + results = query.stream() + for doc in results: + data = doc.to_dict() + data['id'] = doc.id + return ApiKeyModel(**data) return None except Exception as e: - logger.error(f"Error getting API key by hash: {e}") + logger.error(f"Error getting API key by hash: {e}", exc_info=True) raise - async def get_by_user(self, user_id: ObjectId) -> List[ApiKeyModel]: + async def get_by_user(self, user_id: str) -> List[ApiKeyModel]: """ Get API keys by user ID @@ -86,15 +107,18 @@ class ApiKeyRepository: """ try: keys = [] - cursor = self.collection.find({"user_id": user_id}) - async for document in cursor: - keys.append(ApiKeyModel(**document)) + query = self.collection.where("user_id", "==", user_id) + results = query.stream() + for doc in results: + data = doc.to_dict() + data['id'] = doc.id + keys.append(ApiKeyModel(**data)) return keys except Exception as e: - logger.error(f"Error getting API keys by user: {e}") + logger.error(f"Error getting API keys by user: {e}", exc_info=True) raise - async def get_by_team(self, team_id: ObjectId) -> List[ApiKeyModel]: + async def get_by_team(self, team_id: str) -> List[ApiKeyModel]: """ Get API keys by team ID @@ -106,15 +130,18 @@ class ApiKeyRepository: """ try: keys = [] - cursor = self.collection.find({"team_id": team_id}) - async for document in cursor: - keys.append(ApiKeyModel(**document)) + query = self.collection.where("team_id", "==", team_id) + results = query.stream() + for doc in results: + data = doc.to_dict() + data['id'] = doc.id + keys.append(ApiKeyModel(**data)) return keys except Exception as e: - logger.error(f"Error getting API keys by team: {e}") + logger.error(f"Error getting API keys by team: {e}", exc_info=True) raise - async def update_last_used(self, key_id: ObjectId) -> None: + async def update_last_used(self, key_id: str) -> None: """ Update API key's last used timestamp @@ -122,15 +149,13 @@ class ApiKeyRepository: key_id: API key ID """ try: - await self.collection.update_one( - {"_id": key_id}, - {"$set": {"last_used": datetime.utcnow()}} - ) + doc_ref = self.collection.document(key_id) + doc_ref.update({"last_used": datetime.utcnow()}) except Exception as e: - logger.error(f"Error updating API key last used: {e}") + logger.error(f"Error updating API key last used: {e}", exc_info=True) raise - async def deactivate(self, key_id: ObjectId) -> bool: + async def deactivate(self, key_id: str) -> bool: """ Deactivate API key @@ -141,16 +166,14 @@ class ApiKeyRepository: True if deactivated, False otherwise """ try: - result = await self.collection.update_one( - {"_id": key_id}, - {"$set": {"is_active": False}} - ) - return result.modified_count > 0 + doc_ref = self.collection.document(key_id) + doc_ref.update({"is_active": False}) + return True except Exception as e: - logger.error(f"Error deactivating API key: {e}") + logger.error(f"Error deactivating API key: {e}", exc_info=True) raise - async def delete(self, key_id: ObjectId) -> bool: + async def delete(self, key_id: str) -> bool: """ Delete API key @@ -161,10 +184,11 @@ class ApiKeyRepository: True if deleted, False otherwise """ try: - result = await self.collection.delete_one({"_id": key_id}) - return result.deleted_count > 0 + doc_ref = self.collection.document(key_id) + doc_ref.delete() + return True except Exception as e: - logger.error(f"Error deleting API key: {e}") + logger.error(f"Error deleting API key: {e}", exc_info=True) raise # Create a singleton repository