refactor
This commit is contained in:
parent
ba2c00db38
commit
a26bd08d9c
@ -38,7 +38,6 @@ root/
|
||||
│ │ └── v1/ # API version 1 routes
|
||||
│ ├── auth/ # Authentication and authorization
|
||||
│ ├── config/ # Configuration management
|
||||
│ ├── core/ # Core application logic
|
||||
│ ├── db/ # Database layer
|
||||
│ │ ├── providers/ # Database providers (Firestore)
|
||||
│ │ └── repositories/ # Data access repositories
|
||||
@ -215,6 +214,14 @@ Uses Google's Vertex AI multimodal embedding model for generating high-quality i
|
||||
./deployment/deploy.sh --destroy
|
||||
```
|
||||
|
||||
|
||||
7. **Local Development**
|
||||
```bash
|
||||
./scripts/start.sh
|
||||
```
|
||||
|
||||
8. **Local Testing**
|
||||
|
||||
## API Endpoints
|
||||
|
||||
The API provides the following main endpoints with their authentication and pagination support:
|
||||
|
||||
@ -4,13 +4,11 @@ from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.repositories.api_key_repository import api_key_repository
|
||||
from src.db.repositories.user_repository import user_repository
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
from src.services.auth_service import AuthService
|
||||
from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse
|
||||
from src.schemas.team import TeamCreate
|
||||
from src.schemas.user import UserCreate
|
||||
from src.auth.security import generate_api_key, verify_api_key, calculate_expiry_date, is_expired, hash_api_key, get_current_user
|
||||
from src.auth.security import get_current_user
|
||||
from src.models.api_key import ApiKeyModel
|
||||
from src.models.team import TeamModel
|
||||
from src.models.user import UserModel
|
||||
@ -20,6 +18,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Authentication"], prefix="/auth")
|
||||
|
||||
# Initialize service
|
||||
auth_service = AuthService()
|
||||
|
||||
@router.post("/api-keys", response_model=ApiKeyWithValueResponse, status_code=201)
|
||||
async def create_api_key(key_data: ApiKeyCreate, request: Request, user_id: str, team_id: str):
|
||||
"""
|
||||
@ -31,67 +32,16 @@ async def create_api_key(key_data: ApiKeyCreate, request: Request, user_id: str,
|
||||
{"path": request.url.path, "method": request.method, "key_data": key_data.dict(), "user_id": user_id, "team_id": team_id}
|
||||
)
|
||||
|
||||
# Validate user_id and team_id
|
||||
try:
|
||||
target_user_id = ObjectId(user_id)
|
||||
target_team_id = ObjectId(team_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID or team ID")
|
||||
|
||||
# Verify user exists
|
||||
target_user = await user_repository.get_by_id(target_user_id)
|
||||
if not target_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Verify team exists
|
||||
team = await team_repository.get_by_id(target_team_id)
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
|
||||
# Verify user belongs to the team
|
||||
if target_user.team_id != target_team_id:
|
||||
raise HTTPException(status_code=400, detail="User does not belong to the specified team")
|
||||
|
||||
# If team_id is provided in key_data, validate it matches the parameter
|
||||
if key_data.team_id and key_data.team_id != team_id:
|
||||
raise HTTPException(status_code=400, detail="Team ID in request body does not match parameter")
|
||||
|
||||
# If user_id is provided in key_data, validate it matches the parameter
|
||||
if key_data.user_id and key_data.user_id != user_id:
|
||||
raise HTTPException(status_code=400, detail="User ID in request body does not match parameter")
|
||||
|
||||
# Generate API key with expiry date
|
||||
raw_key, hashed_key = generate_api_key(str(target_team_id), str(target_user_id))
|
||||
expiry_date = calculate_expiry_date()
|
||||
|
||||
# Create API key in database
|
||||
api_key = ApiKeyModel(
|
||||
key_hash=hashed_key,
|
||||
user_id=target_user_id,
|
||||
team_id=target_team_id,
|
||||
name=key_data.name,
|
||||
description=key_data.description,
|
||||
expiry_date=expiry_date,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
created_key = await api_key_repository.create(api_key)
|
||||
|
||||
# Convert to response model
|
||||
response = ApiKeyWithValueResponse(
|
||||
id=str(created_key.id),
|
||||
key=raw_key,
|
||||
name=created_key.name,
|
||||
description=created_key.description,
|
||||
team_id=str(created_key.team_id),
|
||||
user_id=str(created_key.user_id),
|
||||
created_at=created_key.created_at,
|
||||
expiry_date=created_key.expiry_date,
|
||||
last_used=created_key.last_used,
|
||||
is_active=created_key.is_active
|
||||
)
|
||||
|
||||
return response
|
||||
response = await auth_service.create_api_key_for_user_and_team(user_id, team_id, key_data)
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating API key: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.post("/admin/api-keys/{user_id}", response_model=ApiKeyWithValueResponse, status_code=201)
|
||||
async def create_api_key_for_user(
|
||||
@ -103,10 +53,6 @@ async def create_api_key_for_user(
|
||||
"""
|
||||
Create a new API key for a specific user (admin only)
|
||||
"""
|
||||
# Check if current user is admin
|
||||
if not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Admin access required")
|
||||
|
||||
log_request(
|
||||
{"path": request.url.path, "method": request.method, "target_user_id": user_id, "key_data": key_data.dict()},
|
||||
user_id=str(current_user.id),
|
||||
@ -114,52 +60,17 @@ async def create_api_key_for_user(
|
||||
)
|
||||
|
||||
try:
|
||||
target_user_obj_id = ObjectId(user_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||
|
||||
# Get the target user
|
||||
target_user = await user_repository.get_by_id(target_user_obj_id)
|
||||
if not target_user:
|
||||
raise HTTPException(status_code=404, detail="Target user not found")
|
||||
|
||||
# Check if target user's team exists
|
||||
team = await team_repository.get_by_id(target_user.team_id)
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Target user's team not found")
|
||||
|
||||
# Generate API key with expiry date
|
||||
raw_key, hashed_key = generate_api_key(str(target_user.team_id), str(target_user.id))
|
||||
expiry_date = calculate_expiry_date()
|
||||
|
||||
# Create API key in database
|
||||
api_key = ApiKeyModel(
|
||||
key_hash=hashed_key,
|
||||
user_id=target_user.id,
|
||||
team_id=target_user.team_id,
|
||||
name=key_data.name,
|
||||
description=key_data.description,
|
||||
expiry_date=expiry_date,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
created_key = await api_key_repository.create(api_key)
|
||||
|
||||
# Convert to response model
|
||||
response = ApiKeyWithValueResponse(
|
||||
id=str(created_key.id),
|
||||
key=raw_key,
|
||||
name=created_key.name,
|
||||
description=created_key.description,
|
||||
team_id=str(created_key.team_id),
|
||||
user_id=str(created_key.user_id),
|
||||
created_at=created_key.created_at,
|
||||
expiry_date=created_key.expiry_date,
|
||||
last_used=created_key.last_used,
|
||||
is_active=created_key.is_active
|
||||
)
|
||||
|
||||
return response
|
||||
response = await auth_service.create_api_key_for_user_by_admin(user_id, key_data, current_user)
|
||||
return response
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating API key for user: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/api-keys", response_model=ApiKeyListResponse)
|
||||
async def list_api_keys(request: Request, current_user = Depends(get_current_user)):
|
||||
@ -172,25 +83,12 @@ async def list_api_keys(request: Request, current_user = Depends(get_current_use
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
# Get API keys for user
|
||||
keys = await api_key_repository.get_by_user(current_user.id)
|
||||
|
||||
# Convert to response models
|
||||
response_keys = []
|
||||
for key in keys:
|
||||
response_keys.append(ApiKeyResponse(
|
||||
id=str(key.id),
|
||||
name=key.name,
|
||||
description=key.description,
|
||||
team_id=str(key.team_id),
|
||||
user_id=str(key.user_id),
|
||||
created_at=key.created_at,
|
||||
expiry_date=key.expiry_date,
|
||||
last_used=key.last_used,
|
||||
is_active=key.is_active
|
||||
))
|
||||
|
||||
return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys))
|
||||
try:
|
||||
response = await auth_service.list_user_api_keys(current_user)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error listing API keys: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.delete("/api-keys/{key_id}", status_code=204)
|
||||
async def revoke_api_key(key_id: str, request: Request, current_user = Depends(get_current_user)):
|
||||
@ -204,26 +102,17 @@ async def revoke_api_key(key_id: str, request: Request, current_user = Depends(g
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert string ID to ObjectId
|
||||
obj_id = ObjectId(key_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid key ID")
|
||||
|
||||
# Get the API key
|
||||
key = await api_key_repository.get_by_id(obj_id)
|
||||
if not key:
|
||||
raise HTTPException(status_code=404, detail="API key not found")
|
||||
|
||||
# Check if user owns the key or is an admin
|
||||
if key.user_id != current_user.id and not current_user.is_admin:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to revoke this API key")
|
||||
|
||||
# Deactivate the key
|
||||
result = await api_key_repository.deactivate(obj_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=500, detail="Failed to revoke API key")
|
||||
|
||||
return None
|
||||
await auth_service.revoke_api_key(key_id, current_user)
|
||||
return None
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error revoking API key: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/verify", status_code=200)
|
||||
async def verify_authentication(request: Request, current_user = Depends(get_current_user)):
|
||||
@ -236,10 +125,9 @@ async def verify_authentication(request: Request, current_user = Depends(get_cur
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
return {
|
||||
"user_id": str(current_user.id),
|
||||
"name": current_user.name,
|
||||
"email": current_user.email,
|
||||
"team_id": str(current_user.team_id),
|
||||
"is_admin": current_user.is_admin
|
||||
}
|
||||
try:
|
||||
response = await auth_service.verify_user_authentication(current_user)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error verifying authentication: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@ -6,12 +6,7 @@ from bson import ObjectId
|
||||
import io
|
||||
|
||||
from src.auth.security import get_current_user
|
||||
from src.db.repositories.image_repository import image_repository
|
||||
from src.services.storage import StorageService
|
||||
from src.services.image_processor import ImageProcessor
|
||||
from src.services.embedding_service import EmbeddingService
|
||||
from src.services.pubsub_service import pubsub_service
|
||||
from src.models.image import ImageModel
|
||||
from src.services.image_service import ImageService
|
||||
from src.models.user import UserModel
|
||||
from src.schemas.image import ImageResponse, ImageListResponse, ImageCreate, ImageUpdate
|
||||
from src.utils.logging import log_request
|
||||
@ -20,17 +15,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Images"], prefix="/images")
|
||||
|
||||
# Initialize services
|
||||
storage_service = StorageService()
|
||||
image_processor = ImageProcessor()
|
||||
embedding_service = EmbeddingService()
|
||||
|
||||
def generate_api_download_url(request: Request, image_id: str) -> str:
|
||||
"""
|
||||
Generate API download URL for an image
|
||||
"""
|
||||
base_url = str(request.base_url).rstrip('/')
|
||||
return f"{base_url}/api/v1/images/{image_id}/download"
|
||||
# Initialize service
|
||||
image_service = ImageService()
|
||||
|
||||
@router.post("", response_model=ImageResponse, status_code=201)
|
||||
async def upload_image(
|
||||
@ -49,86 +35,16 @@ async def upload_image(
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
# Validate file type
|
||||
if not file.content_type or not file.content_type.startswith('image/'):
|
||||
raise HTTPException(status_code=400, detail="File must be an image")
|
||||
|
||||
# Validate file size (10MB limit)
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
content = await file.read()
|
||||
if len(content) > max_size:
|
||||
raise HTTPException(status_code=400, detail="File size exceeds 10MB limit")
|
||||
|
||||
# Reset file pointer
|
||||
await file.seek(0)
|
||||
|
||||
try:
|
||||
# Upload to storage
|
||||
storage_path, content_type, file_size, metadata = await storage_service.upload_file(
|
||||
file, str(current_user.team_id)
|
||||
)
|
||||
|
||||
|
||||
# Create image record
|
||||
image = ImageModel(
|
||||
filename=file.filename,
|
||||
original_filename=file.filename,
|
||||
file_size=file_size,
|
||||
content_type=content_type,
|
||||
storage_path=storage_path,
|
||||
public_url=None, # Will be set after we have the image ID
|
||||
team_id=current_user.team_id,
|
||||
uploader_id=current_user.id,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
collection_id=ObjectId(collection_id) if collection_id else None
|
||||
)
|
||||
|
||||
# Save to database
|
||||
created_image = await image_repository.create(image)
|
||||
|
||||
# Generate API download URL now that we have the image ID
|
||||
api_download_url = generate_api_download_url(request, str(created_image.id))
|
||||
|
||||
# Update the image with the API download URL
|
||||
await image_repository.update(created_image.id, {"public_url": api_download_url})
|
||||
created_image.public_url = api_download_url
|
||||
|
||||
# Publish image processing task to Pub/Sub
|
||||
try:
|
||||
task_published = await pubsub_service.publish_image_processing_task(
|
||||
image_id=str(created_image.id),
|
||||
storage_path=storage_path,
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
if not task_published:
|
||||
logger.warning(f"Failed to publish processing task for image {created_image.id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish image processing task: {e}")
|
||||
|
||||
# Convert to response
|
||||
response = ImageResponse(
|
||||
id=str(created_image.id),
|
||||
filename=created_image.filename,
|
||||
original_filename=created_image.original_filename,
|
||||
file_size=created_image.file_size,
|
||||
content_type=created_image.content_type,
|
||||
storage_path=created_image.storage_path,
|
||||
public_url=created_image.public_url,
|
||||
team_id=str(created_image.team_id),
|
||||
uploader_id=str(created_image.uploader_id),
|
||||
upload_date=created_image.upload_date,
|
||||
description=created_image.description,
|
||||
metadata=created_image.metadata,
|
||||
has_embedding=created_image.has_embedding,
|
||||
collection_id=str(created_image.collection_id) if created_image.collection_id else None
|
||||
)
|
||||
|
||||
response = await image_service.upload_image(file, current_user, request, description, collection_id)
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading image: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to upload image")
|
||||
logger.error(f"Unexpected error uploading image: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("", response_model=ImageListResponse)
|
||||
async def list_images(
|
||||
@ -158,58 +74,12 @@ async def list_images(
|
||||
team_id=str(current_user.team_id)
|
||||
)
|
||||
|
||||
# Check if user is admin - if so, get all images across all teams
|
||||
if current_user.is_admin:
|
||||
# Admin users can see all images across all teams
|
||||
images = await image_repository.get_all_with_pagination(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||
)
|
||||
|
||||
# Get total count for admin
|
||||
total = await image_repository.count_all(
|
||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||
)
|
||||
else:
|
||||
# Regular users only see images from their team
|
||||
images = await image_repository.get_by_team(
|
||||
current_user.team_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||
)
|
||||
|
||||
# Get total count for regular user
|
||||
total = await image_repository.count_by_team(
|
||||
current_user.team_id,
|
||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||
)
|
||||
|
||||
# Convert to response
|
||||
response_images = []
|
||||
for image in images:
|
||||
# Generate API download URL
|
||||
api_download_url = generate_api_download_url(request, str(image.id))
|
||||
|
||||
response_images.append(ImageResponse(
|
||||
id=str(image.id),
|
||||
filename=image.filename,
|
||||
original_filename=image.original_filename,
|
||||
file_size=image.file_size,
|
||||
content_type=image.content_type,
|
||||
storage_path=image.storage_path,
|
||||
public_url=api_download_url,
|
||||
team_id=str(image.team_id),
|
||||
uploader_id=str(image.uploader_id),
|
||||
upload_date=image.upload_date,
|
||||
description=image.description,
|
||||
metadata=image.metadata,
|
||||
has_embedding=image.has_embedding,
|
||||
collection_id=str(image.collection_id) if image.collection_id else None
|
||||
))
|
||||
|
||||
return ImageListResponse(images=response_images, total=total, skip=skip, limit=limit)
|
||||
try:
|
||||
response = await image_service.list_images(current_user, request, skip, limit, collection_id)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error listing images: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/{image_id}", response_model=ImageResponse)
|
||||
async def get_image(
|
||||
@ -227,42 +97,17 @@ async def get_image(
|
||||
)
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(image_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid image ID")
|
||||
|
||||
# Get image
|
||||
image = await image_repository.get_by_id(obj_id)
|
||||
if not image:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
# Check team access (admins can access any image)
|
||||
if not current_user.is_admin and image.team_id != current_user.team_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this image")
|
||||
|
||||
# Generate API download URL
|
||||
api_download_url = generate_api_download_url(request, str(image.id))
|
||||
|
||||
# Convert to response
|
||||
response = ImageResponse(
|
||||
id=str(image.id),
|
||||
filename=image.filename,
|
||||
original_filename=image.original_filename,
|
||||
file_size=image.file_size,
|
||||
content_type=image.content_type,
|
||||
storage_path=image.storage_path,
|
||||
public_url=api_download_url,
|
||||
team_id=str(image.team_id),
|
||||
uploader_id=str(image.uploader_id),
|
||||
upload_date=image.upload_date,
|
||||
description=image.description,
|
||||
metadata=image.metadata,
|
||||
has_embedding=image.has_embedding,
|
||||
collection_id=str(image.collection_id) if image.collection_id else None,
|
||||
last_accessed=image.last_accessed
|
||||
)
|
||||
|
||||
return response
|
||||
response = await image_service.get_image(image_id, current_user, request)
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting image: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/{image_id}/download")
|
||||
async def download_image(
|
||||
@ -280,33 +125,23 @@ async def download_image(
|
||||
)
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(image_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid image ID")
|
||||
|
||||
# Get image
|
||||
image = await image_repository.get_by_id(obj_id)
|
||||
if not image:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
# Check team access (admins can access any image)
|
||||
if not current_user.is_admin and image.team_id != current_user.team_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to access this image")
|
||||
|
||||
# Get file from storage
|
||||
file_content = storage_service.get_file(image.storage_path)
|
||||
if not file_content:
|
||||
raise HTTPException(status_code=404, detail="Image file not found in storage")
|
||||
|
||||
# Update last accessed
|
||||
await image_repository.update_last_accessed(obj_id)
|
||||
|
||||
# Return file as streaming response
|
||||
return StreamingResponse(
|
||||
io.BytesIO(file_content),
|
||||
media_type=image.content_type,
|
||||
headers={"Content-Disposition": f"attachment; filename={image.original_filename}"}
|
||||
)
|
||||
file_content, content_type, filename = await image_service.download_image(image_id, current_user)
|
||||
|
||||
# Return file as streaming response
|
||||
return StreamingResponse(
|
||||
io.BytesIO(file_content),
|
||||
media_type=content_type,
|
||||
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error downloading image: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.put("/{image_id}", response_model=ImageResponse)
|
||||
async def update_image(
|
||||
@ -325,68 +160,17 @@ async def update_image(
|
||||
)
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(image_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid image ID")
|
||||
|
||||
# Get image
|
||||
image = await image_repository.get_by_id(obj_id)
|
||||
if not image:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
# Check team access (admins can update any image)
|
||||
if not current_user.is_admin and image.team_id != current_user.team_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to update this image")
|
||||
|
||||
# Update image
|
||||
update_data = image_data.dict(exclude_unset=True)
|
||||
if not update_data:
|
||||
# No fields to update
|
||||
api_download_url = generate_api_download_url(request, str(image.id))
|
||||
response = ImageResponse(
|
||||
id=str(image.id),
|
||||
filename=image.filename,
|
||||
original_filename=image.original_filename,
|
||||
file_size=image.file_size,
|
||||
content_type=image.content_type,
|
||||
storage_path=image.storage_path,
|
||||
public_url=api_download_url,
|
||||
team_id=str(image.team_id),
|
||||
uploader_id=str(image.uploader_id),
|
||||
upload_date=image.upload_date,
|
||||
description=image.description,
|
||||
metadata=image.metadata,
|
||||
has_embedding=image.has_embedding,
|
||||
collection_id=str(image.collection_id) if image.collection_id else None
|
||||
)
|
||||
response = await image_service.update_image(image_id, image_data, current_user, request)
|
||||
return response
|
||||
|
||||
updated_image = await image_repository.update(obj_id, update_data)
|
||||
if not updated_image:
|
||||
raise HTTPException(status_code=500, detail="Failed to update image")
|
||||
|
||||
# Generate API download URL
|
||||
api_download_url = generate_api_download_url(request, str(updated_image.id))
|
||||
|
||||
# Convert to response
|
||||
response = ImageResponse(
|
||||
id=str(updated_image.id),
|
||||
filename=updated_image.filename,
|
||||
original_filename=updated_image.original_filename,
|
||||
file_size=updated_image.file_size,
|
||||
content_type=updated_image.content_type,
|
||||
storage_path=updated_image.storage_path,
|
||||
public_url=api_download_url,
|
||||
team_id=str(updated_image.team_id),
|
||||
uploader_id=str(updated_image.uploader_id),
|
||||
upload_date=updated_image.upload_date,
|
||||
description=updated_image.description,
|
||||
metadata=updated_image.metadata,
|
||||
has_embedding=updated_image.has_embedding,
|
||||
collection_id=str(updated_image.collection_id) if updated_image.collection_id else None
|
||||
)
|
||||
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating image: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.delete("/{image_id}", status_code=204)
|
||||
async def delete_image(
|
||||
@ -404,35 +188,14 @@ async def delete_image(
|
||||
)
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(image_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid image ID")
|
||||
|
||||
# Get image
|
||||
image = await image_repository.get_by_id(obj_id)
|
||||
if not image:
|
||||
raise HTTPException(status_code=404, detail="Image not found")
|
||||
|
||||
# Check team access (admins can delete any image)
|
||||
if not current_user.is_admin and image.team_id != current_user.team_id:
|
||||
raise HTTPException(status_code=403, detail="Not authorized to delete this image")
|
||||
|
||||
# Delete from storage
|
||||
try:
|
||||
storage_service.delete_file(image.storage_path)
|
||||
await image_service.delete_image(image_id, current_user)
|
||||
return None
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete file from storage: {e}")
|
||||
|
||||
# Delete from vector database if it has embeddings
|
||||
if image.has_embedding and image.embedding_id:
|
||||
try:
|
||||
await embedding_service.delete_embedding(image.embedding_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete embedding: {e}")
|
||||
|
||||
# Delete from database
|
||||
result = await image_repository.delete(obj_id)
|
||||
if not result:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete image")
|
||||
|
||||
return None
|
||||
logger.error(f"Unexpected error deleting image: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@ -3,12 +3,8 @@ from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, Depends, Query, Request, HTTPException
|
||||
|
||||
from src.auth.security import get_current_user
|
||||
from src.services.vector_db import VectorDatabaseService
|
||||
from src.services.embedding_service import EmbeddingService
|
||||
from src.db.repositories.image_repository import image_repository
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
from src.services.search_service import SearchService
|
||||
from src.models.user import UserModel
|
||||
from src.schemas.image import ImageResponse
|
||||
from src.schemas.search import SearchResponse, SearchRequest
|
||||
from src.utils.logging import log_request
|
||||
|
||||
@ -16,17 +12,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Search"], prefix="/search")
|
||||
|
||||
# Initialize services - delay VectorDatabaseService instantiation
|
||||
vector_db_service = None
|
||||
embedding_service = EmbeddingService()
|
||||
|
||||
def get_vector_db_service():
|
||||
"""Get or create the vector database service instance"""
|
||||
global vector_db_service
|
||||
if vector_db_service is None:
|
||||
logger.info("Initializing VectorDatabaseService...")
|
||||
vector_db_service = VectorDatabaseService()
|
||||
return vector_db_service
|
||||
# Initialize service
|
||||
search_service = SearchService()
|
||||
|
||||
@router.get("", response_model=SearchResponse)
|
||||
async def search_images(
|
||||
@ -53,82 +40,22 @@ async def search_images(
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate embedding for the search query
|
||||
query_embedding = await embedding_service.generate_text_embedding(q)
|
||||
if not query_embedding:
|
||||
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
||||
|
||||
# Search in vector database
|
||||
search_results = get_vector_db_service().search_similar_images(
|
||||
query_vector=query_embedding,
|
||||
response = await search_service.search_images(
|
||||
query=q,
|
||||
user=current_user,
|
||||
request=request,
|
||||
limit=limit,
|
||||
similarity_threshold=similarity_threshold,
|
||||
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
||||
collection_id=collection_id
|
||||
)
|
||||
|
||||
if not search_results:
|
||||
return SearchResponse(
|
||||
query=q,
|
||||
results=[],
|
||||
total=0,
|
||||
limit=limit,
|
||||
similarity_threshold=similarity_threshold
|
||||
)
|
||||
|
||||
# Get image IDs and scores from search results
|
||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||
scores = {result['image_id']: result['similarity_score'] for result in search_results if result['image_id']}
|
||||
|
||||
# Get image metadata from database
|
||||
images = await image_repository.get_by_ids(image_ids)
|
||||
|
||||
# Filter by collection if specified
|
||||
filtered_images = []
|
||||
for image in images:
|
||||
# Check collection filter
|
||||
if collection_id and str(image.collection_id) != collection_id:
|
||||
continue
|
||||
|
||||
filtered_images.append(image)
|
||||
|
||||
# Convert to response format with similarity scores
|
||||
results = []
|
||||
for image in filtered_images:
|
||||
image_id = str(image.id)
|
||||
similarity_score = scores.get(image_id, 0.0)
|
||||
|
||||
result = ImageResponse(
|
||||
id=image_id,
|
||||
filename=image.filename,
|
||||
original_filename=image.original_filename,
|
||||
file_size=image.file_size,
|
||||
content_type=image.content_type,
|
||||
storage_path=image.storage_path,
|
||||
team_id=str(image.team_id),
|
||||
uploader_id=str(image.uploader_id),
|
||||
upload_date=image.upload_date,
|
||||
description=image.description,
|
||||
metadata=image.metadata,
|
||||
has_embedding=image.has_embedding,
|
||||
collection_id=str(image.collection_id) if image.collection_id else None,
|
||||
similarity_score=similarity_score
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Sort by similarity score (highest first)
|
||||
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||
|
||||
return SearchResponse(
|
||||
query=q,
|
||||
results=results,
|
||||
total=len(results),
|
||||
limit=limit,
|
||||
similarity_threshold=similarity_threshold
|
||||
)
|
||||
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching images: {e}")
|
||||
raise HTTPException(status_code=500, detail="Search failed")
|
||||
logger.error(f"Unexpected error in search: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.post("", response_model=SearchResponse)
|
||||
async def search_images_advanced(
|
||||
@ -150,108 +77,16 @@ async def search_images_advanced(
|
||||
)
|
||||
|
||||
try:
|
||||
# Generate embedding for the search query
|
||||
logger.info(f"Generating embedding for query: {search_request.query}")
|
||||
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
|
||||
if not query_embedding:
|
||||
logger.error("Failed to generate search embedding")
|
||||
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
||||
|
||||
logger.info(f"Generated embedding with length: {len(query_embedding)}")
|
||||
|
||||
# Search in vector database
|
||||
logger.info(f"Searching vector database with similarity_threshold: {search_request.similarity_threshold}")
|
||||
search_results = get_vector_db_service().search_similar_images(
|
||||
query_vector=query_embedding,
|
||||
limit=search_request.limit,
|
||||
similarity_threshold=search_request.similarity_threshold,
|
||||
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
||||
response = await search_service.search_images_advanced(
|
||||
search_request=search_request,
|
||||
user=current_user,
|
||||
request=request
|
||||
)
|
||||
|
||||
logger.info(f"Vector search returned {len(search_results) if search_results else 0} results")
|
||||
|
||||
if not search_results:
|
||||
logger.info("No search results from vector database, returning empty response")
|
||||
return SearchResponse(
|
||||
query=search_request.query,
|
||||
results=[],
|
||||
total=0,
|
||||
limit=search_request.limit,
|
||||
similarity_threshold=search_request.similarity_threshold
|
||||
)
|
||||
|
||||
# Get image IDs and scores from search results
|
||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||
scores = {result['image_id']: result['similarity_score'] for result in search_results if result['image_id']}
|
||||
|
||||
logger.info(f"Extracted {len(image_ids)} image IDs: {image_ids}")
|
||||
|
||||
# Get image metadata from database
|
||||
logger.info("Fetching image metadata from database...")
|
||||
images = await image_repository.get_by_ids(image_ids)
|
||||
logger.info(f"Retrieved {len(images)} images from database")
|
||||
|
||||
# Apply filters
|
||||
filtered_images = []
|
||||
for image in images:
|
||||
# Check collection filter
|
||||
if search_request.collection_id and str(image.collection_id) != search_request.collection_id:
|
||||
continue
|
||||
|
||||
# Check date range filter
|
||||
if search_request.date_from and image.upload_date < search_request.date_from:
|
||||
continue
|
||||
|
||||
if search_request.date_to and image.upload_date > search_request.date_to:
|
||||
continue
|
||||
|
||||
# Check uploader filter
|
||||
if search_request.uploader_id and str(image.uploader_id) != search_request.uploader_id:
|
||||
continue
|
||||
|
||||
filtered_images.append(image)
|
||||
|
||||
logger.info(f"After filtering: {len(filtered_images)} images remain")
|
||||
|
||||
# Convert to response format with similarity scores
|
||||
results = []
|
||||
for image in filtered_images:
|
||||
image_id = str(image.id)
|
||||
similarity_score = scores.get(image_id, 0.0)
|
||||
|
||||
result = ImageResponse(
|
||||
id=image_id,
|
||||
filename=image.filename,
|
||||
original_filename=image.original_filename,
|
||||
file_size=image.file_size,
|
||||
content_type=image.content_type,
|
||||
storage_path=image.storage_path,
|
||||
team_id=str(image.team_id),
|
||||
uploader_id=str(image.uploader_id),
|
||||
upload_date=image.upload_date,
|
||||
description=image.description,
|
||||
metadata=image.metadata,
|
||||
has_embedding=image.has_embedding,
|
||||
collection_id=str(image.collection_id) if image.collection_id else None,
|
||||
similarity_score=similarity_score
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
# Sort by similarity score (highest first)
|
||||
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||
|
||||
logger.info(f"Returning {len(results)} results")
|
||||
|
||||
return SearchResponse(
|
||||
query=search_request.query,
|
||||
results=results,
|
||||
total=len(results),
|
||||
limit=search_request.limit,
|
||||
similarity_threshold=search_request.similarity_threshold
|
||||
)
|
||||
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Error in advanced search: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
raise HTTPException(status_code=500, detail="Advanced search failed")
|
||||
logger.error(f"Unexpected error in advanced search: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@ -2,15 +2,17 @@ import logging
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
from src.services.team_service import TeamService
|
||||
from src.schemas.team import TeamCreate, TeamUpdate, TeamResponse, TeamListResponse
|
||||
from src.models.team import TeamModel
|
||||
from src.utils.logging import log_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Teams"], prefix="/teams")
|
||||
|
||||
# Initialize service
|
||||
team_service = TeamService()
|
||||
|
||||
@router.post("", response_model=TeamResponse, status_code=201)
|
||||
async def create_team(team_data: TeamCreate, request: Request):
|
||||
"""
|
||||
@ -22,24 +24,12 @@ async def create_team(team_data: TeamCreate, request: Request):
|
||||
{"path": request.url.path, "method": request.method, "team_data": team_data.dict()}
|
||||
)
|
||||
|
||||
# Create team
|
||||
team = TeamModel(
|
||||
name=team_data.name,
|
||||
description=team_data.description
|
||||
)
|
||||
|
||||
created_team = await team_repository.create(team)
|
||||
|
||||
# Convert to response model
|
||||
response = TeamResponse(
|
||||
id=str(created_team.id),
|
||||
name=created_team.name,
|
||||
description=created_team.description,
|
||||
created_at=created_team.created_at,
|
||||
updated_at=created_team.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
try:
|
||||
response = await team_service.create_team(team_data)
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating team: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("", response_model=TeamListResponse)
|
||||
async def list_teams(request: Request):
|
||||
@ -52,21 +42,12 @@ async def list_teams(request: Request):
|
||||
{"path": request.url.path, "method": request.method}
|
||||
)
|
||||
|
||||
# Get all teams
|
||||
teams = await team_repository.get_all()
|
||||
|
||||
# Convert to response models
|
||||
response_teams = []
|
||||
for team in teams:
|
||||
response_teams.append(TeamResponse(
|
||||
id=str(team.id),
|
||||
name=team.name,
|
||||
description=team.description,
|
||||
created_at=team.created_at,
|
||||
updated_at=team.updated_at
|
||||
))
|
||||
|
||||
return TeamListResponse(teams=response_teams, total=len(response_teams))
|
||||
try:
|
||||
response = await team_service.list_teams()
|
||||
return response
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error listing teams: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/{team_id}", response_model=TeamResponse)
|
||||
async def get_team(team_id: str, request: Request):
|
||||
@ -80,26 +61,15 @@ async def get_team(team_id: str, request: Request):
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert string ID to ObjectId
|
||||
obj_id = ObjectId(team_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
||||
|
||||
# Get the team
|
||||
team = await team_repository.get_by_id(obj_id)
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
|
||||
# Convert to response model
|
||||
response = TeamResponse(
|
||||
id=str(team.id),
|
||||
name=team.name,
|
||||
description=team.description,
|
||||
created_at=team.created_at,
|
||||
updated_at=team.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
response = await team_service.get_team(team_id)
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting team: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.put("/{team_id}", response_model=TeamResponse)
|
||||
async def update_team(team_id: str, team_data: TeamUpdate, request: Request):
|
||||
@ -113,42 +83,15 @@ async def update_team(team_id: str, team_data: TeamUpdate, request: Request):
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert string ID to ObjectId
|
||||
obj_id = ObjectId(team_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
||||
|
||||
# Get the team
|
||||
team = await team_repository.get_by_id(obj_id)
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
|
||||
# Update the team
|
||||
update_data = team_data.dict(exclude_unset=True)
|
||||
if not update_data:
|
||||
# No fields to update
|
||||
return TeamResponse(
|
||||
id=str(team.id),
|
||||
name=team.name,
|
||||
description=team.description,
|
||||
created_at=team.created_at,
|
||||
updated_at=team.updated_at
|
||||
)
|
||||
|
||||
updated_team = await team_repository.update(obj_id, update_data)
|
||||
if not updated_team:
|
||||
raise HTTPException(status_code=500, detail="Failed to update team")
|
||||
|
||||
# Convert to response model
|
||||
response = TeamResponse(
|
||||
id=str(updated_team.id),
|
||||
name=updated_team.name,
|
||||
description=updated_team.description,
|
||||
created_at=updated_team.created_at,
|
||||
updated_at=updated_team.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
response = await team_service.update_team(team_id, team_data)
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating team: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.delete("/{team_id}", status_code=204)
|
||||
async def delete_team(team_id: str, request: Request):
|
||||
@ -162,17 +105,12 @@ async def delete_team(team_id: str, request: Request):
|
||||
)
|
||||
|
||||
try:
|
||||
# Convert string ID to ObjectId
|
||||
obj_id = ObjectId(team_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
||||
|
||||
# Get the team
|
||||
team = await team_repository.get_by_id(obj_id)
|
||||
if not team:
|
||||
raise HTTPException(status_code=404, detail="Team not found")
|
||||
|
||||
# Delete the team
|
||||
success = await team_repository.delete(obj_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete team")
|
||||
await team_service.delete_team(team_id)
|
||||
return None
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting team: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@ -1,13 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from bson import ObjectId
|
||||
|
||||
# Remove the auth import since we're removing authentication
|
||||
# from src.api.v1.auth import get_current_user
|
||||
from src.db.repositories.user_repository import user_repository
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
from src.models.user import UserModel
|
||||
from src.services.user_service import UserService
|
||||
from src.schemas.user import UserResponse, UserListResponse, UserCreate, UserUpdate
|
||||
from src.utils.logging import log_request
|
||||
|
||||
@ -15,6 +10,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["Users"], prefix="/users")
|
||||
|
||||
# Initialize service
|
||||
user_service = UserService()
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def read_users_me(
|
||||
request: Request,
|
||||
@ -26,26 +24,15 @@ async def read_users_me(
|
||||
)
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(user_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||
|
||||
current_user = await user_repository.get_by_id(obj_id)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
response = UserResponse(
|
||||
id=str(current_user.id),
|
||||
name=current_user.name,
|
||||
email=current_user.email,
|
||||
team_id=str(current_user.team_id),
|
||||
is_admin=current_user.is_admin,
|
||||
is_active=current_user.is_active,
|
||||
created_at=current_user.created_at,
|
||||
updated_at=current_user.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
response = await user_service.get_user_by_id(user_id)
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting user: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.put("/me", response_model=UserResponse)
|
||||
async def update_current_user(
|
||||
@ -59,46 +46,15 @@ async def update_current_user(
|
||||
)
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(user_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||
|
||||
current_user = await user_repository.get_by_id(obj_id)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Update user
|
||||
update_data = user_data.dict(exclude_unset=True)
|
||||
if not update_data:
|
||||
# No fields to update
|
||||
response = UserResponse(
|
||||
id=str(current_user.id),
|
||||
name=current_user.name,
|
||||
email=current_user.email,
|
||||
team_id=str(current_user.team_id),
|
||||
is_admin=current_user.is_admin,
|
||||
is_active=current_user.is_active,
|
||||
created_at=current_user.created_at,
|
||||
updated_at=current_user.updated_at
|
||||
)
|
||||
response = await user_service.update_user_by_id(user_id, user_data)
|
||||
return response
|
||||
|
||||
updated_user = await user_repository.update(current_user.id, update_data)
|
||||
if not updated_user:
|
||||
raise HTTPException(status_code=500, detail="Failed to update user")
|
||||
|
||||
response = UserResponse(
|
||||
id=str(updated_user.id),
|
||||
name=updated_user.name,
|
||||
email=updated_user.email,
|
||||
team_id=str(updated_user.team_id),
|
||||
is_admin=updated_user.is_admin,
|
||||
is_active=updated_user.is_active,
|
||||
created_at=updated_user.created_at,
|
||||
updated_at=updated_user.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.post("", response_model=UserResponse, status_code=201)
|
||||
async def create_user(
|
||||
@ -114,43 +70,16 @@ async def create_user(
|
||||
{"path": request.url.path, "method": request.method, "user_data": user_data.dict()}
|
||||
)
|
||||
|
||||
# Check if user with email already exists
|
||||
existing_user = await user_repository.get_by_email(user_data.email)
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=400, detail="User with this email already exists")
|
||||
|
||||
# Validate team exists if specified
|
||||
if user_data.team_id:
|
||||
team = await team_repository.get_by_id(ObjectId(user_data.team_id))
|
||||
if not team:
|
||||
raise HTTPException(status_code=400, detail="Team not found")
|
||||
team_id = user_data.team_id
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Team ID is required")
|
||||
|
||||
# Create user
|
||||
user = UserModel(
|
||||
name=user_data.name,
|
||||
email=user_data.email,
|
||||
team_id=ObjectId(team_id),
|
||||
is_admin=user_data.is_admin or False,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
created_user = await user_repository.create(user)
|
||||
|
||||
response = UserResponse(
|
||||
id=str(created_user.id),
|
||||
name=created_user.name,
|
||||
email=created_user.email,
|
||||
team_id=str(created_user.team_id),
|
||||
is_admin=created_user.is_admin,
|
||||
is_active=created_user.is_active,
|
||||
created_at=created_user.created_at,
|
||||
updated_at=created_user.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
try:
|
||||
response = await user_service.create_user(user_data)
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error creating user: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("", response_model=UserListResponse)
|
||||
async def list_users(
|
||||
@ -166,31 +95,14 @@ async def list_users(
|
||||
{"path": request.url.path, "method": request.method, "team_id": team_id}
|
||||
)
|
||||
|
||||
# Get users
|
||||
if team_id:
|
||||
try:
|
||||
filter_team_id = ObjectId(team_id)
|
||||
users = await user_repository.get_by_team(filter_team_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
||||
else:
|
||||
users = await user_repository.get_all()
|
||||
|
||||
# Convert to response
|
||||
response_users = []
|
||||
for user in users:
|
||||
response_users.append(UserResponse(
|
||||
id=str(user.id),
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
team_id=str(user.team_id),
|
||||
is_admin=user.is_admin,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
))
|
||||
|
||||
return UserListResponse(users=response_users, total=len(response_users))
|
||||
try:
|
||||
response = await user_service.list_users(team_id)
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error listing users: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
@ -207,26 +119,15 @@ async def get_user(
|
||||
)
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(user_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||
|
||||
user = await user_repository.get_by_id(obj_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
response = UserResponse(
|
||||
id=str(user.id),
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
team_id=str(user.team_id),
|
||||
is_admin=user.is_admin,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
response = await user_service.get_user(user_id)
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error getting user: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.put("/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
@ -244,47 +145,15 @@ async def update_user(
|
||||
)
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(user_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||
|
||||
# Check if user exists
|
||||
user = await user_repository.get_by_id(obj_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Update user
|
||||
update_data = user_data.dict(exclude_unset=True)
|
||||
if not update_data:
|
||||
# No fields to update
|
||||
response = UserResponse(
|
||||
id=str(user.id),
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
team_id=str(user.team_id),
|
||||
is_admin=user.is_admin,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
)
|
||||
response = await user_service.update_user(user_id, user_data)
|
||||
return response
|
||||
|
||||
updated_user = await user_repository.update(obj_id, update_data)
|
||||
if not updated_user:
|
||||
raise HTTPException(status_code=500, detail="Failed to update user")
|
||||
|
||||
response = UserResponse(
|
||||
id=str(updated_user.id),
|
||||
name=updated_user.name,
|
||||
email=updated_user.email,
|
||||
team_id=str(updated_user.team_id),
|
||||
is_admin=updated_user.is_admin,
|
||||
is_active=updated_user.is_active,
|
||||
created_at=updated_user.created_at,
|
||||
updated_at=updated_user.updated_at
|
||||
)
|
||||
|
||||
return response
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error updating user: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@router.delete("/{user_id}", status_code=204)
|
||||
async def delete_user(
|
||||
@ -301,16 +170,12 @@ async def delete_user(
|
||||
)
|
||||
|
||||
try:
|
||||
obj_id = ObjectId(user_id)
|
||||
except:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||
|
||||
# Check if user exists
|
||||
user = await user_repository.get_by_id(obj_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Delete user
|
||||
success = await user_repository.delete(obj_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=500, detail="Failed to delete user")
|
||||
await user_service.delete_user(user_id)
|
||||
return None
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error deleting user: {e}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
255
src/services/auth_service.py
Normal file
255
src/services/auth_service.py
Normal file
@ -0,0 +1,255 @@
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
from datetime import datetime
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.repositories.api_key_repository import api_key_repository
|
||||
from src.db.repositories.user_repository import user_repository
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse
|
||||
from src.auth.security import generate_api_key, verify_api_key, calculate_expiry_date, is_expired, hash_api_key
|
||||
from src.models.api_key import ApiKeyModel
|
||||
from src.models.team import TeamModel
|
||||
from src.models.user import UserModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AuthService:
|
||||
"""Service class for handling authentication-related business logic"""
|
||||
|
||||
async def create_api_key_for_user_and_team(
|
||||
self,
|
||||
user_id: str,
|
||||
team_id: str,
|
||||
key_data: ApiKeyCreate
|
||||
) -> ApiKeyWithValueResponse:
|
||||
"""
|
||||
Create a new API key for a specific user and team
|
||||
|
||||
Args:
|
||||
user_id: The user ID to create the key for
|
||||
team_id: The team ID the user belongs to
|
||||
key_data: The API key creation data
|
||||
|
||||
Returns:
|
||||
ApiKeyWithValueResponse: The created API key with the raw key value
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id or team_id are invalid
|
||||
RuntimeError: If user or team not found, or user doesn't belong to team
|
||||
"""
|
||||
# Validate user_id and team_id
|
||||
try:
|
||||
target_user_id = ObjectId(user_id)
|
||||
target_team_id = ObjectId(team_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid user ID or team ID")
|
||||
|
||||
# Verify user exists
|
||||
target_user = await user_repository.get_by_id(target_user_id)
|
||||
if not target_user:
|
||||
raise RuntimeError("User not found")
|
||||
|
||||
# Verify team exists
|
||||
team = await team_repository.get_by_id(target_team_id)
|
||||
if not team:
|
||||
raise RuntimeError("Team not found")
|
||||
|
||||
# Verify user belongs to the team
|
||||
if target_user.team_id != target_team_id:
|
||||
raise RuntimeError("User does not belong to the specified team")
|
||||
|
||||
# Validate key_data consistency
|
||||
if key_data.team_id and key_data.team_id != team_id:
|
||||
raise ValueError("Team ID in request body does not match parameter")
|
||||
|
||||
if key_data.user_id and key_data.user_id != user_id:
|
||||
raise ValueError("User ID in request body does not match parameter")
|
||||
|
||||
# Generate API key with expiry date
|
||||
raw_key, hashed_key = generate_api_key(str(target_team_id), str(target_user_id))
|
||||
expiry_date = calculate_expiry_date()
|
||||
|
||||
# Create API key in database
|
||||
api_key = ApiKeyModel(
|
||||
key_hash=hashed_key,
|
||||
user_id=target_user_id,
|
||||
team_id=target_team_id,
|
||||
name=key_data.name,
|
||||
description=key_data.description,
|
||||
expiry_date=expiry_date,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
created_key = await api_key_repository.create(api_key)
|
||||
|
||||
# Convert to response model
|
||||
return ApiKeyWithValueResponse(
|
||||
id=str(created_key.id),
|
||||
key=raw_key,
|
||||
name=created_key.name,
|
||||
description=created_key.description,
|
||||
team_id=str(created_key.team_id),
|
||||
user_id=str(created_key.user_id),
|
||||
created_at=created_key.created_at,
|
||||
expiry_date=created_key.expiry_date,
|
||||
last_used=created_key.last_used,
|
||||
is_active=created_key.is_active
|
||||
)
|
||||
|
||||
async def create_api_key_for_user_by_admin(
|
||||
self,
|
||||
target_user_id: str,
|
||||
key_data: ApiKeyCreate,
|
||||
admin_user: UserModel
|
||||
) -> ApiKeyWithValueResponse:
|
||||
"""
|
||||
Create a new API key for a specific user (admin only)
|
||||
|
||||
Args:
|
||||
target_user_id: The user ID to create the key for
|
||||
key_data: The API key creation data
|
||||
admin_user: The admin user performing the action
|
||||
|
||||
Returns:
|
||||
ApiKeyWithValueResponse: The created API key with the raw key value
|
||||
|
||||
Raises:
|
||||
PermissionError: If the admin user doesn't have admin privileges
|
||||
ValueError: If target_user_id is invalid
|
||||
RuntimeError: If target user or team not found
|
||||
"""
|
||||
# Check if current user is admin
|
||||
if not admin_user.is_admin:
|
||||
raise PermissionError("Admin access required")
|
||||
|
||||
try:
|
||||
target_user_obj_id = ObjectId(target_user_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid user ID")
|
||||
|
||||
# Get the target user
|
||||
target_user = await user_repository.get_by_id(target_user_obj_id)
|
||||
if not target_user:
|
||||
raise RuntimeError("Target user not found")
|
||||
|
||||
# Check if target user's team exists
|
||||
team = await team_repository.get_by_id(target_user.team_id)
|
||||
if not team:
|
||||
raise RuntimeError("Target user's team not found")
|
||||
|
||||
# Generate API key with expiry date
|
||||
raw_key, hashed_key = generate_api_key(str(target_user.team_id), str(target_user.id))
|
||||
expiry_date = calculate_expiry_date()
|
||||
|
||||
# Create API key in database
|
||||
api_key = ApiKeyModel(
|
||||
key_hash=hashed_key,
|
||||
user_id=target_user.id,
|
||||
team_id=target_user.team_id,
|
||||
name=key_data.name,
|
||||
description=key_data.description,
|
||||
expiry_date=expiry_date,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
created_key = await api_key_repository.create(api_key)
|
||||
|
||||
# Convert to response model
|
||||
return ApiKeyWithValueResponse(
|
||||
id=str(created_key.id),
|
||||
key=raw_key,
|
||||
name=created_key.name,
|
||||
description=created_key.description,
|
||||
team_id=str(created_key.team_id),
|
||||
user_id=str(created_key.user_id),
|
||||
created_at=created_key.created_at,
|
||||
expiry_date=created_key.expiry_date,
|
||||
last_used=created_key.last_used,
|
||||
is_active=created_key.is_active
|
||||
)
|
||||
|
||||
async def list_user_api_keys(self, user: UserModel) -> ApiKeyListResponse:
|
||||
"""
|
||||
List API keys for a specific user
|
||||
|
||||
Args:
|
||||
user: The user to list API keys for
|
||||
|
||||
Returns:
|
||||
ApiKeyListResponse: List of API keys for the user
|
||||
"""
|
||||
# Get API keys for user
|
||||
keys = await api_key_repository.get_by_user(user.id)
|
||||
|
||||
# Convert to response models
|
||||
response_keys = []
|
||||
for key in keys:
|
||||
response_keys.append(ApiKeyResponse(
|
||||
id=str(key.id),
|
||||
name=key.name,
|
||||
description=key.description,
|
||||
team_id=str(key.team_id),
|
||||
user_id=str(key.user_id),
|
||||
created_at=key.created_at,
|
||||
expiry_date=key.expiry_date,
|
||||
last_used=key.last_used,
|
||||
is_active=key.is_active
|
||||
))
|
||||
|
||||
return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys))
|
||||
|
||||
async def revoke_api_key(self, key_id: str, user: UserModel) -> bool:
|
||||
"""
|
||||
Revoke (deactivate) an API key
|
||||
|
||||
Args:
|
||||
key_id: The ID of the API key to revoke
|
||||
user: The user attempting to revoke the key
|
||||
|
||||
Returns:
|
||||
bool: True if successfully revoked
|
||||
|
||||
Raises:
|
||||
ValueError: If key_id is invalid
|
||||
RuntimeError: If key not found
|
||||
PermissionError: If user not authorized to revoke the key
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(key_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid key ID")
|
||||
|
||||
# Get the API key
|
||||
key = await api_key_repository.get_by_id(obj_id)
|
||||
if not key:
|
||||
raise RuntimeError("API key not found")
|
||||
|
||||
# Check if user owns the key or is an admin
|
||||
if key.user_id != user.id and not user.is_admin:
|
||||
raise PermissionError("Not authorized to revoke this API key")
|
||||
|
||||
# Deactivate the key
|
||||
result = await api_key_repository.deactivate(obj_id)
|
||||
if not result:
|
||||
raise RuntimeError("Failed to revoke API key")
|
||||
|
||||
return True
|
||||
|
||||
async def verify_user_authentication(self, user: UserModel) -> dict:
|
||||
"""
|
||||
Verify and return user authentication information
|
||||
|
||||
Args:
|
||||
user: The authenticated user
|
||||
|
||||
Returns:
|
||||
dict: User authentication information
|
||||
"""
|
||||
return {
|
||||
"user_id": str(user.id),
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
"team_id": str(user.team_id),
|
||||
"is_admin": user.is_admin
|
||||
}
|
||||
370
src/services/image_service.py
Normal file
370
src/services/image_service.py
Normal file
@ -0,0 +1,370 @@
|
||||
import logging
|
||||
from typing import Optional, List, Tuple
|
||||
from fastapi import UploadFile, Request
|
||||
from bson import ObjectId
|
||||
import io
|
||||
|
||||
from src.db.repositories.image_repository import image_repository
|
||||
from src.services.storage import StorageService
|
||||
from src.services.image_processor import ImageProcessor
|
||||
from src.services.embedding_service import EmbeddingService
|
||||
from src.services.pubsub_service import pubsub_service
|
||||
from src.models.image import ImageModel
|
||||
from src.models.user import UserModel
|
||||
from src.schemas.image import ImageResponse, ImageListResponse, ImageCreate, ImageUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ImageService:
|
||||
"""Service class for handling image-related business logic"""
|
||||
|
||||
def __init__(self):
|
||||
self.storage_service = StorageService()
|
||||
self.image_processor = ImageProcessor()
|
||||
self.embedding_service = EmbeddingService()
|
||||
|
||||
def _generate_api_download_url(self, request: Request, image_id: str) -> str:
|
||||
"""Generate API download URL for an image"""
|
||||
base_url = str(request.base_url).rstrip('/')
|
||||
return f"{base_url}/api/v1/images/{image_id}/download"
|
||||
|
||||
async def upload_image(
|
||||
self,
|
||||
file: UploadFile,
|
||||
user: UserModel,
|
||||
request: Request,
|
||||
description: Optional[str] = None,
|
||||
collection_id: Optional[str] = None
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Upload a new image
|
||||
|
||||
Args:
|
||||
file: The uploaded file
|
||||
user: The user uploading the image
|
||||
request: The FastAPI request object for URL generation
|
||||
description: Optional description for the image
|
||||
collection_id: Optional collection ID to associate with the image
|
||||
|
||||
Returns:
|
||||
ImageResponse: The created image metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If file validation fails
|
||||
RuntimeError: If upload fails
|
||||
"""
|
||||
# Validate file type
|
||||
if not file.content_type or not file.content_type.startswith('image/'):
|
||||
raise ValueError("File must be an image")
|
||||
|
||||
# Validate file size (10MB limit)
|
||||
max_size = 10 * 1024 * 1024 # 10MB
|
||||
content = await file.read()
|
||||
if len(content) > max_size:
|
||||
raise ValueError("File size exceeds 10MB limit")
|
||||
|
||||
# Reset file pointer
|
||||
await file.seek(0)
|
||||
|
||||
try:
|
||||
# Upload to storage
|
||||
storage_path, content_type, file_size, metadata = await self.storage_service.upload_file(
|
||||
file, str(user.team_id)
|
||||
)
|
||||
|
||||
# Create image record
|
||||
image = ImageModel(
|
||||
filename=file.filename,
|
||||
original_filename=file.filename,
|
||||
file_size=file_size,
|
||||
content_type=content_type,
|
||||
storage_path=storage_path,
|
||||
public_url=None, # Will be set after we have the image ID
|
||||
team_id=user.team_id,
|
||||
uploader_id=user.id,
|
||||
description=description,
|
||||
metadata=metadata,
|
||||
collection_id=ObjectId(collection_id) if collection_id else None
|
||||
)
|
||||
|
||||
# Save to database
|
||||
created_image = await image_repository.create(image)
|
||||
|
||||
# Generate API download URL now that we have the image ID
|
||||
api_download_url = self._generate_api_download_url(request, str(created_image.id))
|
||||
|
||||
# Update the image with the API download URL
|
||||
await image_repository.update(created_image.id, {"public_url": api_download_url})
|
||||
created_image.public_url = api_download_url
|
||||
|
||||
# Publish image processing task to Pub/Sub
|
||||
try:
|
||||
task_published = await pubsub_service.publish_image_processing_task(
|
||||
image_id=str(created_image.id),
|
||||
storage_path=storage_path,
|
||||
team_id=str(user.team_id)
|
||||
)
|
||||
if not task_published:
|
||||
logger.warning(f"Failed to publish processing task for image {created_image.id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to publish image processing task: {e}")
|
||||
|
||||
# Convert to response
|
||||
return self._convert_to_response(created_image, request)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading image: {e}")
|
||||
raise RuntimeError("Failed to upload image")
|
||||
|
||||
async def list_images(
|
||||
self,
|
||||
user: UserModel,
|
||||
request: Request,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
collection_id: Optional[str] = None
|
||||
) -> ImageListResponse:
|
||||
"""
|
||||
List images for the user's team or all images if user is admin
|
||||
|
||||
Args:
|
||||
user: The requesting user
|
||||
request: The FastAPI request object for URL generation
|
||||
skip: Number of records to skip for pagination
|
||||
limit: Maximum number of records to return
|
||||
collection_id: Optional filter by collection ID
|
||||
|
||||
Returns:
|
||||
ImageListResponse: List of images with pagination metadata
|
||||
"""
|
||||
# Check if user is admin - if so, get all images across all teams
|
||||
if user.is_admin:
|
||||
images = await image_repository.get_all_with_pagination(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||
)
|
||||
total = await image_repository.count_all(
|
||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||
)
|
||||
else:
|
||||
# Regular users only see images from their team
|
||||
images = await image_repository.get_by_team(
|
||||
user.team_id,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||
)
|
||||
total = await image_repository.count_by_team(
|
||||
user.team_id,
|
||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||
)
|
||||
|
||||
# Convert to response
|
||||
response_images = [self._convert_to_response(image, request) for image in images]
|
||||
|
||||
return ImageListResponse(images=response_images, total=total, skip=skip, limit=limit)
|
||||
|
||||
async def get_image(self, image_id: str, user: UserModel, request: Request) -> ImageResponse:
|
||||
"""
|
||||
Get image metadata by ID
|
||||
|
||||
Args:
|
||||
image_id: The image ID to retrieve
|
||||
user: The requesting user
|
||||
request: The FastAPI request object for URL generation
|
||||
|
||||
Returns:
|
||||
ImageResponse: The image metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If image_id is invalid
|
||||
RuntimeError: If image not found
|
||||
PermissionError: If user not authorized to access the image
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(image_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid image ID")
|
||||
|
||||
# Get image
|
||||
image = await image_repository.get_by_id(obj_id)
|
||||
if not image:
|
||||
raise RuntimeError("Image not found")
|
||||
|
||||
# Check team access (admins can access any image)
|
||||
if not user.is_admin and image.team_id != user.team_id:
|
||||
raise PermissionError("Not authorized to access this image")
|
||||
|
||||
return self._convert_to_response(image, request, include_last_accessed=True)
|
||||
|
||||
async def download_image(self, image_id: str, user: UserModel) -> Tuple[bytes, str, str]:
|
||||
"""
|
||||
Download image file
|
||||
|
||||
Args:
|
||||
image_id: The image ID to download
|
||||
user: The requesting user
|
||||
|
||||
Returns:
|
||||
Tuple[bytes, str, str]: File content, content type, and filename
|
||||
|
||||
Raises:
|
||||
ValueError: If image_id is invalid
|
||||
RuntimeError: If image not found or file not found in storage
|
||||
PermissionError: If user not authorized to access the image
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(image_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid image ID")
|
||||
|
||||
# Get image
|
||||
image = await image_repository.get_by_id(obj_id)
|
||||
if not image:
|
||||
raise RuntimeError("Image not found")
|
||||
|
||||
# Check team access (admins can access any image)
|
||||
if not user.is_admin and image.team_id != user.team_id:
|
||||
raise PermissionError("Not authorized to access this image")
|
||||
|
||||
# Get file from storage
|
||||
file_content = self.storage_service.get_file(image.storage_path)
|
||||
if not file_content:
|
||||
raise RuntimeError("Image file not found in storage")
|
||||
|
||||
# Update last accessed
|
||||
await image_repository.update_last_accessed(obj_id)
|
||||
|
||||
return file_content, image.content_type, image.original_filename
|
||||
|
||||
async def update_image(
|
||||
self,
|
||||
image_id: str,
|
||||
image_data: ImageUpdate,
|
||||
user: UserModel,
|
||||
request: Request
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Update image metadata
|
||||
|
||||
Args:
|
||||
image_id: The image ID to update
|
||||
image_data: The update data
|
||||
user: The requesting user
|
||||
request: The FastAPI request object for URL generation
|
||||
|
||||
Returns:
|
||||
ImageResponse: The updated image metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If image_id is invalid
|
||||
RuntimeError: If image not found or update fails
|
||||
PermissionError: If user not authorized to update the image
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(image_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid image ID")
|
||||
|
||||
# Get image
|
||||
image = await image_repository.get_by_id(obj_id)
|
||||
if not image:
|
||||
raise RuntimeError("Image not found")
|
||||
|
||||
# Check team access (admins can update any image)
|
||||
if not user.is_admin and image.team_id != user.team_id:
|
||||
raise PermissionError("Not authorized to update this image")
|
||||
|
||||
# Update image
|
||||
update_data = image_data.dict(exclude_unset=True)
|
||||
if not update_data:
|
||||
# No fields to update
|
||||
return self._convert_to_response(image, request)
|
||||
|
||||
updated_image = await image_repository.update(obj_id, update_data)
|
||||
if not updated_image:
|
||||
raise RuntimeError("Failed to update image")
|
||||
|
||||
return self._convert_to_response(updated_image, request)
|
||||
|
||||
async def delete_image(self, image_id: str, user: UserModel) -> bool:
|
||||
"""
|
||||
Delete an image
|
||||
|
||||
Args:
|
||||
image_id: The image ID to delete
|
||||
user: The requesting user
|
||||
|
||||
Returns:
|
||||
bool: True if successfully deleted
|
||||
|
||||
Raises:
|
||||
ValueError: If image_id is invalid
|
||||
RuntimeError: If image not found or deletion fails
|
||||
PermissionError: If user not authorized to delete the image
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(image_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid image ID")
|
||||
|
||||
# Get image
|
||||
image = await image_repository.get_by_id(obj_id)
|
||||
if not image:
|
||||
raise RuntimeError("Image not found")
|
||||
|
||||
# Check team access (admins can delete any image)
|
||||
if not user.is_admin and image.team_id != user.team_id:
|
||||
raise PermissionError("Not authorized to delete this image")
|
||||
|
||||
# Delete from storage
|
||||
try:
|
||||
self.storage_service.delete_file(image.storage_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete file from storage: {e}")
|
||||
|
||||
# Delete from vector database if it has embeddings
|
||||
if image.has_embedding and image.embedding_id:
|
||||
try:
|
||||
await self.embedding_service.delete_embedding(image.embedding_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete embedding: {e}")
|
||||
|
||||
# Delete from database
|
||||
result = await image_repository.delete(obj_id)
|
||||
if not result:
|
||||
raise RuntimeError("Failed to delete image")
|
||||
|
||||
return True
|
||||
|
||||
def _convert_to_response(
|
||||
self,
|
||||
image: ImageModel,
|
||||
request: Request,
|
||||
include_last_accessed: bool = False
|
||||
) -> ImageResponse:
|
||||
"""Convert ImageModel to ImageResponse"""
|
||||
api_download_url = self._generate_api_download_url(request, str(image.id))
|
||||
|
||||
response_data = {
|
||||
"id": str(image.id),
|
||||
"filename": image.filename,
|
||||
"original_filename": image.original_filename,
|
||||
"file_size": image.file_size,
|
||||
"content_type": image.content_type,
|
||||
"storage_path": image.storage_path,
|
||||
"public_url": api_download_url,
|
||||
"team_id": str(image.team_id),
|
||||
"uploader_id": str(image.uploader_id),
|
||||
"upload_date": image.upload_date,
|
||||
"description": image.description,
|
||||
"metadata": image.metadata,
|
||||
"has_embedding": image.has_embedding,
|
||||
"collection_id": str(image.collection_id) if image.collection_id else None
|
||||
}
|
||||
|
||||
if include_last_accessed:
|
||||
response_data["last_accessed"] = image.last_accessed
|
||||
|
||||
return ImageResponse(**response_data)
|
||||
271
src/services/search_service.py
Normal file
271
src/services/search_service.py
Normal file
@ -0,0 +1,271 @@
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from fastapi import Request
|
||||
from bson import ObjectId
|
||||
|
||||
from src.services.vector_db import VectorDatabaseService
|
||||
from src.services.embedding_service import EmbeddingService
|
||||
from src.db.repositories.image_repository import image_repository
|
||||
from src.models.user import UserModel
|
||||
from src.schemas.image import ImageResponse
|
||||
from src.schemas.search import SearchResponse, SearchRequest
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SearchService:
|
||||
"""Service class for handling search-related business logic"""
|
||||
|
||||
def __init__(self):
|
||||
self.vector_db_service = None
|
||||
self.embedding_service = EmbeddingService()
|
||||
|
||||
def _get_vector_db_service(self):
|
||||
"""Get or create the vector database service instance"""
|
||||
if self.vector_db_service is None:
|
||||
logger.info("Initializing VectorDatabaseService...")
|
||||
self.vector_db_service = VectorDatabaseService()
|
||||
return self.vector_db_service
|
||||
|
||||
def _generate_api_download_url(self, request: Request, image_id: str) -> str:
|
||||
"""Generate API download URL for an image"""
|
||||
base_url = str(request.base_url).rstrip('/')
|
||||
return f"{base_url}/api/v1/images/{image_id}/download"
|
||||
|
||||
async def search_images(
|
||||
self,
|
||||
query: str,
|
||||
user: UserModel,
|
||||
request: Request,
|
||||
limit: int = 10,
|
||||
similarity_threshold: float = 0.65,
|
||||
collection_id: Optional[str] = None
|
||||
) -> SearchResponse:
|
||||
"""
|
||||
Search for images using semantic similarity
|
||||
|
||||
Args:
|
||||
query: The search query text
|
||||
user: The requesting user
|
||||
request: The FastAPI request object for URL generation
|
||||
limit: Number of results to return
|
||||
similarity_threshold: Similarity threshold for filtering results
|
||||
collection_id: Optional filter by collection ID
|
||||
|
||||
Returns:
|
||||
SearchResponse: Search results with similarity scores
|
||||
|
||||
Raises:
|
||||
ValueError: If query embedding generation fails
|
||||
RuntimeError: If search fails
|
||||
"""
|
||||
try:
|
||||
# Generate embedding for the search query
|
||||
query_embedding = await self.embedding_service.generate_text_embedding(query)
|
||||
if not query_embedding:
|
||||
raise ValueError("Failed to generate search embedding")
|
||||
|
||||
# Search in vector database
|
||||
search_results = self._get_vector_db_service().search_similar_images(
|
||||
query_vector=query_embedding,
|
||||
limit=limit,
|
||||
similarity_threshold=similarity_threshold,
|
||||
filter_conditions={"team_id": str(user.team_id)} if user.team_id else None
|
||||
)
|
||||
|
||||
if not search_results:
|
||||
return SearchResponse(
|
||||
query=query,
|
||||
results=[],
|
||||
total=0,
|
||||
limit=limit,
|
||||
similarity_threshold=similarity_threshold
|
||||
)
|
||||
|
||||
# Get image IDs and scores from search results
|
||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||
scores = {result['image_id']: result['similarity_score'] for result in search_results if result['image_id']}
|
||||
|
||||
# Get image metadata from database
|
||||
images = await image_repository.get_by_ids(image_ids)
|
||||
|
||||
# Filter by collection if specified
|
||||
filtered_images = self._filter_images_by_collection(images, collection_id)
|
||||
|
||||
# Convert to response format with similarity scores
|
||||
results = self._convert_images_to_search_results(filtered_images, scores, request)
|
||||
|
||||
# Sort by similarity score (highest first)
|
||||
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||
|
||||
return SearchResponse(
|
||||
query=query,
|
||||
results=results,
|
||||
total=len(results),
|
||||
limit=limit,
|
||||
similarity_threshold=similarity_threshold
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching images: {e}")
|
||||
raise RuntimeError("Search failed")
|
||||
|
||||
async def search_images_advanced(
|
||||
self,
|
||||
search_request: SearchRequest,
|
||||
user: UserModel,
|
||||
request: Request
|
||||
) -> SearchResponse:
|
||||
"""
|
||||
Advanced search for images with more filtering options
|
||||
|
||||
Args:
|
||||
search_request: The advanced search request with filters
|
||||
user: The requesting user
|
||||
request: The FastAPI request object for URL generation
|
||||
|
||||
Returns:
|
||||
SearchResponse: Search results with similarity scores
|
||||
|
||||
Raises:
|
||||
ValueError: If query embedding generation fails
|
||||
RuntimeError: If search fails
|
||||
"""
|
||||
try:
|
||||
# Generate embedding for the search query
|
||||
logger.info(f"Generating embedding for query: {search_request.query}")
|
||||
query_embedding = await self.embedding_service.generate_text_embedding(search_request.query)
|
||||
if not query_embedding:
|
||||
logger.error("Failed to generate search embedding")
|
||||
raise ValueError("Failed to generate search embedding")
|
||||
|
||||
logger.info(f"Generated embedding with length: {len(query_embedding)}")
|
||||
|
||||
# Search in vector database
|
||||
logger.info(f"Searching vector database with similarity_threshold: {search_request.similarity_threshold}")
|
||||
search_results = self._get_vector_db_service().search_similar_images(
|
||||
query_vector=query_embedding,
|
||||
limit=search_request.limit,
|
||||
similarity_threshold=search_request.similarity_threshold,
|
||||
filter_conditions={"team_id": str(user.team_id)} if user.team_id else None
|
||||
)
|
||||
|
||||
logger.info(f"Vector search returned {len(search_results) if search_results else 0} results")
|
||||
|
||||
if not search_results:
|
||||
logger.info("No search results from vector database, returning empty response")
|
||||
return SearchResponse(
|
||||
query=search_request.query,
|
||||
results=[],
|
||||
total=0,
|
||||
limit=search_request.limit,
|
||||
similarity_threshold=search_request.similarity_threshold
|
||||
)
|
||||
|
||||
# Get image IDs and scores from search results
|
||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||
scores = {result['image_id']: result['similarity_score'] for result in search_results if result['image_id']}
|
||||
|
||||
logger.info(f"Extracted {len(image_ids)} image IDs: {image_ids}")
|
||||
|
||||
# Get image metadata from database
|
||||
logger.info("Fetching image metadata from database...")
|
||||
images = await image_repository.get_by_ids(image_ids)
|
||||
logger.info(f"Retrieved {len(images)} images from database")
|
||||
|
||||
# Apply advanced filters
|
||||
filtered_images = self._apply_advanced_filters(images, search_request)
|
||||
|
||||
logger.info(f"After filtering: {len(filtered_images)} images remain")
|
||||
|
||||
# Convert to response format with similarity scores
|
||||
results = self._convert_images_to_search_results(filtered_images, scores, request)
|
||||
|
||||
# Sort by similarity score (highest first)
|
||||
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||
|
||||
logger.info(f"Returning {len(results)} results")
|
||||
|
||||
return SearchResponse(
|
||||
query=search_request.query,
|
||||
results=results,
|
||||
total=len(results),
|
||||
limit=search_request.limit,
|
||||
similarity_threshold=search_request.similarity_threshold
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in advanced search: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
raise RuntimeError("Advanced search failed")
|
||||
|
||||
def _filter_images_by_collection(self, images: List, collection_id: Optional[str]) -> List:
|
||||
"""Filter images by collection ID"""
|
||||
if not collection_id:
|
||||
return images
|
||||
|
||||
filtered_images = []
|
||||
for image in images:
|
||||
if str(image.collection_id) == collection_id:
|
||||
filtered_images.append(image)
|
||||
|
||||
return filtered_images
|
||||
|
||||
def _apply_advanced_filters(self, images: List, search_request: SearchRequest) -> List:
|
||||
"""Apply advanced filters to the image list"""
|
||||
filtered_images = []
|
||||
|
||||
for image in images:
|
||||
# Check collection filter
|
||||
if search_request.collection_id and str(image.collection_id) != search_request.collection_id:
|
||||
continue
|
||||
|
||||
# Check date range filter
|
||||
if search_request.date_from and image.upload_date < search_request.date_from:
|
||||
continue
|
||||
|
||||
if search_request.date_to and image.upload_date > search_request.date_to:
|
||||
continue
|
||||
|
||||
# Check uploader filter
|
||||
if search_request.uploader_id and str(image.uploader_id) != search_request.uploader_id:
|
||||
continue
|
||||
|
||||
filtered_images.append(image)
|
||||
|
||||
return filtered_images
|
||||
|
||||
def _convert_images_to_search_results(
|
||||
self,
|
||||
images: List,
|
||||
scores: Dict[str, float],
|
||||
request: Request
|
||||
) -> List[ImageResponse]:
|
||||
"""Convert images to search result format with similarity scores"""
|
||||
results = []
|
||||
|
||||
for image in images:
|
||||
image_id = str(image.id)
|
||||
similarity_score = scores.get(image_id, 0.0)
|
||||
|
||||
result = ImageResponse(
|
||||
id=image_id,
|
||||
filename=image.filename,
|
||||
original_filename=image.original_filename,
|
||||
file_size=image.file_size,
|
||||
content_type=image.content_type,
|
||||
storage_path=image.storage_path,
|
||||
public_url=self._generate_api_download_url(request, image_id),
|
||||
team_id=str(image.team_id),
|
||||
uploader_id=str(image.uploader_id),
|
||||
upload_date=image.upload_date,
|
||||
description=image.description,
|
||||
metadata=image.metadata,
|
||||
has_embedding=image.has_embedding,
|
||||
collection_id=str(image.collection_id) if image.collection_id else None,
|
||||
similarity_score=similarity_score
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
176
src/services/team_service.py
Normal file
176
src/services/team_service.py
Normal file
@ -0,0 +1,176 @@
|
||||
import logging
|
||||
from typing import List
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
from src.schemas.team import TeamCreate, TeamUpdate, TeamResponse, TeamListResponse
|
||||
from src.models.team import TeamModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TeamService:
|
||||
"""Service class for handling team-related business logic"""
|
||||
|
||||
async def create_team(self, team_data: TeamCreate) -> TeamResponse:
|
||||
"""
|
||||
Create a new team
|
||||
|
||||
Args:
|
||||
team_data: The team creation data
|
||||
|
||||
Returns:
|
||||
TeamResponse: The created team
|
||||
"""
|
||||
# Create team
|
||||
team = TeamModel(
|
||||
name=team_data.name,
|
||||
description=team_data.description
|
||||
)
|
||||
|
||||
created_team = await team_repository.create(team)
|
||||
|
||||
# Convert to response model
|
||||
return TeamResponse(
|
||||
id=str(created_team.id),
|
||||
name=created_team.name,
|
||||
description=created_team.description,
|
||||
created_at=created_team.created_at,
|
||||
updated_at=created_team.updated_at
|
||||
)
|
||||
|
||||
async def list_teams(self) -> TeamListResponse:
|
||||
"""
|
||||
List all teams
|
||||
|
||||
Returns:
|
||||
TeamListResponse: List of all teams
|
||||
"""
|
||||
# Get all teams
|
||||
teams = await team_repository.get_all()
|
||||
|
||||
# Convert to response models
|
||||
response_teams = []
|
||||
for team in teams:
|
||||
response_teams.append(TeamResponse(
|
||||
id=str(team.id),
|
||||
name=team.name,
|
||||
description=team.description,
|
||||
created_at=team.created_at,
|
||||
updated_at=team.updated_at
|
||||
))
|
||||
|
||||
return TeamListResponse(teams=response_teams, total=len(response_teams))
|
||||
|
||||
async def get_team(self, team_id: str) -> TeamResponse:
|
||||
"""
|
||||
Get a team by ID
|
||||
|
||||
Args:
|
||||
team_id: The team ID to retrieve
|
||||
|
||||
Returns:
|
||||
TeamResponse: The team data
|
||||
|
||||
Raises:
|
||||
ValueError: If team_id is invalid
|
||||
RuntimeError: If team not found
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(team_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid team ID")
|
||||
|
||||
# Get the team
|
||||
team = await team_repository.get_by_id(obj_id)
|
||||
if not team:
|
||||
raise RuntimeError("Team not found")
|
||||
|
||||
# Convert to response model
|
||||
return TeamResponse(
|
||||
id=str(team.id),
|
||||
name=team.name,
|
||||
description=team.description,
|
||||
created_at=team.created_at,
|
||||
updated_at=team.updated_at
|
||||
)
|
||||
|
||||
async def update_team(self, team_id: str, team_data: TeamUpdate) -> TeamResponse:
|
||||
"""
|
||||
Update a team
|
||||
|
||||
Args:
|
||||
team_id: The team ID to update
|
||||
team_data: The update data
|
||||
|
||||
Returns:
|
||||
TeamResponse: The updated team
|
||||
|
||||
Raises:
|
||||
ValueError: If team_id is invalid
|
||||
RuntimeError: If team not found or update fails
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(team_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid team ID")
|
||||
|
||||
# Get the team
|
||||
team = await team_repository.get_by_id(obj_id)
|
||||
if not team:
|
||||
raise RuntimeError("Team not found")
|
||||
|
||||
# Update the team
|
||||
update_data = team_data.dict(exclude_unset=True)
|
||||
if not update_data:
|
||||
# No fields to update
|
||||
return TeamResponse(
|
||||
id=str(team.id),
|
||||
name=team.name,
|
||||
description=team.description,
|
||||
created_at=team.created_at,
|
||||
updated_at=team.updated_at
|
||||
)
|
||||
|
||||
updated_team = await team_repository.update(obj_id, update_data)
|
||||
if not updated_team:
|
||||
raise RuntimeError("Failed to update team")
|
||||
|
||||
# Convert to response model
|
||||
return TeamResponse(
|
||||
id=str(updated_team.id),
|
||||
name=updated_team.name,
|
||||
description=updated_team.description,
|
||||
created_at=updated_team.created_at,
|
||||
updated_at=updated_team.updated_at
|
||||
)
|
||||
|
||||
async def delete_team(self, team_id: str) -> bool:
|
||||
"""
|
||||
Delete a team
|
||||
|
||||
Args:
|
||||
team_id: The team ID to delete
|
||||
|
||||
Returns:
|
||||
bool: True if successfully deleted
|
||||
|
||||
Raises:
|
||||
ValueError: If team_id is invalid
|
||||
RuntimeError: If team not found or deletion fails
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(team_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid team ID")
|
||||
|
||||
# Get the team
|
||||
team = await team_repository.get_by_id(obj_id)
|
||||
if not team:
|
||||
raise RuntimeError("Team not found")
|
||||
|
||||
# Delete the team
|
||||
success = await team_repository.delete(obj_id)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to delete team")
|
||||
|
||||
return True
|
||||
310
src/services/user_service.py
Normal file
310
src/services/user_service.py
Normal file
@ -0,0 +1,310 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from bson import ObjectId
|
||||
|
||||
from src.db.repositories.user_repository import user_repository
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
from src.models.user import UserModel
|
||||
from src.schemas.user import UserResponse, UserListResponse, UserCreate, UserUpdate
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UserService:
|
||||
"""Service class for handling user-related business logic"""
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> UserResponse:
|
||||
"""
|
||||
Get user information by user ID
|
||||
|
||||
Args:
|
||||
user_id: The user ID to retrieve
|
||||
|
||||
Returns:
|
||||
UserResponse: The user data
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is invalid
|
||||
RuntimeError: If user not found
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(user_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid user ID")
|
||||
|
||||
user = await user_repository.get_by_id(obj_id)
|
||||
if not user:
|
||||
raise RuntimeError("User not found")
|
||||
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
team_id=str(user.team_id),
|
||||
is_admin=user.is_admin,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
)
|
||||
|
||||
async def update_user_by_id(self, user_id: str, user_data: UserUpdate) -> UserResponse:
|
||||
"""
|
||||
Update user information by user ID
|
||||
|
||||
Args:
|
||||
user_id: The user ID to update
|
||||
user_data: The update data
|
||||
|
||||
Returns:
|
||||
UserResponse: The updated user data
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is invalid
|
||||
RuntimeError: If user not found or update fails
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(user_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid user ID")
|
||||
|
||||
user = await user_repository.get_by_id(obj_id)
|
||||
if not user:
|
||||
raise RuntimeError("User not found")
|
||||
|
||||
# Update user
|
||||
update_data = user_data.dict(exclude_unset=True)
|
||||
if not update_data:
|
||||
# No fields to update
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
team_id=str(user.team_id),
|
||||
is_admin=user.is_admin,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
)
|
||||
|
||||
updated_user = await user_repository.update(user.id, update_data)
|
||||
if not updated_user:
|
||||
raise RuntimeError("Failed to update user")
|
||||
|
||||
return UserResponse(
|
||||
id=str(updated_user.id),
|
||||
name=updated_user.name,
|
||||
email=updated_user.email,
|
||||
team_id=str(updated_user.team_id),
|
||||
is_admin=updated_user.is_admin,
|
||||
is_active=updated_user.is_active,
|
||||
created_at=updated_user.created_at,
|
||||
updated_at=updated_user.updated_at
|
||||
)
|
||||
|
||||
async def create_user(self, user_data: UserCreate) -> UserResponse:
|
||||
"""
|
||||
Create a new user
|
||||
|
||||
Args:
|
||||
user_data: The user creation data
|
||||
|
||||
Returns:
|
||||
UserResponse: The created user
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails
|
||||
RuntimeError: If team not found or user creation fails
|
||||
"""
|
||||
# Check if user with email already exists
|
||||
existing_user = await user_repository.get_by_email(user_data.email)
|
||||
if existing_user:
|
||||
raise ValueError("User with this email already exists")
|
||||
|
||||
# Validate team exists if specified
|
||||
if user_data.team_id:
|
||||
team = await team_repository.get_by_id(ObjectId(user_data.team_id))
|
||||
if not team:
|
||||
raise RuntimeError("Team not found")
|
||||
team_id = user_data.team_id
|
||||
else:
|
||||
raise ValueError("Team ID is required")
|
||||
|
||||
# Create user
|
||||
user = UserModel(
|
||||
name=user_data.name,
|
||||
email=user_data.email,
|
||||
team_id=ObjectId(team_id),
|
||||
is_admin=user_data.is_admin or False,
|
||||
is_active=True
|
||||
)
|
||||
|
||||
created_user = await user_repository.create(user)
|
||||
|
||||
return UserResponse(
|
||||
id=str(created_user.id),
|
||||
name=created_user.name,
|
||||
email=created_user.email,
|
||||
team_id=str(created_user.team_id),
|
||||
is_admin=created_user.is_admin,
|
||||
is_active=created_user.is_active,
|
||||
created_at=created_user.created_at,
|
||||
updated_at=created_user.updated_at
|
||||
)
|
||||
|
||||
async def list_users(self, team_id: Optional[str] = None) -> UserListResponse:
|
||||
"""
|
||||
List users, optionally filtered by team
|
||||
|
||||
Args:
|
||||
team_id: Optional team ID to filter by
|
||||
|
||||
Returns:
|
||||
UserListResponse: List of users
|
||||
|
||||
Raises:
|
||||
ValueError: If team_id is invalid
|
||||
"""
|
||||
# Get users
|
||||
if team_id:
|
||||
try:
|
||||
filter_team_id = ObjectId(team_id)
|
||||
users = await user_repository.get_by_team(filter_team_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid team ID")
|
||||
else:
|
||||
users = await user_repository.get_all()
|
||||
|
||||
# Convert to response
|
||||
response_users = []
|
||||
for user in users:
|
||||
response_users.append(UserResponse(
|
||||
id=str(user.id),
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
team_id=str(user.team_id),
|
||||
is_admin=user.is_admin,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
))
|
||||
|
||||
return UserListResponse(users=response_users, total=len(response_users))
|
||||
|
||||
async def get_user(self, user_id: str) -> UserResponse:
|
||||
"""
|
||||
Get user by ID
|
||||
|
||||
Args:
|
||||
user_id: The user ID to retrieve
|
||||
|
||||
Returns:
|
||||
UserResponse: The user data
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is invalid
|
||||
RuntimeError: If user not found
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(user_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid user ID")
|
||||
|
||||
user = await user_repository.get_by_id(obj_id)
|
||||
if not user:
|
||||
raise RuntimeError("User not found")
|
||||
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
team_id=str(user.team_id),
|
||||
is_admin=user.is_admin,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
)
|
||||
|
||||
async def update_user(self, user_id: str, user_data: UserUpdate) -> UserResponse:
|
||||
"""
|
||||
Update user by ID
|
||||
|
||||
Args:
|
||||
user_id: The user ID to update
|
||||
user_data: The update data
|
||||
|
||||
Returns:
|
||||
UserResponse: The updated user data
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is invalid
|
||||
RuntimeError: If user not found or update fails
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(user_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid user ID")
|
||||
|
||||
# Check if user exists
|
||||
user = await user_repository.get_by_id(obj_id)
|
||||
if not user:
|
||||
raise RuntimeError("User not found")
|
||||
|
||||
# Update user
|
||||
update_data = user_data.dict(exclude_unset=True)
|
||||
if not update_data:
|
||||
# No fields to update
|
||||
return UserResponse(
|
||||
id=str(user.id),
|
||||
name=user.name,
|
||||
email=user.email,
|
||||
team_id=str(user.team_id),
|
||||
is_admin=user.is_admin,
|
||||
is_active=user.is_active,
|
||||
created_at=user.created_at,
|
||||
updated_at=user.updated_at
|
||||
)
|
||||
|
||||
updated_user = await user_repository.update(obj_id, update_data)
|
||||
if not updated_user:
|
||||
raise RuntimeError("Failed to update user")
|
||||
|
||||
return UserResponse(
|
||||
id=str(updated_user.id),
|
||||
name=updated_user.name,
|
||||
email=updated_user.email,
|
||||
team_id=str(updated_user.team_id),
|
||||
is_admin=updated_user.is_admin,
|
||||
is_active=updated_user.is_active,
|
||||
created_at=updated_user.created_at,
|
||||
updated_at=updated_user.updated_at
|
||||
)
|
||||
|
||||
async def delete_user(self, user_id: str) -> bool:
|
||||
"""
|
||||
Delete user by ID
|
||||
|
||||
Args:
|
||||
user_id: The user ID to delete
|
||||
|
||||
Returns:
|
||||
bool: True if successfully deleted
|
||||
|
||||
Raises:
|
||||
ValueError: If user_id is invalid
|
||||
RuntimeError: If user not found or deletion fails
|
||||
"""
|
||||
try:
|
||||
obj_id = ObjectId(user_id)
|
||||
except Exception:
|
||||
raise ValueError("Invalid user ID")
|
||||
|
||||
# Check if user exists
|
||||
user = await user_repository.get_by_id(obj_id)
|
||||
if not user:
|
||||
raise RuntimeError("User not found")
|
||||
|
||||
# Delete user
|
||||
success = await user_repository.delete(obj_id)
|
||||
if not success:
|
||||
raise RuntimeError("Failed to delete user")
|
||||
|
||||
return True
|
||||
Loading…
x
Reference in New Issue
Block a user