This commit is contained in:
johnpccd 2025-05-24 06:21:38 +02:00
parent 9a70df0887
commit 56c8f7b944
7 changed files with 399 additions and 47 deletions

2
.gitignore vendored
View File

@ -53,3 +53,5 @@ coverage.xml
# Terraform # Terraform
*.tfvars *.tfvars
.terraform .terraform
firestore-credentials.json

29
main.py
View File

@ -8,9 +8,10 @@ from fastapi.openapi.utils import get_openapi
# Import API routers # Import API routers
from src.api.v1 import teams, users, images, auth, search from src.api.v1 import teams, users, images, auth, search
# Import configuration # Import configuration and database
from src.core.config import settings from src.core.config import settings
from src.core.logging import setup_logging from src.core.logging import setup_logging
from src.db import db
# Setup logging # Setup logging
setup_logging() setup_logging()
@ -26,6 +27,14 @@ app = FastAPI(
openapi_url="/api/v1/openapi.json" openapi_url="/api/v1/openapi.json"
) )
# Connect to database
try:
db.connect_to_database()
logger.info("Database connection initialized")
except Exception as e:
logger.error(f"Failed to connect to database: {e}", exc_info=True)
# We'll continue without database for Swagger UI to work, but operations will fail
# Set up CORS # Set up CORS
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -46,9 +55,19 @@ app.include_router(search.router, prefix="/api/v1")
@app.exception_handler(Exception) @app.exception_handler(Exception)
async def general_exception_handler(request: Request, exc: Exception): async def general_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled exception: {exc}", exc_info=True) logger.error(f"Unhandled exception: {exc}", exc_info=True)
# Include more details for easier debugging
error_details = {
"detail": "Internal server error",
"type": type(exc).__name__,
"path": request.url.path
}
# Only include exception message in development mode
if settings.ENVIRONMENT == "development":
error_details["message"] = str(exc)
return JSONResponse( return JSONResponse(
status_code=500, status_code=500,
content={"detail": "Internal server error"} content=error_details
) )
# Custom Swagger UI with API key authentication # Custom Swagger UI with API key authentication
@ -99,6 +118,12 @@ app.openapi = custom_openapi
async def root(): async def root():
return {"message": "Welcome to the Image Management API. Please see /docs for API documentation."} return {"message": "Welcome to the Image Management API. Please see /docs for API documentation."}
# Shutdown handler to close database connections
@app.on_event("shutdown")
async def shutdown_event():
logger.info("Application shutting down")
db.close_database_connection()
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)

76
scripts/create_admin.py Normal file
View File

@ -0,0 +1,76 @@
import os
import sys
import asyncio
import logging
from bson import ObjectId
# Add the project root to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Import repositories
from src.db.repositories.team_repository import team_repository
from src.db.repositories.user_repository import user_repository
from src.db.repositories.api_key_repository import api_key_repository
# Import models
from src.db.models.team import TeamModel
from src.db.models.user import UserModel
from src.db.models.api_key import ApiKeyModel
# Import security functions
from src.core.security import generate_api_key, calculate_expiry_date
async def create_admin():
# Create a new team
print("Creating admin team...")
team = TeamModel(
name="Admin Team",
description="Default admin team for system administration"
)
created_team = await team_repository.create(team)
print(f"Created team with ID: {created_team.id}")
# Create admin user
print("Creating admin user...")
user = UserModel(
name="Admin User",
email="admin@example.com",
team_id=created_team.id,
is_admin=True,
is_active=True
)
created_user = await user_repository.create(user)
print(f"Created admin user with ID: {created_user.id}")
# Generate API key
print("Generating API key...")
raw_key, hashed_key = generate_api_key(str(created_team.id), str(created_user.id))
expiry_date = calculate_expiry_date()
# Create API key in database
api_key = ApiKeyModel(
key_hash=hashed_key,
user_id=created_user.id,
team_id=created_team.id,
name="Admin API Key",
description="Initial API key for admin user",
expiry_date=expiry_date,
is_active=True
)
created_key = await api_key_repository.create(api_key)
print(f"Created API key with ID: {created_key.id}")
print(f"API Key (save this, it won't be shown again): {raw_key}")
return {
"team_id": str(created_team.id),
"user_id": str(created_user.id),
"api_key": raw_key
}
if __name__ == "__main__":
print("Creating admin user and API key...")
result = asyncio.run(create_admin())
print("\nSetup complete! Use the API key to authenticate API calls.")

View File

@ -0,0 +1,69 @@
import os
import sys
import hmac
import hashlib
import secrets
import string
# Add the project root to the Python path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# Get the API key secret from environment
from src.core.config import settings
def generate_api_key(team_id="dev-team", user_id="dev-admin"):
"""
Generate a secure API key and its hashed value
Args:
team_id: Team ID for which the key is generated
user_id: User ID for which the key is generated
Returns:
Tuple of (raw_api_key, hashed_api_key)
"""
# Generate a random key prefix (visible part)
prefix = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(8))
# Generate a secure random token for the key
random_part = secrets.token_hex(16)
# Format: prefix.random_part
raw_api_key = f"{prefix}.{random_part}"
# Hash the API key for storage
hashed_api_key = hash_api_key(raw_api_key)
return raw_api_key, hashed_api_key
def hash_api_key(api_key: str) -> str:
"""
Create a secure hash of the API key for storage
Args:
api_key: The raw API key
Returns:
Hashed API key
"""
return hmac.new(
settings.API_KEY_SECRET.encode(),
api_key.encode(),
hashlib.sha256
).hexdigest()
if __name__ == "__main__":
# Generate a development API key
api_key, key_hash = generate_api_key()
print("\n====== DEVELOPMENT API KEY ======")
print(f"API Key: {api_key}")
print(f"Key Hash: {key_hash}")
print("\nCOPY THIS API KEY AND USE IT IN YOUR SWAGGER UI!")
print("Header Name: X-API-Key")
print("Header Value: <the API key value above>")
print("===============================")
print("\nNote: This is a generated key, but since there's no database setup,")
print("you won't be able to use it with the API until the key is added to the database.")
print("This would be useful if you developed a bypass_auth mode for development.")
print("For now, please check with the development team for API key access.")

View File

@ -0,0 +1,97 @@
#!/usr/bin/env python3
"""
Script to set up Firestore credentials for development and deployment
"""
import os
import json
import sys
import argparse
from pathlib import Path
def create_env_file(project_id, credentials_file="firestore-credentials.json"):
"""Create a .env file with the necessary environment variables"""
env_content = f"""# Firestore Settings
FIRESTORE_PROJECT_ID={project_id}
FIRESTORE_CREDENTIALS_FILE={credentials_file}
# Google Cloud Storage Settings
GCS_BUCKET_NAME={project_id}-storage
GCS_CREDENTIALS_FILE={credentials_file}
# Security settings
API_KEY_SECRET=development-secret-key-change-in-production
API_KEY_EXPIRY_DAYS=365
# Vector Database Settings
VECTOR_DB_API_KEY=
VECTOR_DB_ENVIRONMENT=
VECTOR_DB_INDEX_NAME=image-embeddings
# Other Settings
ENVIRONMENT=development
LOG_LEVEL=INFO
RATE_LIMIT_PER_MINUTE=100
"""
with open(".env", "w") as f:
f.write(env_content)
print("Created .env file with Firestore settings")
def main():
parser = argparse.ArgumentParser(description='Set up Firestore credentials')
parser.add_argument('--key-file', type=str, help='Path to the service account key file')
parser.add_argument('--project-id', type=str, help='Google Cloud project ID')
parser.add_argument('--create-env', action='store_true', help='Create .env file')
args = parser.parse_args()
# Ensure we have a project ID
project_id = args.project_id
if not project_id:
if args.key_file and os.path.exists(args.key_file):
try:
with open(args.key_file, 'r') as f:
key_data = json.load(f)
project_id = key_data.get('project_id')
if project_id:
print(f"Using project ID from key file: {project_id}")
except Exception as e:
print(f"Error reading key file: {e}")
sys.exit(1)
if not project_id:
print("Error: Project ID is required")
parser.print_help()
sys.exit(1)
# Handle key file
target_key_file = "firestore-credentials.json"
if args.key_file and os.path.exists(args.key_file):
# Copy the key file to the target location
try:
with open(args.key_file, 'r') as src, open(target_key_file, 'w') as dst:
key_data = json.load(src)
json.dump(key_data, dst, indent=2)
print(f"Copied service account key to {target_key_file}")
except Exception as e:
print(f"Error copying key file: {e}")
sys.exit(1)
else:
print("Warning: No service account key file provided")
print(f"You need to place your service account key in {target_key_file}")
# Create .env file if requested
if args.create_env:
create_env_file(project_id, target_key_file)
print("\nSetup complete!")
print("\nFor development:")
print(f"1. Make sure {target_key_file} exists in the project root")
print("2. Ensure environment variables are set in .env file")
print("\nFor deployment:")
print("1. For Cloud Run, set environment variables in deployment config")
print("2. Make sure to securely manage service account key")
if __name__ == "__main__":
main()

View File

@ -1,5 +1,8 @@
import logging import logging
import os
import json
from google.cloud import firestore from google.cloud import firestore
from google.oauth2 import service_account
from src.core.config import settings from src.core.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -10,10 +13,52 @@ class Database:
def connect_to_database(self): def connect_to_database(self):
"""Create database connection.""" """Create database connection."""
try: try:
self.client = firestore.Client() # Print project ID for debugging
logger.info("Connected to Firestore") logger.info(f"Attempting to connect to Firestore with project ID: {settings.FIRESTORE_PROJECT_ID}")
# First try using Application Default Credentials (for Cloud environments)
try:
logger.info("Attempting to connect using Application Default Credentials")
self.client = firestore.Client(project=settings.FIRESTORE_PROJECT_ID)
# Test connection by trying to access a collection
self.client.collection('test').limit(1).get()
logger.info(f"Connected to Firestore project using Application Default Credentials: {settings.FIRESTORE_PROJECT_ID}")
return
except Exception as adc_error:
logger.error(f"Application Default Credentials failed: {adc_error}", exc_info=True)
# Fall back to service account file
credentials_path = settings.FIRESTORE_CREDENTIALS_FILE
if not os.path.exists(credentials_path):
logger.error(f"Firestore credentials file not found: {credentials_path}")
raise FileNotFoundError(f"Credentials file not found: {credentials_path}")
# Print key file contents (without sensitive parts) for debugging
try:
with open(credentials_path, 'r') as f:
key_data = json.load(f)
# Log non-sensitive parts of the key
logger.info(f"Using credentials file with project_id: {key_data.get('project_id')}")
logger.info(f"Client email: {key_data.get('client_email')}")
logger.info(f"Key file type: {key_data.get('type')}")
except Exception as e:
logger.error(f"Error reading key file: {e}")
# Load credentials
credentials = service_account.Credentials.from_service_account_file(credentials_path)
# Initialize Firestore client
self.client = firestore.Client(
project=settings.FIRESTORE_PROJECT_ID,
credentials=credentials
)
# Test connection by trying to access a collection
self.client.collection('test').limit(1).get()
logger.info(f"Connected to Firestore project using credentials file: {settings.FIRESTORE_PROJECT_ID}")
except Exception as e: except Exception as e:
logger.error(f"Failed to connect to Firestore: {e}") logger.error(f"Failed to connect to Firestore: {e}", exc_info=True)
raise raise
def close_database_connection(self): def close_database_connection(self):
@ -27,6 +72,20 @@ class Database:
def get_database(self): def get_database(self):
"""Get the database instance.""" """Get the database instance."""
if self.client is None:
logger.warning("Database client is None. Attempting to reconnect...")
try:
self.connect_to_database()
except Exception as e:
logger.error(f"Failed to reconnect to database: {e}")
# Return None to avoid further errors but log this issue
return None
# Verify that client is properly initialized
if self.client is None:
logger.error("Database client is still None after reconnect attempt")
return None
return self.client return self.client
# Create a singleton database instance # Create a singleton database instance

View File

@ -15,7 +15,11 @@ class ApiKeyRepository:
@property @property
def collection(self): def collection(self):
return db.get_database()[self.collection_name] database = db.get_database()
if database is None:
logger.error("Database connection is None, cannot access collection")
raise RuntimeError("Database connection is not available")
return database.collection(self.collection_name)
async def create(self, api_key: ApiKeyModel) -> ApiKeyModel: async def create(self, api_key: ApiKeyModel) -> ApiKeyModel:
""" """
@ -28,15 +32,26 @@ class ApiKeyRepository:
Created API key with ID Created API key with ID
""" """
try: try:
result = await self.collection.insert_one(api_key.dict(by_alias=True)) # Convert to dict for Firestore
created_key = await self.get_by_id(result.inserted_id) api_key_dict = api_key.dict(by_alias=True)
logger.info(f"API key created: {result.inserted_id}")
return created_key # Create a new document with an auto-generated ID
doc_ref = self.collection.document()
# Set the ID in the model and data
api_key.id = doc_ref.id
api_key_dict['id'] = doc_ref.id
# Save to Firestore
doc_ref.set(api_key_dict)
logger.info(f"API key created: {doc_ref.id}")
return api_key
except Exception as e: except Exception as e:
logger.error(f"Error creating API key: {e}") logger.error(f"Error creating API key: {e}", exc_info=True)
raise raise
async def get_by_id(self, key_id: ObjectId) -> Optional[ApiKeyModel]: async def get_by_id(self, key_id: str) -> Optional[ApiKeyModel]:
""" """
Get API key by ID Get API key by ID
@ -47,12 +62,15 @@ class ApiKeyRepository:
API key if found, None otherwise API key if found, None otherwise
""" """
try: try:
key = await self.collection.find_one({"_id": key_id}) doc_ref = self.collection.document(key_id)
if key: doc = doc_ref.get()
return ApiKeyModel(**key) if doc.exists:
data = doc.to_dict()
data['id'] = doc.id
return ApiKeyModel(**data)
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error getting API key by ID: {e}") logger.error(f"Error getting API key by ID: {e}", exc_info=True)
raise raise
async def get_by_hash(self, key_hash: str) -> Optional[ApiKeyModel]: async def get_by_hash(self, key_hash: str) -> Optional[ApiKeyModel]:
@ -66,15 +84,18 @@ class ApiKeyRepository:
API key if found, None otherwise API key if found, None otherwise
""" """
try: try:
key = await self.collection.find_one({"key_hash": key_hash}) query = self.collection.where("key_hash", "==", key_hash).limit(1)
if key: results = query.stream()
return ApiKeyModel(**key) for doc in results:
data = doc.to_dict()
data['id'] = doc.id
return ApiKeyModel(**data)
return None return None
except Exception as e: except Exception as e:
logger.error(f"Error getting API key by hash: {e}") logger.error(f"Error getting API key by hash: {e}", exc_info=True)
raise raise
async def get_by_user(self, user_id: ObjectId) -> List[ApiKeyModel]: async def get_by_user(self, user_id: str) -> List[ApiKeyModel]:
""" """
Get API keys by user ID Get API keys by user ID
@ -86,15 +107,18 @@ class ApiKeyRepository:
""" """
try: try:
keys = [] keys = []
cursor = self.collection.find({"user_id": user_id}) query = self.collection.where("user_id", "==", user_id)
async for document in cursor: results = query.stream()
keys.append(ApiKeyModel(**document)) for doc in results:
data = doc.to_dict()
data['id'] = doc.id
keys.append(ApiKeyModel(**data))
return keys return keys
except Exception as e: except Exception as e:
logger.error(f"Error getting API keys by user: {e}") logger.error(f"Error getting API keys by user: {e}", exc_info=True)
raise raise
async def get_by_team(self, team_id: ObjectId) -> List[ApiKeyModel]: async def get_by_team(self, team_id: str) -> List[ApiKeyModel]:
""" """
Get API keys by team ID Get API keys by team ID
@ -106,15 +130,18 @@ class ApiKeyRepository:
""" """
try: try:
keys = [] keys = []
cursor = self.collection.find({"team_id": team_id}) query = self.collection.where("team_id", "==", team_id)
async for document in cursor: results = query.stream()
keys.append(ApiKeyModel(**document)) for doc in results:
data = doc.to_dict()
data['id'] = doc.id
keys.append(ApiKeyModel(**data))
return keys return keys
except Exception as e: except Exception as e:
logger.error(f"Error getting API keys by team: {e}") logger.error(f"Error getting API keys by team: {e}", exc_info=True)
raise raise
async def update_last_used(self, key_id: ObjectId) -> None: async def update_last_used(self, key_id: str) -> None:
""" """
Update API key's last used timestamp Update API key's last used timestamp
@ -122,15 +149,13 @@ class ApiKeyRepository:
key_id: API key ID key_id: API key ID
""" """
try: try:
await self.collection.update_one( doc_ref = self.collection.document(key_id)
{"_id": key_id}, doc_ref.update({"last_used": datetime.utcnow()})
{"$set": {"last_used": datetime.utcnow()}}
)
except Exception as e: except Exception as e:
logger.error(f"Error updating API key last used: {e}") logger.error(f"Error updating API key last used: {e}", exc_info=True)
raise raise
async def deactivate(self, key_id: ObjectId) -> bool: async def deactivate(self, key_id: str) -> bool:
""" """
Deactivate API key Deactivate API key
@ -141,16 +166,14 @@ class ApiKeyRepository:
True if deactivated, False otherwise True if deactivated, False otherwise
""" """
try: try:
result = await self.collection.update_one( doc_ref = self.collection.document(key_id)
{"_id": key_id}, doc_ref.update({"is_active": False})
{"$set": {"is_active": False}} return True
)
return result.modified_count > 0
except Exception as e: except Exception as e:
logger.error(f"Error deactivating API key: {e}") logger.error(f"Error deactivating API key: {e}", exc_info=True)
raise raise
async def delete(self, key_id: ObjectId) -> bool: async def delete(self, key_id: str) -> bool:
""" """
Delete API key Delete API key
@ -161,10 +184,11 @@ class ApiKeyRepository:
True if deleted, False otherwise True if deleted, False otherwise
""" """
try: try:
result = await self.collection.delete_one({"_id": key_id}) doc_ref = self.collection.document(key_id)
return result.deleted_count > 0 doc_ref.delete()
return True
except Exception as e: except Exception as e:
logger.error(f"Error deleting API key: {e}") logger.error(f"Error deleting API key: {e}", exc_info=True)
raise raise
# Create a singleton repository # Create a singleton repository