refactor
This commit is contained in:
parent
ba2c00db38
commit
a26bd08d9c
@ -38,7 +38,6 @@ root/
|
|||||||
│ │ └── v1/ # API version 1 routes
|
│ │ └── v1/ # API version 1 routes
|
||||||
│ ├── auth/ # Authentication and authorization
|
│ ├── auth/ # Authentication and authorization
|
||||||
│ ├── config/ # Configuration management
|
│ ├── config/ # Configuration management
|
||||||
│ ├── core/ # Core application logic
|
|
||||||
│ ├── db/ # Database layer
|
│ ├── db/ # Database layer
|
||||||
│ │ ├── providers/ # Database providers (Firestore)
|
│ │ ├── providers/ # Database providers (Firestore)
|
||||||
│ │ └── repositories/ # Data access repositories
|
│ │ └── repositories/ # Data access repositories
|
||||||
@ -215,6 +214,14 @@ Uses Google's Vertex AI multimodal embedding model for generating high-quality i
|
|||||||
./deployment/deploy.sh --destroy
|
./deployment/deploy.sh --destroy
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
7. **Local Development**
|
||||||
|
```bash
|
||||||
|
./scripts/start.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
8. **Local Testing**
|
||||||
|
|
||||||
## API Endpoints
|
## API Endpoints
|
||||||
|
|
||||||
The API provides the following main endpoints with their authentication and pagination support:
|
The API provides the following main endpoints with their authentication and pagination support:
|
||||||
|
|||||||
@ -4,13 +4,11 @@ from datetime import datetime
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
|
||||||
from src.db.repositories.api_key_repository import api_key_repository
|
from src.services.auth_service import AuthService
|
||||||
from src.db.repositories.user_repository import user_repository
|
|
||||||
from src.db.repositories.team_repository import team_repository
|
|
||||||
from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse
|
from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse
|
||||||
from src.schemas.team import TeamCreate
|
from src.schemas.team import TeamCreate
|
||||||
from src.schemas.user import UserCreate
|
from src.schemas.user import UserCreate
|
||||||
from src.auth.security import generate_api_key, verify_api_key, calculate_expiry_date, is_expired, hash_api_key, get_current_user
|
from src.auth.security import get_current_user
|
||||||
from src.models.api_key import ApiKeyModel
|
from src.models.api_key import ApiKeyModel
|
||||||
from src.models.team import TeamModel
|
from src.models.team import TeamModel
|
||||||
from src.models.user import UserModel
|
from src.models.user import UserModel
|
||||||
@ -20,6 +18,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(tags=["Authentication"], prefix="/auth")
|
router = APIRouter(tags=["Authentication"], prefix="/auth")
|
||||||
|
|
||||||
|
# Initialize service
|
||||||
|
auth_service = AuthService()
|
||||||
|
|
||||||
@router.post("/api-keys", response_model=ApiKeyWithValueResponse, status_code=201)
|
@router.post("/api-keys", response_model=ApiKeyWithValueResponse, status_code=201)
|
||||||
async def create_api_key(key_data: ApiKeyCreate, request: Request, user_id: str, team_id: str):
|
async def create_api_key(key_data: ApiKeyCreate, request: Request, user_id: str, team_id: str):
|
||||||
"""
|
"""
|
||||||
@ -31,67 +32,16 @@ async def create_api_key(key_data: ApiKeyCreate, request: Request, user_id: str,
|
|||||||
{"path": request.url.path, "method": request.method, "key_data": key_data.dict(), "user_id": user_id, "team_id": team_id}
|
{"path": request.url.path, "method": request.method, "key_data": key_data.dict(), "user_id": user_id, "team_id": team_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate user_id and team_id
|
|
||||||
try:
|
try:
|
||||||
target_user_id = ObjectId(user_id)
|
response = await auth_service.create_api_key_for_user_and_team(user_id, team_id, key_data)
|
||||||
target_team_id = ObjectId(team_id)
|
return response
|
||||||
except:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail="Invalid user ID or team ID")
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except RuntimeError as e:
|
||||||
# Verify user exists
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
target_user = await user_repository.get_by_id(target_user_id)
|
except Exception as e:
|
||||||
if not target_user:
|
logger.error(f"Unexpected error creating API key: {e}")
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
# Verify team exists
|
|
||||||
team = await team_repository.get_by_id(target_team_id)
|
|
||||||
if not team:
|
|
||||||
raise HTTPException(status_code=404, detail="Team not found")
|
|
||||||
|
|
||||||
# Verify user belongs to the team
|
|
||||||
if target_user.team_id != target_team_id:
|
|
||||||
raise HTTPException(status_code=400, detail="User does not belong to the specified team")
|
|
||||||
|
|
||||||
# If team_id is provided in key_data, validate it matches the parameter
|
|
||||||
if key_data.team_id and key_data.team_id != team_id:
|
|
||||||
raise HTTPException(status_code=400, detail="Team ID in request body does not match parameter")
|
|
||||||
|
|
||||||
# If user_id is provided in key_data, validate it matches the parameter
|
|
||||||
if key_data.user_id and key_data.user_id != user_id:
|
|
||||||
raise HTTPException(status_code=400, detail="User ID in request body does not match parameter")
|
|
||||||
|
|
||||||
# Generate API key with expiry date
|
|
||||||
raw_key, hashed_key = generate_api_key(str(target_team_id), str(target_user_id))
|
|
||||||
expiry_date = calculate_expiry_date()
|
|
||||||
|
|
||||||
# Create API key in database
|
|
||||||
api_key = ApiKeyModel(
|
|
||||||
key_hash=hashed_key,
|
|
||||||
user_id=target_user_id,
|
|
||||||
team_id=target_team_id,
|
|
||||||
name=key_data.name,
|
|
||||||
description=key_data.description,
|
|
||||||
expiry_date=expiry_date,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
|
|
||||||
created_key = await api_key_repository.create(api_key)
|
|
||||||
|
|
||||||
# Convert to response model
|
|
||||||
response = ApiKeyWithValueResponse(
|
|
||||||
id=str(created_key.id),
|
|
||||||
key=raw_key,
|
|
||||||
name=created_key.name,
|
|
||||||
description=created_key.description,
|
|
||||||
team_id=str(created_key.team_id),
|
|
||||||
user_id=str(created_key.user_id),
|
|
||||||
created_at=created_key.created_at,
|
|
||||||
expiry_date=created_key.expiry_date,
|
|
||||||
last_used=created_key.last_used,
|
|
||||||
is_active=created_key.is_active
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.post("/admin/api-keys/{user_id}", response_model=ApiKeyWithValueResponse, status_code=201)
|
@router.post("/admin/api-keys/{user_id}", response_model=ApiKeyWithValueResponse, status_code=201)
|
||||||
async def create_api_key_for_user(
|
async def create_api_key_for_user(
|
||||||
@ -103,10 +53,6 @@ async def create_api_key_for_user(
|
|||||||
"""
|
"""
|
||||||
Create a new API key for a specific user (admin only)
|
Create a new API key for a specific user (admin only)
|
||||||
"""
|
"""
|
||||||
# Check if current user is admin
|
|
||||||
if not current_user.is_admin:
|
|
||||||
raise HTTPException(status_code=403, detail="Admin access required")
|
|
||||||
|
|
||||||
log_request(
|
log_request(
|
||||||
{"path": request.url.path, "method": request.method, "target_user_id": user_id, "key_data": key_data.dict()},
|
{"path": request.url.path, "method": request.method, "target_user_id": user_id, "key_data": key_data.dict()},
|
||||||
user_id=str(current_user.id),
|
user_id=str(current_user.id),
|
||||||
@ -114,52 +60,17 @@ async def create_api_key_for_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
target_user_obj_id = ObjectId(user_id)
|
response = await auth_service.create_api_key_for_user_by_admin(user_id, key_data, current_user)
|
||||||
except:
|
return response
|
||||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
except PermissionError as e:
|
||||||
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
# Get the target user
|
except ValueError as e:
|
||||||
target_user = await user_repository.get_by_id(target_user_obj_id)
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
if not target_user:
|
except RuntimeError as e:
|
||||||
raise HTTPException(status_code=404, detail="Target user not found")
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
# Check if target user's team exists
|
logger.error(f"Unexpected error creating API key for user: {e}")
|
||||||
team = await team_repository.get_by_id(target_user.team_id)
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
if not team:
|
|
||||||
raise HTTPException(status_code=404, detail="Target user's team not found")
|
|
||||||
|
|
||||||
# Generate API key with expiry date
|
|
||||||
raw_key, hashed_key = generate_api_key(str(target_user.team_id), str(target_user.id))
|
|
||||||
expiry_date = calculate_expiry_date()
|
|
||||||
|
|
||||||
# Create API key in database
|
|
||||||
api_key = ApiKeyModel(
|
|
||||||
key_hash=hashed_key,
|
|
||||||
user_id=target_user.id,
|
|
||||||
team_id=target_user.team_id,
|
|
||||||
name=key_data.name,
|
|
||||||
description=key_data.description,
|
|
||||||
expiry_date=expiry_date,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
|
|
||||||
created_key = await api_key_repository.create(api_key)
|
|
||||||
|
|
||||||
# Convert to response model
|
|
||||||
response = ApiKeyWithValueResponse(
|
|
||||||
id=str(created_key.id),
|
|
||||||
key=raw_key,
|
|
||||||
name=created_key.name,
|
|
||||||
description=created_key.description,
|
|
||||||
team_id=str(created_key.team_id),
|
|
||||||
user_id=str(created_key.user_id),
|
|
||||||
created_at=created_key.created_at,
|
|
||||||
expiry_date=created_key.expiry_date,
|
|
||||||
last_used=created_key.last_used,
|
|
||||||
is_active=created_key.is_active
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.get("/api-keys", response_model=ApiKeyListResponse)
|
@router.get("/api-keys", response_model=ApiKeyListResponse)
|
||||||
async def list_api_keys(request: Request, current_user = Depends(get_current_user)):
|
async def list_api_keys(request: Request, current_user = Depends(get_current_user)):
|
||||||
@ -172,25 +83,12 @@ async def list_api_keys(request: Request, current_user = Depends(get_current_use
|
|||||||
team_id=str(current_user.team_id)
|
team_id=str(current_user.team_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get API keys for user
|
try:
|
||||||
keys = await api_key_repository.get_by_user(current_user.id)
|
response = await auth_service.list_user_api_keys(current_user)
|
||||||
|
return response
|
||||||
# Convert to response models
|
except Exception as e:
|
||||||
response_keys = []
|
logger.error(f"Unexpected error listing API keys: {e}")
|
||||||
for key in keys:
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
response_keys.append(ApiKeyResponse(
|
|
||||||
id=str(key.id),
|
|
||||||
name=key.name,
|
|
||||||
description=key.description,
|
|
||||||
team_id=str(key.team_id),
|
|
||||||
user_id=str(key.user_id),
|
|
||||||
created_at=key.created_at,
|
|
||||||
expiry_date=key.expiry_date,
|
|
||||||
last_used=key.last_used,
|
|
||||||
is_active=key.is_active
|
|
||||||
))
|
|
||||||
|
|
||||||
return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys))
|
|
||||||
|
|
||||||
@router.delete("/api-keys/{key_id}", status_code=204)
|
@router.delete("/api-keys/{key_id}", status_code=204)
|
||||||
async def revoke_api_key(key_id: str, request: Request, current_user = Depends(get_current_user)):
|
async def revoke_api_key(key_id: str, request: Request, current_user = Depends(get_current_user)):
|
||||||
@ -204,26 +102,17 @@ async def revoke_api_key(key_id: str, request: Request, current_user = Depends(g
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert string ID to ObjectId
|
await auth_service.revoke_api_key(key_id, current_user)
|
||||||
obj_id = ObjectId(key_id)
|
return None
|
||||||
except:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail="Invalid key ID")
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except RuntimeError as e:
|
||||||
# Get the API key
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
key = await api_key_repository.get_by_id(obj_id)
|
except PermissionError as e:
|
||||||
if not key:
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
raise HTTPException(status_code=404, detail="API key not found")
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error revoking API key: {e}")
|
||||||
# Check if user owns the key or is an admin
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
if key.user_id != current_user.id and not current_user.is_admin:
|
|
||||||
raise HTTPException(status_code=403, detail="Not authorized to revoke this API key")
|
|
||||||
|
|
||||||
# Deactivate the key
|
|
||||||
result = await api_key_repository.deactivate(obj_id)
|
|
||||||
if not result:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to revoke API key")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
@router.get("/verify", status_code=200)
|
@router.get("/verify", status_code=200)
|
||||||
async def verify_authentication(request: Request, current_user = Depends(get_current_user)):
|
async def verify_authentication(request: Request, current_user = Depends(get_current_user)):
|
||||||
@ -236,10 +125,9 @@ async def verify_authentication(request: Request, current_user = Depends(get_cur
|
|||||||
team_id=str(current_user.team_id)
|
team_id=str(current_user.team_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
try:
|
||||||
"user_id": str(current_user.id),
|
response = await auth_service.verify_user_authentication(current_user)
|
||||||
"name": current_user.name,
|
return response
|
||||||
"email": current_user.email,
|
except Exception as e:
|
||||||
"team_id": str(current_user.team_id),
|
logger.error(f"Unexpected error verifying authentication: {e}")
|
||||||
"is_admin": current_user.is_admin
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
}
|
|
||||||
@ -6,12 +6,7 @@ from bson import ObjectId
|
|||||||
import io
|
import io
|
||||||
|
|
||||||
from src.auth.security import get_current_user
|
from src.auth.security import get_current_user
|
||||||
from src.db.repositories.image_repository import image_repository
|
from src.services.image_service import ImageService
|
||||||
from src.services.storage import StorageService
|
|
||||||
from src.services.image_processor import ImageProcessor
|
|
||||||
from src.services.embedding_service import EmbeddingService
|
|
||||||
from src.services.pubsub_service import pubsub_service
|
|
||||||
from src.models.image import ImageModel
|
|
||||||
from src.models.user import UserModel
|
from src.models.user import UserModel
|
||||||
from src.schemas.image import ImageResponse, ImageListResponse, ImageCreate, ImageUpdate
|
from src.schemas.image import ImageResponse, ImageListResponse, ImageCreate, ImageUpdate
|
||||||
from src.utils.logging import log_request
|
from src.utils.logging import log_request
|
||||||
@ -20,17 +15,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(tags=["Images"], prefix="/images")
|
router = APIRouter(tags=["Images"], prefix="/images")
|
||||||
|
|
||||||
# Initialize services
|
# Initialize service
|
||||||
storage_service = StorageService()
|
image_service = ImageService()
|
||||||
image_processor = ImageProcessor()
|
|
||||||
embedding_service = EmbeddingService()
|
|
||||||
|
|
||||||
def generate_api_download_url(request: Request, image_id: str) -> str:
|
|
||||||
"""
|
|
||||||
Generate API download URL for an image
|
|
||||||
"""
|
|
||||||
base_url = str(request.base_url).rstrip('/')
|
|
||||||
return f"{base_url}/api/v1/images/{image_id}/download"
|
|
||||||
|
|
||||||
@router.post("", response_model=ImageResponse, status_code=201)
|
@router.post("", response_model=ImageResponse, status_code=201)
|
||||||
async def upload_image(
|
async def upload_image(
|
||||||
@ -49,86 +35,16 @@ async def upload_image(
|
|||||||
team_id=str(current_user.team_id)
|
team_id=str(current_user.team_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate file type
|
|
||||||
if not file.content_type or not file.content_type.startswith('image/'):
|
|
||||||
raise HTTPException(status_code=400, detail="File must be an image")
|
|
||||||
|
|
||||||
# Validate file size (10MB limit)
|
|
||||||
max_size = 10 * 1024 * 1024 # 10MB
|
|
||||||
content = await file.read()
|
|
||||||
if len(content) > max_size:
|
|
||||||
raise HTTPException(status_code=400, detail="File size exceeds 10MB limit")
|
|
||||||
|
|
||||||
# Reset file pointer
|
|
||||||
await file.seek(0)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Upload to storage
|
response = await image_service.upload_image(file, current_user, request, description, collection_id)
|
||||||
storage_path, content_type, file_size, metadata = await storage_service.upload_file(
|
|
||||||
file, str(current_user.team_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Create image record
|
|
||||||
image = ImageModel(
|
|
||||||
filename=file.filename,
|
|
||||||
original_filename=file.filename,
|
|
||||||
file_size=file_size,
|
|
||||||
content_type=content_type,
|
|
||||||
storage_path=storage_path,
|
|
||||||
public_url=None, # Will be set after we have the image ID
|
|
||||||
team_id=current_user.team_id,
|
|
||||||
uploader_id=current_user.id,
|
|
||||||
description=description,
|
|
||||||
metadata=metadata,
|
|
||||||
collection_id=ObjectId(collection_id) if collection_id else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save to database
|
|
||||||
created_image = await image_repository.create(image)
|
|
||||||
|
|
||||||
# Generate API download URL now that we have the image ID
|
|
||||||
api_download_url = generate_api_download_url(request, str(created_image.id))
|
|
||||||
|
|
||||||
# Update the image with the API download URL
|
|
||||||
await image_repository.update(created_image.id, {"public_url": api_download_url})
|
|
||||||
created_image.public_url = api_download_url
|
|
||||||
|
|
||||||
# Publish image processing task to Pub/Sub
|
|
||||||
try:
|
|
||||||
task_published = await pubsub_service.publish_image_processing_task(
|
|
||||||
image_id=str(created_image.id),
|
|
||||||
storage_path=storage_path,
|
|
||||||
team_id=str(current_user.team_id)
|
|
||||||
)
|
|
||||||
if not task_published:
|
|
||||||
logger.warning(f"Failed to publish processing task for image {created_image.id}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to publish image processing task: {e}")
|
|
||||||
|
|
||||||
# Convert to response
|
|
||||||
response = ImageResponse(
|
|
||||||
id=str(created_image.id),
|
|
||||||
filename=created_image.filename,
|
|
||||||
original_filename=created_image.original_filename,
|
|
||||||
file_size=created_image.file_size,
|
|
||||||
content_type=created_image.content_type,
|
|
||||||
storage_path=created_image.storage_path,
|
|
||||||
public_url=created_image.public_url,
|
|
||||||
team_id=str(created_image.team_id),
|
|
||||||
uploader_id=str(created_image.uploader_id),
|
|
||||||
upload_date=created_image.upload_date,
|
|
||||||
description=created_image.description,
|
|
||||||
metadata=created_image.metadata,
|
|
||||||
has_embedding=created_image.has_embedding,
|
|
||||||
collection_id=str(created_image.collection_id) if created_image.collection_id else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except RuntimeError as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error uploading image: {e}")
|
logger.error(f"Unexpected error uploading image: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Failed to upload image")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
@router.get("", response_model=ImageListResponse)
|
@router.get("", response_model=ImageListResponse)
|
||||||
async def list_images(
|
async def list_images(
|
||||||
@ -158,58 +74,12 @@ async def list_images(
|
|||||||
team_id=str(current_user.team_id)
|
team_id=str(current_user.team_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if user is admin - if so, get all images across all teams
|
try:
|
||||||
if current_user.is_admin:
|
response = await image_service.list_images(current_user, request, skip, limit, collection_id)
|
||||||
# Admin users can see all images across all teams
|
return response
|
||||||
images = await image_repository.get_all_with_pagination(
|
except Exception as e:
|
||||||
skip=skip,
|
logger.error(f"Unexpected error listing images: {e}")
|
||||||
limit=limit,
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get total count for admin
|
|
||||||
total = await image_repository.count_all(
|
|
||||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Regular users only see images from their team
|
|
||||||
images = await image_repository.get_by_team(
|
|
||||||
current_user.team_id,
|
|
||||||
skip=skip,
|
|
||||||
limit=limit,
|
|
||||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get total count for regular user
|
|
||||||
total = await image_repository.count_by_team(
|
|
||||||
current_user.team_id,
|
|
||||||
collection_id=ObjectId(collection_id) if collection_id else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert to response
|
|
||||||
response_images = []
|
|
||||||
for image in images:
|
|
||||||
# Generate API download URL
|
|
||||||
api_download_url = generate_api_download_url(request, str(image.id))
|
|
||||||
|
|
||||||
response_images.append(ImageResponse(
|
|
||||||
id=str(image.id),
|
|
||||||
filename=image.filename,
|
|
||||||
original_filename=image.original_filename,
|
|
||||||
file_size=image.file_size,
|
|
||||||
content_type=image.content_type,
|
|
||||||
storage_path=image.storage_path,
|
|
||||||
public_url=api_download_url,
|
|
||||||
team_id=str(image.team_id),
|
|
||||||
uploader_id=str(image.uploader_id),
|
|
||||||
upload_date=image.upload_date,
|
|
||||||
description=image.description,
|
|
||||||
metadata=image.metadata,
|
|
||||||
has_embedding=image.has_embedding,
|
|
||||||
collection_id=str(image.collection_id) if image.collection_id else None
|
|
||||||
))
|
|
||||||
|
|
||||||
return ImageListResponse(images=response_images, total=total, skip=skip, limit=limit)
|
|
||||||
|
|
||||||
@router.get("/{image_id}", response_model=ImageResponse)
|
@router.get("/{image_id}", response_model=ImageResponse)
|
||||||
async def get_image(
|
async def get_image(
|
||||||
@ -227,42 +97,17 @@ async def get_image(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj_id = ObjectId(image_id)
|
response = await image_service.get_image(image_id, current_user, request)
|
||||||
except:
|
return response
|
||||||
raise HTTPException(status_code=400, detail="Invalid image ID")
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
# Get image
|
except RuntimeError as e:
|
||||||
image = await image_repository.get_by_id(obj_id)
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
if not image:
|
except PermissionError as e:
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
# Check team access (admins can access any image)
|
logger.error(f"Unexpected error getting image: {e}")
|
||||||
if not current_user.is_admin and image.team_id != current_user.team_id:
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
raise HTTPException(status_code=403, detail="Not authorized to access this image")
|
|
||||||
|
|
||||||
# Generate API download URL
|
|
||||||
api_download_url = generate_api_download_url(request, str(image.id))
|
|
||||||
|
|
||||||
# Convert to response
|
|
||||||
response = ImageResponse(
|
|
||||||
id=str(image.id),
|
|
||||||
filename=image.filename,
|
|
||||||
original_filename=image.original_filename,
|
|
||||||
file_size=image.file_size,
|
|
||||||
content_type=image.content_type,
|
|
||||||
storage_path=image.storage_path,
|
|
||||||
public_url=api_download_url,
|
|
||||||
team_id=str(image.team_id),
|
|
||||||
uploader_id=str(image.uploader_id),
|
|
||||||
upload_date=image.upload_date,
|
|
||||||
description=image.description,
|
|
||||||
metadata=image.metadata,
|
|
||||||
has_embedding=image.has_embedding,
|
|
||||||
collection_id=str(image.collection_id) if image.collection_id else None,
|
|
||||||
last_accessed=image.last_accessed
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.get("/{image_id}/download")
|
@router.get("/{image_id}/download")
|
||||||
async def download_image(
|
async def download_image(
|
||||||
@ -280,33 +125,23 @@ async def download_image(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj_id = ObjectId(image_id)
|
file_content, content_type, filename = await image_service.download_image(image_id, current_user)
|
||||||
except:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid image ID")
|
|
||||||
|
|
||||||
# Get image
|
# Return file as streaming response
|
||||||
image = await image_repository.get_by_id(obj_id)
|
return StreamingResponse(
|
||||||
if not image:
|
io.BytesIO(file_content),
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
media_type=content_type,
|
||||||
|
headers={"Content-Disposition": f"attachment; filename={filename}"}
|
||||||
# Check team access (admins can access any image)
|
)
|
||||||
if not current_user.is_admin and image.team_id != current_user.team_id:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=403, detail="Not authorized to access this image")
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except RuntimeError as e:
|
||||||
# Get file from storage
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
file_content = storage_service.get_file(image.storage_path)
|
except PermissionError as e:
|
||||||
if not file_content:
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
raise HTTPException(status_code=404, detail="Image file not found in storage")
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error downloading image: {e}")
|
||||||
# Update last accessed
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
await image_repository.update_last_accessed(obj_id)
|
|
||||||
|
|
||||||
# Return file as streaming response
|
|
||||||
return StreamingResponse(
|
|
||||||
io.BytesIO(file_content),
|
|
||||||
media_type=image.content_type,
|
|
||||||
headers={"Content-Disposition": f"attachment; filename={image.original_filename}"}
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.put("/{image_id}", response_model=ImageResponse)
|
@router.put("/{image_id}", response_model=ImageResponse)
|
||||||
async def update_image(
|
async def update_image(
|
||||||
@ -325,68 +160,17 @@ async def update_image(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj_id = ObjectId(image_id)
|
response = await image_service.update_image(image_id, image_data, current_user, request)
|
||||||
except:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid image ID")
|
|
||||||
|
|
||||||
# Get image
|
|
||||||
image = await image_repository.get_by_id(obj_id)
|
|
||||||
if not image:
|
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
|
||||||
|
|
||||||
# Check team access (admins can update any image)
|
|
||||||
if not current_user.is_admin and image.team_id != current_user.team_id:
|
|
||||||
raise HTTPException(status_code=403, detail="Not authorized to update this image")
|
|
||||||
|
|
||||||
# Update image
|
|
||||||
update_data = image_data.dict(exclude_unset=True)
|
|
||||||
if not update_data:
|
|
||||||
# No fields to update
|
|
||||||
api_download_url = generate_api_download_url(request, str(image.id))
|
|
||||||
response = ImageResponse(
|
|
||||||
id=str(image.id),
|
|
||||||
filename=image.filename,
|
|
||||||
original_filename=image.original_filename,
|
|
||||||
file_size=image.file_size,
|
|
||||||
content_type=image.content_type,
|
|
||||||
storage_path=image.storage_path,
|
|
||||||
public_url=api_download_url,
|
|
||||||
team_id=str(image.team_id),
|
|
||||||
uploader_id=str(image.uploader_id),
|
|
||||||
upload_date=image.upload_date,
|
|
||||||
description=image.description,
|
|
||||||
metadata=image.metadata,
|
|
||||||
has_embedding=image.has_embedding,
|
|
||||||
collection_id=str(image.collection_id) if image.collection_id else None
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
except ValueError as e:
|
||||||
updated_image = await image_repository.update(obj_id, update_data)
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
if not updated_image:
|
except RuntimeError as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update image")
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except PermissionError as e:
|
||||||
# Generate API download URL
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
api_download_url = generate_api_download_url(request, str(updated_image.id))
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error updating image: {e}")
|
||||||
# Convert to response
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
response = ImageResponse(
|
|
||||||
id=str(updated_image.id),
|
|
||||||
filename=updated_image.filename,
|
|
||||||
original_filename=updated_image.original_filename,
|
|
||||||
file_size=updated_image.file_size,
|
|
||||||
content_type=updated_image.content_type,
|
|
||||||
storage_path=updated_image.storage_path,
|
|
||||||
public_url=api_download_url,
|
|
||||||
team_id=str(updated_image.team_id),
|
|
||||||
uploader_id=str(updated_image.uploader_id),
|
|
||||||
upload_date=updated_image.upload_date,
|
|
||||||
description=updated_image.description,
|
|
||||||
metadata=updated_image.metadata,
|
|
||||||
has_embedding=updated_image.has_embedding,
|
|
||||||
collection_id=str(updated_image.collection_id) if updated_image.collection_id else None
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.delete("/{image_id}", status_code=204)
|
@router.delete("/{image_id}", status_code=204)
|
||||||
async def delete_image(
|
async def delete_image(
|
||||||
@ -404,35 +188,14 @@ async def delete_image(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj_id = ObjectId(image_id)
|
await image_service.delete_image(image_id, current_user)
|
||||||
except:
|
return None
|
||||||
raise HTTPException(status_code=400, detail="Invalid image ID")
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
# Get image
|
except RuntimeError as e:
|
||||||
image = await image_repository.get_by_id(obj_id)
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
if not image:
|
except PermissionError as e:
|
||||||
raise HTTPException(status_code=404, detail="Image not found")
|
raise HTTPException(status_code=403, detail=str(e))
|
||||||
|
|
||||||
# Check team access (admins can delete any image)
|
|
||||||
if not current_user.is_admin and image.team_id != current_user.team_id:
|
|
||||||
raise HTTPException(status_code=403, detail="Not authorized to delete this image")
|
|
||||||
|
|
||||||
# Delete from storage
|
|
||||||
try:
|
|
||||||
storage_service.delete_file(image.storage_path)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to delete file from storage: {e}")
|
logger.error(f"Unexpected error deleting image: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
# Delete from vector database if it has embeddings
|
|
||||||
if image.has_embedding and image.embedding_id:
|
|
||||||
try:
|
|
||||||
await embedding_service.delete_embedding(image.embedding_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to delete embedding: {e}")
|
|
||||||
|
|
||||||
# Delete from database
|
|
||||||
result = await image_repository.delete(obj_id)
|
|
||||||
if not result:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete image")
|
|
||||||
|
|
||||||
return None
|
|
||||||
@ -3,12 +3,8 @@ from typing import Optional, List, Dict, Any
|
|||||||
from fastapi import APIRouter, Depends, Query, Request, HTTPException
|
from fastapi import APIRouter, Depends, Query, Request, HTTPException
|
||||||
|
|
||||||
from src.auth.security import get_current_user
|
from src.auth.security import get_current_user
|
||||||
from src.services.vector_db import VectorDatabaseService
|
from src.services.search_service import SearchService
|
||||||
from src.services.embedding_service import EmbeddingService
|
|
||||||
from src.db.repositories.image_repository import image_repository
|
|
||||||
from src.db.repositories.team_repository import team_repository
|
|
||||||
from src.models.user import UserModel
|
from src.models.user import UserModel
|
||||||
from src.schemas.image import ImageResponse
|
|
||||||
from src.schemas.search import SearchResponse, SearchRequest
|
from src.schemas.search import SearchResponse, SearchRequest
|
||||||
from src.utils.logging import log_request
|
from src.utils.logging import log_request
|
||||||
|
|
||||||
@ -16,17 +12,8 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(tags=["Search"], prefix="/search")
|
router = APIRouter(tags=["Search"], prefix="/search")
|
||||||
|
|
||||||
# Initialize services - delay VectorDatabaseService instantiation
|
# Initialize service
|
||||||
vector_db_service = None
|
search_service = SearchService()
|
||||||
embedding_service = EmbeddingService()
|
|
||||||
|
|
||||||
def get_vector_db_service():
|
|
||||||
"""Get or create the vector database service instance"""
|
|
||||||
global vector_db_service
|
|
||||||
if vector_db_service is None:
|
|
||||||
logger.info("Initializing VectorDatabaseService...")
|
|
||||||
vector_db_service = VectorDatabaseService()
|
|
||||||
return vector_db_service
|
|
||||||
|
|
||||||
@router.get("", response_model=SearchResponse)
|
@router.get("", response_model=SearchResponse)
|
||||||
async def search_images(
|
async def search_images(
|
||||||
@ -53,82 +40,22 @@ async def search_images(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate embedding for the search query
|
response = await search_service.search_images(
|
||||||
query_embedding = await embedding_service.generate_text_embedding(q)
|
query=q,
|
||||||
if not query_embedding:
|
user=current_user,
|
||||||
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
request=request,
|
||||||
|
|
||||||
# Search in vector database
|
|
||||||
search_results = get_vector_db_service().search_similar_images(
|
|
||||||
query_vector=query_embedding,
|
|
||||||
limit=limit,
|
limit=limit,
|
||||||
similarity_threshold=similarity_threshold,
|
similarity_threshold=similarity_threshold,
|
||||||
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
collection_id=collection_id
|
||||||
)
|
)
|
||||||
|
return response
|
||||||
if not search_results:
|
except ValueError as e:
|
||||||
return SearchResponse(
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
query=q,
|
except RuntimeError as e:
|
||||||
results=[],
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
total=0,
|
|
||||||
limit=limit,
|
|
||||||
similarity_threshold=similarity_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get image IDs and scores from search results
|
|
||||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
|
||||||
scores = {result['image_id']: result['similarity_score'] for result in search_results if result['image_id']}
|
|
||||||
|
|
||||||
# Get image metadata from database
|
|
||||||
images = await image_repository.get_by_ids(image_ids)
|
|
||||||
|
|
||||||
# Filter by collection if specified
|
|
||||||
filtered_images = []
|
|
||||||
for image in images:
|
|
||||||
# Check collection filter
|
|
||||||
if collection_id and str(image.collection_id) != collection_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
filtered_images.append(image)
|
|
||||||
|
|
||||||
# Convert to response format with similarity scores
|
|
||||||
results = []
|
|
||||||
for image in filtered_images:
|
|
||||||
image_id = str(image.id)
|
|
||||||
similarity_score = scores.get(image_id, 0.0)
|
|
||||||
|
|
||||||
result = ImageResponse(
|
|
||||||
id=image_id,
|
|
||||||
filename=image.filename,
|
|
||||||
original_filename=image.original_filename,
|
|
||||||
file_size=image.file_size,
|
|
||||||
content_type=image.content_type,
|
|
||||||
storage_path=image.storage_path,
|
|
||||||
team_id=str(image.team_id),
|
|
||||||
uploader_id=str(image.uploader_id),
|
|
||||||
upload_date=image.upload_date,
|
|
||||||
description=image.description,
|
|
||||||
metadata=image.metadata,
|
|
||||||
has_embedding=image.has_embedding,
|
|
||||||
collection_id=str(image.collection_id) if image.collection_id else None,
|
|
||||||
similarity_score=similarity_score
|
|
||||||
)
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
# Sort by similarity score (highest first)
|
|
||||||
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
|
||||||
|
|
||||||
return SearchResponse(
|
|
||||||
query=q,
|
|
||||||
results=results,
|
|
||||||
total=len(results),
|
|
||||||
limit=limit,
|
|
||||||
similarity_threshold=similarity_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error searching images: {e}")
|
logger.error(f"Unexpected error in search: {e}")
|
||||||
raise HTTPException(status_code=500, detail="Search failed")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
@router.post("", response_model=SearchResponse)
|
@router.post("", response_model=SearchResponse)
|
||||||
async def search_images_advanced(
|
async def search_images_advanced(
|
||||||
@ -150,108 +77,16 @@ async def search_images_advanced(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate embedding for the search query
|
response = await search_service.search_images_advanced(
|
||||||
logger.info(f"Generating embedding for query: {search_request.query}")
|
search_request=search_request,
|
||||||
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
|
user=current_user,
|
||||||
if not query_embedding:
|
request=request
|
||||||
logger.error("Failed to generate search embedding")
|
|
||||||
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
|
||||||
|
|
||||||
logger.info(f"Generated embedding with length: {len(query_embedding)}")
|
|
||||||
|
|
||||||
# Search in vector database
|
|
||||||
logger.info(f"Searching vector database with similarity_threshold: {search_request.similarity_threshold}")
|
|
||||||
search_results = get_vector_db_service().search_similar_images(
|
|
||||||
query_vector=query_embedding,
|
|
||||||
limit=search_request.limit,
|
|
||||||
similarity_threshold=search_request.similarity_threshold,
|
|
||||||
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
|
||||||
)
|
)
|
||||||
|
return response
|
||||||
logger.info(f"Vector search returned {len(search_results) if search_results else 0} results")
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
if not search_results:
|
except RuntimeError as e:
|
||||||
logger.info("No search results from vector database, returning empty response")
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
return SearchResponse(
|
|
||||||
query=search_request.query,
|
|
||||||
results=[],
|
|
||||||
total=0,
|
|
||||||
limit=search_request.limit,
|
|
||||||
similarity_threshold=search_request.similarity_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get image IDs and scores from search results
|
|
||||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
|
||||||
scores = {result['image_id']: result['similarity_score'] for result in search_results if result['image_id']}
|
|
||||||
|
|
||||||
logger.info(f"Extracted {len(image_ids)} image IDs: {image_ids}")
|
|
||||||
|
|
||||||
# Get image metadata from database
|
|
||||||
logger.info("Fetching image metadata from database...")
|
|
||||||
images = await image_repository.get_by_ids(image_ids)
|
|
||||||
logger.info(f"Retrieved {len(images)} images from database")
|
|
||||||
|
|
||||||
# Apply filters
|
|
||||||
filtered_images = []
|
|
||||||
for image in images:
|
|
||||||
# Check collection filter
|
|
||||||
if search_request.collection_id and str(image.collection_id) != search_request.collection_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check date range filter
|
|
||||||
if search_request.date_from and image.upload_date < search_request.date_from:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if search_request.date_to and image.upload_date > search_request.date_to:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check uploader filter
|
|
||||||
if search_request.uploader_id and str(image.uploader_id) != search_request.uploader_id:
|
|
||||||
continue
|
|
||||||
|
|
||||||
filtered_images.append(image)
|
|
||||||
|
|
||||||
logger.info(f"After filtering: {len(filtered_images)} images remain")
|
|
||||||
|
|
||||||
# Convert to response format with similarity scores
|
|
||||||
results = []
|
|
||||||
for image in filtered_images:
|
|
||||||
image_id = str(image.id)
|
|
||||||
similarity_score = scores.get(image_id, 0.0)
|
|
||||||
|
|
||||||
result = ImageResponse(
|
|
||||||
id=image_id,
|
|
||||||
filename=image.filename,
|
|
||||||
original_filename=image.original_filename,
|
|
||||||
file_size=image.file_size,
|
|
||||||
content_type=image.content_type,
|
|
||||||
storage_path=image.storage_path,
|
|
||||||
team_id=str(image.team_id),
|
|
||||||
uploader_id=str(image.uploader_id),
|
|
||||||
upload_date=image.upload_date,
|
|
||||||
description=image.description,
|
|
||||||
metadata=image.metadata,
|
|
||||||
has_embedding=image.has_embedding,
|
|
||||||
collection_id=str(image.collection_id) if image.collection_id else None,
|
|
||||||
similarity_score=similarity_score
|
|
||||||
)
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
# Sort by similarity score (highest first)
|
|
||||||
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
|
||||||
|
|
||||||
logger.info(f"Returning {len(results)} results")
|
|
||||||
|
|
||||||
return SearchResponse(
|
|
||||||
query=search_request.query,
|
|
||||||
results=results,
|
|
||||||
total=len(results),
|
|
||||||
limit=search_request.limit,
|
|
||||||
similarity_threshold=search_request.similarity_threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in advanced search: {e}")
|
logger.error(f"Unexpected error in advanced search: {e}")
|
||||||
import traceback
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
||||||
raise HTTPException(status_code=500, detail="Advanced search failed")
|
|
||||||
|
|||||||
@ -2,15 +2,17 @@ import logging
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
|
|
||||||
from src.db.repositories.team_repository import team_repository
|
from src.services.team_service import TeamService
|
||||||
from src.schemas.team import TeamCreate, TeamUpdate, TeamResponse, TeamListResponse
|
from src.schemas.team import TeamCreate, TeamUpdate, TeamResponse, TeamListResponse
|
||||||
from src.models.team import TeamModel
|
|
||||||
from src.utils.logging import log_request
|
from src.utils.logging import log_request
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(tags=["Teams"], prefix="/teams")
|
router = APIRouter(tags=["Teams"], prefix="/teams")
|
||||||
|
|
||||||
|
# Initialize service
|
||||||
|
team_service = TeamService()
|
||||||
|
|
||||||
@router.post("", response_model=TeamResponse, status_code=201)
|
@router.post("", response_model=TeamResponse, status_code=201)
|
||||||
async def create_team(team_data: TeamCreate, request: Request):
|
async def create_team(team_data: TeamCreate, request: Request):
|
||||||
"""
|
"""
|
||||||
@ -22,24 +24,12 @@ async def create_team(team_data: TeamCreate, request: Request):
|
|||||||
{"path": request.url.path, "method": request.method, "team_data": team_data.dict()}
|
{"path": request.url.path, "method": request.method, "team_data": team_data.dict()}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create team
|
try:
|
||||||
team = TeamModel(
|
response = await team_service.create_team(team_data)
|
||||||
name=team_data.name,
|
return response
|
||||||
description=team_data.description
|
except Exception as e:
|
||||||
)
|
logger.error(f"Unexpected error creating team: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
created_team = await team_repository.create(team)
|
|
||||||
|
|
||||||
# Convert to response model
|
|
||||||
response = TeamResponse(
|
|
||||||
id=str(created_team.id),
|
|
||||||
name=created_team.name,
|
|
||||||
description=created_team.description,
|
|
||||||
created_at=created_team.created_at,
|
|
||||||
updated_at=created_team.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.get("", response_model=TeamListResponse)
|
@router.get("", response_model=TeamListResponse)
|
||||||
async def list_teams(request: Request):
|
async def list_teams(request: Request):
|
||||||
@ -52,21 +42,12 @@ async def list_teams(request: Request):
|
|||||||
{"path": request.url.path, "method": request.method}
|
{"path": request.url.path, "method": request.method}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all teams
|
try:
|
||||||
teams = await team_repository.get_all()
|
response = await team_service.list_teams()
|
||||||
|
return response
|
||||||
# Convert to response models
|
except Exception as e:
|
||||||
response_teams = []
|
logger.error(f"Unexpected error listing teams: {e}")
|
||||||
for team in teams:
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
response_teams.append(TeamResponse(
|
|
||||||
id=str(team.id),
|
|
||||||
name=team.name,
|
|
||||||
description=team.description,
|
|
||||||
created_at=team.created_at,
|
|
||||||
updated_at=team.updated_at
|
|
||||||
))
|
|
||||||
|
|
||||||
return TeamListResponse(teams=response_teams, total=len(response_teams))
|
|
||||||
|
|
||||||
@router.get("/{team_id}", response_model=TeamResponse)
|
@router.get("/{team_id}", response_model=TeamResponse)
|
||||||
async def get_team(team_id: str, request: Request):
|
async def get_team(team_id: str, request: Request):
|
||||||
@ -80,26 +61,15 @@ async def get_team(team_id: str, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert string ID to ObjectId
|
response = await team_service.get_team(team_id)
|
||||||
obj_id = ObjectId(team_id)
|
return response
|
||||||
except:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except RuntimeError as e:
|
||||||
# Get the team
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
team = await team_repository.get_by_id(obj_id)
|
except Exception as e:
|
||||||
if not team:
|
logger.error(f"Unexpected error getting team: {e}")
|
||||||
raise HTTPException(status_code=404, detail="Team not found")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
# Convert to response model
|
|
||||||
response = TeamResponse(
|
|
||||||
id=str(team.id),
|
|
||||||
name=team.name,
|
|
||||||
description=team.description,
|
|
||||||
created_at=team.created_at,
|
|
||||||
updated_at=team.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.put("/{team_id}", response_model=TeamResponse)
|
@router.put("/{team_id}", response_model=TeamResponse)
|
||||||
async def update_team(team_id: str, team_data: TeamUpdate, request: Request):
|
async def update_team(team_id: str, team_data: TeamUpdate, request: Request):
|
||||||
@ -113,42 +83,15 @@ async def update_team(team_id: str, team_data: TeamUpdate, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert string ID to ObjectId
|
response = await team_service.update_team(team_id, team_data)
|
||||||
obj_id = ObjectId(team_id)
|
return response
|
||||||
except:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except RuntimeError as e:
|
||||||
# Get the team
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
team = await team_repository.get_by_id(obj_id)
|
except Exception as e:
|
||||||
if not team:
|
logger.error(f"Unexpected error updating team: {e}")
|
||||||
raise HTTPException(status_code=404, detail="Team not found")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
# Update the team
|
|
||||||
update_data = team_data.dict(exclude_unset=True)
|
|
||||||
if not update_data:
|
|
||||||
# No fields to update
|
|
||||||
return TeamResponse(
|
|
||||||
id=str(team.id),
|
|
||||||
name=team.name,
|
|
||||||
description=team.description,
|
|
||||||
created_at=team.created_at,
|
|
||||||
updated_at=team.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_team = await team_repository.update(obj_id, update_data)
|
|
||||||
if not updated_team:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to update team")
|
|
||||||
|
|
||||||
# Convert to response model
|
|
||||||
response = TeamResponse(
|
|
||||||
id=str(updated_team.id),
|
|
||||||
name=updated_team.name,
|
|
||||||
description=updated_team.description,
|
|
||||||
created_at=updated_team.created_at,
|
|
||||||
updated_at=updated_team.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.delete("/{team_id}", status_code=204)
|
@router.delete("/{team_id}", status_code=204)
|
||||||
async def delete_team(team_id: str, request: Request):
|
async def delete_team(team_id: str, request: Request):
|
||||||
@ -162,17 +105,12 @@ async def delete_team(team_id: str, request: Request):
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Convert string ID to ObjectId
|
await team_service.delete_team(team_id)
|
||||||
obj_id = ObjectId(team_id)
|
return None
|
||||||
except:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except RuntimeError as e:
|
||||||
# Get the team
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
team = await team_repository.get_by_id(obj_id)
|
except Exception as e:
|
||||||
if not team:
|
logger.error(f"Unexpected error deleting team: {e}")
|
||||||
raise HTTPException(status_code=404, detail="Team not found")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
|
|
||||||
# Delete the team
|
|
||||||
success = await team_repository.delete(obj_id)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete team")
|
|
||||||
@ -1,13 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from bson import ObjectId
|
|
||||||
|
|
||||||
# Remove the auth import since we're removing authentication
|
from src.services.user_service import UserService
|
||||||
# from src.api.v1.auth import get_current_user
|
|
||||||
from src.db.repositories.user_repository import user_repository
|
|
||||||
from src.db.repositories.team_repository import team_repository
|
|
||||||
from src.models.user import UserModel
|
|
||||||
from src.schemas.user import UserResponse, UserListResponse, UserCreate, UserUpdate
|
from src.schemas.user import UserResponse, UserListResponse, UserCreate, UserUpdate
|
||||||
from src.utils.logging import log_request
|
from src.utils.logging import log_request
|
||||||
|
|
||||||
@ -15,6 +10,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
router = APIRouter(tags=["Users"], prefix="/users")
|
router = APIRouter(tags=["Users"], prefix="/users")
|
||||||
|
|
||||||
|
# Initialize service
|
||||||
|
user_service = UserService()
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
@router.get("/me", response_model=UserResponse)
|
||||||
async def read_users_me(
|
async def read_users_me(
|
||||||
request: Request,
|
request: Request,
|
||||||
@ -26,26 +24,15 @@ async def read_users_me(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj_id = ObjectId(user_id)
|
response = await user_service.get_user_by_id(user_id)
|
||||||
except:
|
return response
|
||||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
current_user = await user_repository.get_by_id(obj_id)
|
except RuntimeError as e:
|
||||||
if not current_user:
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error getting user: {e}")
|
||||||
response = UserResponse(
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
id=str(current_user.id),
|
|
||||||
name=current_user.name,
|
|
||||||
email=current_user.email,
|
|
||||||
team_id=str(current_user.team_id),
|
|
||||||
is_admin=current_user.is_admin,
|
|
||||||
is_active=current_user.is_active,
|
|
||||||
created_at=current_user.created_at,
|
|
||||||
updated_at=current_user.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.put("/me", response_model=UserResponse)
|
@router.put("/me", response_model=UserResponse)
|
||||||
async def update_current_user(
|
async def update_current_user(
|
||||||
@ -59,46 +46,15 @@ async def update_current_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj_id = ObjectId(user_id)
|
response = await user_service.update_user_by_id(user_id, user_data)
|
||||||
except:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
|
||||||
|
|
||||||
current_user = await user_repository.get_by_id(obj_id)
|
|
||||||
if not current_user:
|
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
|
||||||
|
|
||||||
# Update user
|
|
||||||
update_data = user_data.dict(exclude_unset=True)
|
|
||||||
if not update_data:
|
|
||||||
# No fields to update
|
|
||||||
response = UserResponse(
|
|
||||||
id=str(current_user.id),
|
|
||||||
name=current_user.name,
|
|
||||||
email=current_user.email,
|
|
||||||
team_id=str(current_user.team_id),
|
|
||||||
is_admin=current_user.is_admin,
|
|
||||||
is_active=current_user.is_active,
|
|
||||||
created_at=current_user.created_at,
|
|
||||||
updated_at=current_user.updated_at
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
except ValueError as e:
|
||||||
updated_user = await user_repository.update(current_user.id, update_data)
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
if not updated_user:
|
except RuntimeError as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update user")
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
response = UserResponse(
|
logger.error(f"Unexpected error updating user: {e}")
|
||||||
id=str(updated_user.id),
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
name=updated_user.name,
|
|
||||||
email=updated_user.email,
|
|
||||||
team_id=str(updated_user.team_id),
|
|
||||||
is_admin=updated_user.is_admin,
|
|
||||||
is_active=updated_user.is_active,
|
|
||||||
created_at=updated_user.created_at,
|
|
||||||
updated_at=updated_user.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.post("", response_model=UserResponse, status_code=201)
|
@router.post("", response_model=UserResponse, status_code=201)
|
||||||
async def create_user(
|
async def create_user(
|
||||||
@ -114,43 +70,16 @@ async def create_user(
|
|||||||
{"path": request.url.path, "method": request.method, "user_data": user_data.dict()}
|
{"path": request.url.path, "method": request.method, "user_data": user_data.dict()}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if user with email already exists
|
try:
|
||||||
existing_user = await user_repository.get_by_email(user_data.email)
|
response = await user_service.create_user(user_data)
|
||||||
if existing_user:
|
return response
|
||||||
raise HTTPException(status_code=400, detail="User with this email already exists")
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
# Validate team exists if specified
|
except RuntimeError as e:
|
||||||
if user_data.team_id:
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
team = await team_repository.get_by_id(ObjectId(user_data.team_id))
|
except Exception as e:
|
||||||
if not team:
|
logger.error(f"Unexpected error creating user: {e}")
|
||||||
raise HTTPException(status_code=400, detail="Team not found")
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
team_id = user_data.team_id
|
|
||||||
else:
|
|
||||||
raise HTTPException(status_code=400, detail="Team ID is required")
|
|
||||||
|
|
||||||
# Create user
|
|
||||||
user = UserModel(
|
|
||||||
name=user_data.name,
|
|
||||||
email=user_data.email,
|
|
||||||
team_id=ObjectId(team_id),
|
|
||||||
is_admin=user_data.is_admin or False,
|
|
||||||
is_active=True
|
|
||||||
)
|
|
||||||
|
|
||||||
created_user = await user_repository.create(user)
|
|
||||||
|
|
||||||
response = UserResponse(
|
|
||||||
id=str(created_user.id),
|
|
||||||
name=created_user.name,
|
|
||||||
email=created_user.email,
|
|
||||||
team_id=str(created_user.team_id),
|
|
||||||
is_admin=created_user.is_admin,
|
|
||||||
is_active=created_user.is_active,
|
|
||||||
created_at=created_user.created_at,
|
|
||||||
updated_at=created_user.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.get("", response_model=UserListResponse)
|
@router.get("", response_model=UserListResponse)
|
||||||
async def list_users(
|
async def list_users(
|
||||||
@ -166,31 +95,14 @@ async def list_users(
|
|||||||
{"path": request.url.path, "method": request.method, "team_id": team_id}
|
{"path": request.url.path, "method": request.method, "team_id": team_id}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get users
|
try:
|
||||||
if team_id:
|
response = await user_service.list_users(team_id)
|
||||||
try:
|
return response
|
||||||
filter_team_id = ObjectId(team_id)
|
except ValueError as e:
|
||||||
users = await user_repository.get_by_team(filter_team_id)
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=400, detail="Invalid team ID")
|
logger.error(f"Unexpected error listing users: {e}")
|
||||||
else:
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
users = await user_repository.get_all()
|
|
||||||
|
|
||||||
# Convert to response
|
|
||||||
response_users = []
|
|
||||||
for user in users:
|
|
||||||
response_users.append(UserResponse(
|
|
||||||
id=str(user.id),
|
|
||||||
name=user.name,
|
|
||||||
email=user.email,
|
|
||||||
team_id=str(user.team_id),
|
|
||||||
is_admin=user.is_admin,
|
|
||||||
is_active=user.is_active,
|
|
||||||
created_at=user.created_at,
|
|
||||||
updated_at=user.updated_at
|
|
||||||
))
|
|
||||||
|
|
||||||
return UserListResponse(users=response_users, total=len(response_users))
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=UserResponse)
|
@router.get("/{user_id}", response_model=UserResponse)
|
||||||
async def get_user(
|
async def get_user(
|
||||||
@ -207,26 +119,15 @@ async def get_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj_id = ObjectId(user_id)
|
response = await user_service.get_user(user_id)
|
||||||
except:
|
return response
|
||||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
user = await user_repository.get_by_id(obj_id)
|
except RuntimeError as e:
|
||||||
if not user:
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error getting user: {e}")
|
||||||
response = UserResponse(
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
id=str(user.id),
|
|
||||||
name=user.name,
|
|
||||||
email=user.email,
|
|
||||||
team_id=str(user.team_id),
|
|
||||||
is_admin=user.is_admin,
|
|
||||||
is_active=user.is_active,
|
|
||||||
created_at=user.created_at,
|
|
||||||
updated_at=user.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.put("/{user_id}", response_model=UserResponse)
|
@router.put("/{user_id}", response_model=UserResponse)
|
||||||
async def update_user(
|
async def update_user(
|
||||||
@ -244,47 +145,15 @@ async def update_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj_id = ObjectId(user_id)
|
response = await user_service.update_user(user_id, user_data)
|
||||||
except:
|
|
||||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
|
||||||
|
|
||||||
# Check if user exists
|
|
||||||
user = await user_repository.get_by_id(obj_id)
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
|
||||||
|
|
||||||
# Update user
|
|
||||||
update_data = user_data.dict(exclude_unset=True)
|
|
||||||
if not update_data:
|
|
||||||
# No fields to update
|
|
||||||
response = UserResponse(
|
|
||||||
id=str(user.id),
|
|
||||||
name=user.name,
|
|
||||||
email=user.email,
|
|
||||||
team_id=str(user.team_id),
|
|
||||||
is_admin=user.is_admin,
|
|
||||||
is_active=user.is_active,
|
|
||||||
created_at=user.created_at,
|
|
||||||
updated_at=user.updated_at
|
|
||||||
)
|
|
||||||
return response
|
return response
|
||||||
|
except ValueError as e:
|
||||||
updated_user = await user_repository.update(obj_id, update_data)
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
if not updated_user:
|
except RuntimeError as e:
|
||||||
raise HTTPException(status_code=500, detail="Failed to update user")
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
response = UserResponse(
|
logger.error(f"Unexpected error updating user: {e}")
|
||||||
id=str(updated_user.id),
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
name=updated_user.name,
|
|
||||||
email=updated_user.email,
|
|
||||||
team_id=str(updated_user.team_id),
|
|
||||||
is_admin=updated_user.is_admin,
|
|
||||||
is_active=updated_user.is_active,
|
|
||||||
created_at=updated_user.created_at,
|
|
||||||
updated_at=updated_user.updated_at
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
@router.delete("/{user_id}", status_code=204)
|
@router.delete("/{user_id}", status_code=204)
|
||||||
async def delete_user(
|
async def delete_user(
|
||||||
@ -301,16 +170,12 @@ async def delete_user(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
obj_id = ObjectId(user_id)
|
await user_service.delete_user(user_id)
|
||||||
except:
|
return None
|
||||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
# Check if user exists
|
except RuntimeError as e:
|
||||||
user = await user_repository.get_by_id(obj_id)
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
if not user:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
logger.error(f"Unexpected error deleting user: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Internal server error")
|
||||||
# Delete user
|
|
||||||
success = await user_repository.delete(obj_id)
|
|
||||||
if not success:
|
|
||||||
raise HTTPException(status_code=500, detail="Failed to delete user")
|
|
||||||
255
src/services/auth_service.py
Normal file
255
src/services/auth_service.py
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db.repositories.api_key_repository import api_key_repository
|
||||||
|
from src.db.repositories.user_repository import user_repository
|
||||||
|
from src.db.repositories.team_repository import team_repository
|
||||||
|
from src.schemas.api_key import ApiKeyCreate, ApiKeyResponse, ApiKeyWithValueResponse, ApiKeyListResponse
|
||||||
|
from src.auth.security import generate_api_key, verify_api_key, calculate_expiry_date, is_expired, hash_api_key
|
||||||
|
from src.models.api_key import ApiKeyModel
|
||||||
|
from src.models.team import TeamModel
|
||||||
|
from src.models.user import UserModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class AuthService:
|
||||||
|
"""Service class for handling authentication-related business logic"""
|
||||||
|
|
||||||
|
async def create_api_key_for_user_and_team(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
team_id: str,
|
||||||
|
key_data: ApiKeyCreate
|
||||||
|
) -> ApiKeyWithValueResponse:
|
||||||
|
"""
|
||||||
|
Create a new API key for a specific user and team
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to create the key for
|
||||||
|
team_id: The team ID the user belongs to
|
||||||
|
key_data: The API key creation data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiKeyWithValueResponse: The created API key with the raw key value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If user_id or team_id are invalid
|
||||||
|
RuntimeError: If user or team not found, or user doesn't belong to team
|
||||||
|
"""
|
||||||
|
# Validate user_id and team_id
|
||||||
|
try:
|
||||||
|
target_user_id = ObjectId(user_id)
|
||||||
|
target_team_id = ObjectId(team_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid user ID or team ID")
|
||||||
|
|
||||||
|
# Verify user exists
|
||||||
|
target_user = await user_repository.get_by_id(target_user_id)
|
||||||
|
if not target_user:
|
||||||
|
raise RuntimeError("User not found")
|
||||||
|
|
||||||
|
# Verify team exists
|
||||||
|
team = await team_repository.get_by_id(target_team_id)
|
||||||
|
if not team:
|
||||||
|
raise RuntimeError("Team not found")
|
||||||
|
|
||||||
|
# Verify user belongs to the team
|
||||||
|
if target_user.team_id != target_team_id:
|
||||||
|
raise RuntimeError("User does not belong to the specified team")
|
||||||
|
|
||||||
|
# Validate key_data consistency
|
||||||
|
if key_data.team_id and key_data.team_id != team_id:
|
||||||
|
raise ValueError("Team ID in request body does not match parameter")
|
||||||
|
|
||||||
|
if key_data.user_id and key_data.user_id != user_id:
|
||||||
|
raise ValueError("User ID in request body does not match parameter")
|
||||||
|
|
||||||
|
# Generate API key with expiry date
|
||||||
|
raw_key, hashed_key = generate_api_key(str(target_team_id), str(target_user_id))
|
||||||
|
expiry_date = calculate_expiry_date()
|
||||||
|
|
||||||
|
# Create API key in database
|
||||||
|
api_key = ApiKeyModel(
|
||||||
|
key_hash=hashed_key,
|
||||||
|
user_id=target_user_id,
|
||||||
|
team_id=target_team_id,
|
||||||
|
name=key_data.name,
|
||||||
|
description=key_data.description,
|
||||||
|
expiry_date=expiry_date,
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
|
||||||
|
created_key = await api_key_repository.create(api_key)
|
||||||
|
|
||||||
|
# Convert to response model
|
||||||
|
return ApiKeyWithValueResponse(
|
||||||
|
id=str(created_key.id),
|
||||||
|
key=raw_key,
|
||||||
|
name=created_key.name,
|
||||||
|
description=created_key.description,
|
||||||
|
team_id=str(created_key.team_id),
|
||||||
|
user_id=str(created_key.user_id),
|
||||||
|
created_at=created_key.created_at,
|
||||||
|
expiry_date=created_key.expiry_date,
|
||||||
|
last_used=created_key.last_used,
|
||||||
|
is_active=created_key.is_active
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_api_key_for_user_by_admin(
|
||||||
|
self,
|
||||||
|
target_user_id: str,
|
||||||
|
key_data: ApiKeyCreate,
|
||||||
|
admin_user: UserModel
|
||||||
|
) -> ApiKeyWithValueResponse:
|
||||||
|
"""
|
||||||
|
Create a new API key for a specific user (admin only)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_user_id: The user ID to create the key for
|
||||||
|
key_data: The API key creation data
|
||||||
|
admin_user: The admin user performing the action
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiKeyWithValueResponse: The created API key with the raw key value
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
PermissionError: If the admin user doesn't have admin privileges
|
||||||
|
ValueError: If target_user_id is invalid
|
||||||
|
RuntimeError: If target user or team not found
|
||||||
|
"""
|
||||||
|
# Check if current user is admin
|
||||||
|
if not admin_user.is_admin:
|
||||||
|
raise PermissionError("Admin access required")
|
||||||
|
|
||||||
|
try:
|
||||||
|
target_user_obj_id = ObjectId(target_user_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid user ID")
|
||||||
|
|
||||||
|
# Get the target user
|
||||||
|
target_user = await user_repository.get_by_id(target_user_obj_id)
|
||||||
|
if not target_user:
|
||||||
|
raise RuntimeError("Target user not found")
|
||||||
|
|
||||||
|
# Check if target user's team exists
|
||||||
|
team = await team_repository.get_by_id(target_user.team_id)
|
||||||
|
if not team:
|
||||||
|
raise RuntimeError("Target user's team not found")
|
||||||
|
|
||||||
|
# Generate API key with expiry date
|
||||||
|
raw_key, hashed_key = generate_api_key(str(target_user.team_id), str(target_user.id))
|
||||||
|
expiry_date = calculate_expiry_date()
|
||||||
|
|
||||||
|
# Create API key in database
|
||||||
|
api_key = ApiKeyModel(
|
||||||
|
key_hash=hashed_key,
|
||||||
|
user_id=target_user.id,
|
||||||
|
team_id=target_user.team_id,
|
||||||
|
name=key_data.name,
|
||||||
|
description=key_data.description,
|
||||||
|
expiry_date=expiry_date,
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
|
||||||
|
created_key = await api_key_repository.create(api_key)
|
||||||
|
|
||||||
|
# Convert to response model
|
||||||
|
return ApiKeyWithValueResponse(
|
||||||
|
id=str(created_key.id),
|
||||||
|
key=raw_key,
|
||||||
|
name=created_key.name,
|
||||||
|
description=created_key.description,
|
||||||
|
team_id=str(created_key.team_id),
|
||||||
|
user_id=str(created_key.user_id),
|
||||||
|
created_at=created_key.created_at,
|
||||||
|
expiry_date=created_key.expiry_date,
|
||||||
|
last_used=created_key.last_used,
|
||||||
|
is_active=created_key.is_active
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_user_api_keys(self, user: UserModel) -> ApiKeyListResponse:
|
||||||
|
"""
|
||||||
|
List API keys for a specific user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: The user to list API keys for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiKeyListResponse: List of API keys for the user
|
||||||
|
"""
|
||||||
|
# Get API keys for user
|
||||||
|
keys = await api_key_repository.get_by_user(user.id)
|
||||||
|
|
||||||
|
# Convert to response models
|
||||||
|
response_keys = []
|
||||||
|
for key in keys:
|
||||||
|
response_keys.append(ApiKeyResponse(
|
||||||
|
id=str(key.id),
|
||||||
|
name=key.name,
|
||||||
|
description=key.description,
|
||||||
|
team_id=str(key.team_id),
|
||||||
|
user_id=str(key.user_id),
|
||||||
|
created_at=key.created_at,
|
||||||
|
expiry_date=key.expiry_date,
|
||||||
|
last_used=key.last_used,
|
||||||
|
is_active=key.is_active
|
||||||
|
))
|
||||||
|
|
||||||
|
return ApiKeyListResponse(api_keys=response_keys, total=len(response_keys))
|
||||||
|
|
||||||
|
async def revoke_api_key(self, key_id: str, user: UserModel) -> bool:
|
||||||
|
"""
|
||||||
|
Revoke (deactivate) an API key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_id: The ID of the API key to revoke
|
||||||
|
user: The user attempting to revoke the key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successfully revoked
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If key_id is invalid
|
||||||
|
RuntimeError: If key not found
|
||||||
|
PermissionError: If user not authorized to revoke the key
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(key_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid key ID")
|
||||||
|
|
||||||
|
# Get the API key
|
||||||
|
key = await api_key_repository.get_by_id(obj_id)
|
||||||
|
if not key:
|
||||||
|
raise RuntimeError("API key not found")
|
||||||
|
|
||||||
|
# Check if user owns the key or is an admin
|
||||||
|
if key.user_id != user.id and not user.is_admin:
|
||||||
|
raise PermissionError("Not authorized to revoke this API key")
|
||||||
|
|
||||||
|
# Deactivate the key
|
||||||
|
result = await api_key_repository.deactivate(obj_id)
|
||||||
|
if not result:
|
||||||
|
raise RuntimeError("Failed to revoke API key")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def verify_user_authentication(self, user: UserModel) -> dict:
|
||||||
|
"""
|
||||||
|
Verify and return user authentication information
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: The authenticated user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: User authentication information
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"user_id": str(user.id),
|
||||||
|
"name": user.name,
|
||||||
|
"email": user.email,
|
||||||
|
"team_id": str(user.team_id),
|
||||||
|
"is_admin": user.is_admin
|
||||||
|
}
|
||||||
370
src/services/image_service.py
Normal file
370
src/services/image_service.py
Normal file
@ -0,0 +1,370 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional, List, Tuple
|
||||||
|
from fastapi import UploadFile, Request
|
||||||
|
from bson import ObjectId
|
||||||
|
import io
|
||||||
|
|
||||||
|
from src.db.repositories.image_repository import image_repository
|
||||||
|
from src.services.storage import StorageService
|
||||||
|
from src.services.image_processor import ImageProcessor
|
||||||
|
from src.services.embedding_service import EmbeddingService
|
||||||
|
from src.services.pubsub_service import pubsub_service
|
||||||
|
from src.models.image import ImageModel
|
||||||
|
from src.models.user import UserModel
|
||||||
|
from src.schemas.image import ImageResponse, ImageListResponse, ImageCreate, ImageUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class ImageService:
|
||||||
|
"""Service class for handling image-related business logic"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.storage_service = StorageService()
|
||||||
|
self.image_processor = ImageProcessor()
|
||||||
|
self.embedding_service = EmbeddingService()
|
||||||
|
|
||||||
|
def _generate_api_download_url(self, request: Request, image_id: str) -> str:
|
||||||
|
"""Generate API download URL for an image"""
|
||||||
|
base_url = str(request.base_url).rstrip('/')
|
||||||
|
return f"{base_url}/api/v1/images/{image_id}/download"
|
||||||
|
|
||||||
|
async def upload_image(
|
||||||
|
self,
|
||||||
|
file: UploadFile,
|
||||||
|
user: UserModel,
|
||||||
|
request: Request,
|
||||||
|
description: Optional[str] = None,
|
||||||
|
collection_id: Optional[str] = None
|
||||||
|
) -> ImageResponse:
|
||||||
|
"""
|
||||||
|
Upload a new image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file: The uploaded file
|
||||||
|
user: The user uploading the image
|
||||||
|
request: The FastAPI request object for URL generation
|
||||||
|
description: Optional description for the image
|
||||||
|
collection_id: Optional collection ID to associate with the image
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageResponse: The created image metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If file validation fails
|
||||||
|
RuntimeError: If upload fails
|
||||||
|
"""
|
||||||
|
# Validate file type
|
||||||
|
if not file.content_type or not file.content_type.startswith('image/'):
|
||||||
|
raise ValueError("File must be an image")
|
||||||
|
|
||||||
|
# Validate file size (10MB limit)
|
||||||
|
max_size = 10 * 1024 * 1024 # 10MB
|
||||||
|
content = await file.read()
|
||||||
|
if len(content) > max_size:
|
||||||
|
raise ValueError("File size exceeds 10MB limit")
|
||||||
|
|
||||||
|
# Reset file pointer
|
||||||
|
await file.seek(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Upload to storage
|
||||||
|
storage_path, content_type, file_size, metadata = await self.storage_service.upload_file(
|
||||||
|
file, str(user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create image record
|
||||||
|
image = ImageModel(
|
||||||
|
filename=file.filename,
|
||||||
|
original_filename=file.filename,
|
||||||
|
file_size=file_size,
|
||||||
|
content_type=content_type,
|
||||||
|
storage_path=storage_path,
|
||||||
|
public_url=None, # Will be set after we have the image ID
|
||||||
|
team_id=user.team_id,
|
||||||
|
uploader_id=user.id,
|
||||||
|
description=description,
|
||||||
|
metadata=metadata,
|
||||||
|
collection_id=ObjectId(collection_id) if collection_id else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to database
|
||||||
|
created_image = await image_repository.create(image)
|
||||||
|
|
||||||
|
# Generate API download URL now that we have the image ID
|
||||||
|
api_download_url = self._generate_api_download_url(request, str(created_image.id))
|
||||||
|
|
||||||
|
# Update the image with the API download URL
|
||||||
|
await image_repository.update(created_image.id, {"public_url": api_download_url})
|
||||||
|
created_image.public_url = api_download_url
|
||||||
|
|
||||||
|
# Publish image processing task to Pub/Sub
|
||||||
|
try:
|
||||||
|
task_published = await pubsub_service.publish_image_processing_task(
|
||||||
|
image_id=str(created_image.id),
|
||||||
|
storage_path=storage_path,
|
||||||
|
team_id=str(user.team_id)
|
||||||
|
)
|
||||||
|
if not task_published:
|
||||||
|
logger.warning(f"Failed to publish processing task for image {created_image.id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to publish image processing task: {e}")
|
||||||
|
|
||||||
|
# Convert to response
|
||||||
|
return self._convert_to_response(created_image, request)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error uploading image: {e}")
|
||||||
|
raise RuntimeError("Failed to upload image")
|
||||||
|
|
||||||
|
async def list_images(
|
||||||
|
self,
|
||||||
|
user: UserModel,
|
||||||
|
request: Request,
|
||||||
|
skip: int = 0,
|
||||||
|
limit: int = 50,
|
||||||
|
collection_id: Optional[str] = None
|
||||||
|
) -> ImageListResponse:
|
||||||
|
"""
|
||||||
|
List images for the user's team or all images if user is admin
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user: The requesting user
|
||||||
|
request: The FastAPI request object for URL generation
|
||||||
|
skip: Number of records to skip for pagination
|
||||||
|
limit: Maximum number of records to return
|
||||||
|
collection_id: Optional filter by collection ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageListResponse: List of images with pagination metadata
|
||||||
|
"""
|
||||||
|
# Check if user is admin - if so, get all images across all teams
|
||||||
|
if user.is_admin:
|
||||||
|
images = await image_repository.get_all_with_pagination(
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||||
|
)
|
||||||
|
total = await image_repository.count_all(
|
||||||
|
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Regular users only see images from their team
|
||||||
|
images = await image_repository.get_by_team(
|
||||||
|
user.team_id,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||||
|
)
|
||||||
|
total = await image_repository.count_by_team(
|
||||||
|
user.team_id,
|
||||||
|
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to response
|
||||||
|
response_images = [self._convert_to_response(image, request) for image in images]
|
||||||
|
|
||||||
|
return ImageListResponse(images=response_images, total=total, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
async def get_image(self, image_id: str, user: UserModel, request: Request) -> ImageResponse:
|
||||||
|
"""
|
||||||
|
Get image metadata by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: The image ID to retrieve
|
||||||
|
user: The requesting user
|
||||||
|
request: The FastAPI request object for URL generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageResponse: The image metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If image_id is invalid
|
||||||
|
RuntimeError: If image not found
|
||||||
|
PermissionError: If user not authorized to access the image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(image_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid image ID")
|
||||||
|
|
||||||
|
# Get image
|
||||||
|
image = await image_repository.get_by_id(obj_id)
|
||||||
|
if not image:
|
||||||
|
raise RuntimeError("Image not found")
|
||||||
|
|
||||||
|
# Check team access (admins can access any image)
|
||||||
|
if not user.is_admin and image.team_id != user.team_id:
|
||||||
|
raise PermissionError("Not authorized to access this image")
|
||||||
|
|
||||||
|
return self._convert_to_response(image, request, include_last_accessed=True)
|
||||||
|
|
||||||
|
async def download_image(self, image_id: str, user: UserModel) -> Tuple[bytes, str, str]:
|
||||||
|
"""
|
||||||
|
Download image file
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: The image ID to download
|
||||||
|
user: The requesting user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[bytes, str, str]: File content, content type, and filename
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If image_id is invalid
|
||||||
|
RuntimeError: If image not found or file not found in storage
|
||||||
|
PermissionError: If user not authorized to access the image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(image_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid image ID")
|
||||||
|
|
||||||
|
# Get image
|
||||||
|
image = await image_repository.get_by_id(obj_id)
|
||||||
|
if not image:
|
||||||
|
raise RuntimeError("Image not found")
|
||||||
|
|
||||||
|
# Check team access (admins can access any image)
|
||||||
|
if not user.is_admin and image.team_id != user.team_id:
|
||||||
|
raise PermissionError("Not authorized to access this image")
|
||||||
|
|
||||||
|
# Get file from storage
|
||||||
|
file_content = self.storage_service.get_file(image.storage_path)
|
||||||
|
if not file_content:
|
||||||
|
raise RuntimeError("Image file not found in storage")
|
||||||
|
|
||||||
|
# Update last accessed
|
||||||
|
await image_repository.update_last_accessed(obj_id)
|
||||||
|
|
||||||
|
return file_content, image.content_type, image.original_filename
|
||||||
|
|
||||||
|
async def update_image(
|
||||||
|
self,
|
||||||
|
image_id: str,
|
||||||
|
image_data: ImageUpdate,
|
||||||
|
user: UserModel,
|
||||||
|
request: Request
|
||||||
|
) -> ImageResponse:
|
||||||
|
"""
|
||||||
|
Update image metadata
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: The image ID to update
|
||||||
|
image_data: The update data
|
||||||
|
user: The requesting user
|
||||||
|
request: The FastAPI request object for URL generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageResponse: The updated image metadata
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If image_id is invalid
|
||||||
|
RuntimeError: If image not found or update fails
|
||||||
|
PermissionError: If user not authorized to update the image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(image_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid image ID")
|
||||||
|
|
||||||
|
# Get image
|
||||||
|
image = await image_repository.get_by_id(obj_id)
|
||||||
|
if not image:
|
||||||
|
raise RuntimeError("Image not found")
|
||||||
|
|
||||||
|
# Check team access (admins can update any image)
|
||||||
|
if not user.is_admin and image.team_id != user.team_id:
|
||||||
|
raise PermissionError("Not authorized to update this image")
|
||||||
|
|
||||||
|
# Update image
|
||||||
|
update_data = image_data.dict(exclude_unset=True)
|
||||||
|
if not update_data:
|
||||||
|
# No fields to update
|
||||||
|
return self._convert_to_response(image, request)
|
||||||
|
|
||||||
|
updated_image = await image_repository.update(obj_id, update_data)
|
||||||
|
if not updated_image:
|
||||||
|
raise RuntimeError("Failed to update image")
|
||||||
|
|
||||||
|
return self._convert_to_response(updated_image, request)
|
||||||
|
|
||||||
|
async def delete_image(self, image_id: str, user: UserModel) -> bool:
|
||||||
|
"""
|
||||||
|
Delete an image
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: The image ID to delete
|
||||||
|
user: The requesting user
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successfully deleted
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If image_id is invalid
|
||||||
|
RuntimeError: If image not found or deletion fails
|
||||||
|
PermissionError: If user not authorized to delete the image
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(image_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid image ID")
|
||||||
|
|
||||||
|
# Get image
|
||||||
|
image = await image_repository.get_by_id(obj_id)
|
||||||
|
if not image:
|
||||||
|
raise RuntimeError("Image not found")
|
||||||
|
|
||||||
|
# Check team access (admins can delete any image)
|
||||||
|
if not user.is_admin and image.team_id != user.team_id:
|
||||||
|
raise PermissionError("Not authorized to delete this image")
|
||||||
|
|
||||||
|
# Delete from storage
|
||||||
|
try:
|
||||||
|
self.storage_service.delete_file(image.storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete file from storage: {e}")
|
||||||
|
|
||||||
|
# Delete from vector database if it has embeddings
|
||||||
|
if image.has_embedding and image.embedding_id:
|
||||||
|
try:
|
||||||
|
await self.embedding_service.delete_embedding(image.embedding_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete embedding: {e}")
|
||||||
|
|
||||||
|
# Delete from database
|
||||||
|
result = await image_repository.delete(obj_id)
|
||||||
|
if not result:
|
||||||
|
raise RuntimeError("Failed to delete image")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _convert_to_response(
|
||||||
|
self,
|
||||||
|
image: ImageModel,
|
||||||
|
request: Request,
|
||||||
|
include_last_accessed: bool = False
|
||||||
|
) -> ImageResponse:
|
||||||
|
"""Convert ImageModel to ImageResponse"""
|
||||||
|
api_download_url = self._generate_api_download_url(request, str(image.id))
|
||||||
|
|
||||||
|
response_data = {
|
||||||
|
"id": str(image.id),
|
||||||
|
"filename": image.filename,
|
||||||
|
"original_filename": image.original_filename,
|
||||||
|
"file_size": image.file_size,
|
||||||
|
"content_type": image.content_type,
|
||||||
|
"storage_path": image.storage_path,
|
||||||
|
"public_url": api_download_url,
|
||||||
|
"team_id": str(image.team_id),
|
||||||
|
"uploader_id": str(image.uploader_id),
|
||||||
|
"upload_date": image.upload_date,
|
||||||
|
"description": image.description,
|
||||||
|
"metadata": image.metadata,
|
||||||
|
"has_embedding": image.has_embedding,
|
||||||
|
"collection_id": str(image.collection_id) if image.collection_id else None
|
||||||
|
}
|
||||||
|
|
||||||
|
if include_last_accessed:
|
||||||
|
response_data["last_accessed"] = image.last_accessed
|
||||||
|
|
||||||
|
return ImageResponse(**response_data)
|
||||||
271
src/services/search_service.py
Normal file
271
src/services/search_service.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
from fastapi import Request
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.services.vector_db import VectorDatabaseService
|
||||||
|
from src.services.embedding_service import EmbeddingService
|
||||||
|
from src.db.repositories.image_repository import image_repository
|
||||||
|
from src.models.user import UserModel
|
||||||
|
from src.schemas.image import ImageResponse
|
||||||
|
from src.schemas.search import SearchResponse, SearchRequest
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class SearchService:
|
||||||
|
"""Service class for handling search-related business logic"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.vector_db_service = None
|
||||||
|
self.embedding_service = EmbeddingService()
|
||||||
|
|
||||||
|
def _get_vector_db_service(self):
|
||||||
|
"""Get or create the vector database service instance"""
|
||||||
|
if self.vector_db_service is None:
|
||||||
|
logger.info("Initializing VectorDatabaseService...")
|
||||||
|
self.vector_db_service = VectorDatabaseService()
|
||||||
|
return self.vector_db_service
|
||||||
|
|
||||||
|
def _generate_api_download_url(self, request: Request, image_id: str) -> str:
|
||||||
|
"""Generate API download URL for an image"""
|
||||||
|
base_url = str(request.base_url).rstrip('/')
|
||||||
|
return f"{base_url}/api/v1/images/{image_id}/download"
|
||||||
|
|
||||||
|
async def search_images(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
user: UserModel,
|
||||||
|
request: Request,
|
||||||
|
limit: int = 10,
|
||||||
|
similarity_threshold: float = 0.65,
|
||||||
|
collection_id: Optional[str] = None
|
||||||
|
) -> SearchResponse:
|
||||||
|
"""
|
||||||
|
Search for images using semantic similarity
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: The search query text
|
||||||
|
user: The requesting user
|
||||||
|
request: The FastAPI request object for URL generation
|
||||||
|
limit: Number of results to return
|
||||||
|
similarity_threshold: Similarity threshold for filtering results
|
||||||
|
collection_id: Optional filter by collection ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SearchResponse: Search results with similarity scores
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If query embedding generation fails
|
||||||
|
RuntimeError: If search fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Generate embedding for the search query
|
||||||
|
query_embedding = await self.embedding_service.generate_text_embedding(query)
|
||||||
|
if not query_embedding:
|
||||||
|
raise ValueError("Failed to generate search embedding")
|
||||||
|
|
||||||
|
# Search in vector database
|
||||||
|
search_results = self._get_vector_db_service().search_similar_images(
|
||||||
|
query_vector=query_embedding,
|
||||||
|
limit=limit,
|
||||||
|
similarity_threshold=similarity_threshold,
|
||||||
|
filter_conditions={"team_id": str(user.team_id)} if user.team_id else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if not search_results:
|
||||||
|
return SearchResponse(
|
||||||
|
query=query,
|
||||||
|
results=[],
|
||||||
|
total=0,
|
||||||
|
limit=limit,
|
||||||
|
similarity_threshold=similarity_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get image IDs and scores from search results
|
||||||
|
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||||
|
scores = {result['image_id']: result['similarity_score'] for result in search_results if result['image_id']}
|
||||||
|
|
||||||
|
# Get image metadata from database
|
||||||
|
images = await image_repository.get_by_ids(image_ids)
|
||||||
|
|
||||||
|
# Filter by collection if specified
|
||||||
|
filtered_images = self._filter_images_by_collection(images, collection_id)
|
||||||
|
|
||||||
|
# Convert to response format with similarity scores
|
||||||
|
results = self._convert_images_to_search_results(filtered_images, scores, request)
|
||||||
|
|
||||||
|
# Sort by similarity score (highest first)
|
||||||
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||||
|
|
||||||
|
return SearchResponse(
|
||||||
|
query=query,
|
||||||
|
results=results,
|
||||||
|
total=len(results),
|
||||||
|
limit=limit,
|
||||||
|
similarity_threshold=similarity_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error searching images: {e}")
|
||||||
|
raise RuntimeError("Search failed")
|
||||||
|
|
||||||
|
async def search_images_advanced(
|
||||||
|
self,
|
||||||
|
search_request: SearchRequest,
|
||||||
|
user: UserModel,
|
||||||
|
request: Request
|
||||||
|
) -> SearchResponse:
|
||||||
|
"""
|
||||||
|
Advanced search for images with more filtering options
|
||||||
|
|
||||||
|
Args:
|
||||||
|
search_request: The advanced search request with filters
|
||||||
|
user: The requesting user
|
||||||
|
request: The FastAPI request object for URL generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SearchResponse: Search results with similarity scores
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If query embedding generation fails
|
||||||
|
RuntimeError: If search fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Generate embedding for the search query
|
||||||
|
logger.info(f"Generating embedding for query: {search_request.query}")
|
||||||
|
query_embedding = await self.embedding_service.generate_text_embedding(search_request.query)
|
||||||
|
if not query_embedding:
|
||||||
|
logger.error("Failed to generate search embedding")
|
||||||
|
raise ValueError("Failed to generate search embedding")
|
||||||
|
|
||||||
|
logger.info(f"Generated embedding with length: {len(query_embedding)}")
|
||||||
|
|
||||||
|
# Search in vector database
|
||||||
|
logger.info(f"Searching vector database with similarity_threshold: {search_request.similarity_threshold}")
|
||||||
|
search_results = self._get_vector_db_service().search_similar_images(
|
||||||
|
query_vector=query_embedding,
|
||||||
|
limit=search_request.limit,
|
||||||
|
similarity_threshold=search_request.similarity_threshold,
|
||||||
|
filter_conditions={"team_id": str(user.team_id)} if user.team_id else None
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Vector search returned {len(search_results) if search_results else 0} results")
|
||||||
|
|
||||||
|
if not search_results:
|
||||||
|
logger.info("No search results from vector database, returning empty response")
|
||||||
|
return SearchResponse(
|
||||||
|
query=search_request.query,
|
||||||
|
results=[],
|
||||||
|
total=0,
|
||||||
|
limit=search_request.limit,
|
||||||
|
similarity_threshold=search_request.similarity_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get image IDs and scores from search results
|
||||||
|
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||||
|
scores = {result['image_id']: result['similarity_score'] for result in search_results if result['image_id']}
|
||||||
|
|
||||||
|
logger.info(f"Extracted {len(image_ids)} image IDs: {image_ids}")
|
||||||
|
|
||||||
|
# Get image metadata from database
|
||||||
|
logger.info("Fetching image metadata from database...")
|
||||||
|
images = await image_repository.get_by_ids(image_ids)
|
||||||
|
logger.info(f"Retrieved {len(images)} images from database")
|
||||||
|
|
||||||
|
# Apply advanced filters
|
||||||
|
filtered_images = self._apply_advanced_filters(images, search_request)
|
||||||
|
|
||||||
|
logger.info(f"After filtering: {len(filtered_images)} images remain")
|
||||||
|
|
||||||
|
# Convert to response format with similarity scores
|
||||||
|
results = self._convert_images_to_search_results(filtered_images, scores, request)
|
||||||
|
|
||||||
|
# Sort by similarity score (highest first)
|
||||||
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||||
|
|
||||||
|
logger.info(f"Returning {len(results)} results")
|
||||||
|
|
||||||
|
return SearchResponse(
|
||||||
|
query=search_request.query,
|
||||||
|
results=results,
|
||||||
|
total=len(results),
|
||||||
|
limit=search_request.limit,
|
||||||
|
similarity_threshold=search_request.similarity_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in advanced search: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
raise RuntimeError("Advanced search failed")
|
||||||
|
|
||||||
|
def _filter_images_by_collection(self, images: List, collection_id: Optional[str]) -> List:
|
||||||
|
"""Filter images by collection ID"""
|
||||||
|
if not collection_id:
|
||||||
|
return images
|
||||||
|
|
||||||
|
filtered_images = []
|
||||||
|
for image in images:
|
||||||
|
if str(image.collection_id) == collection_id:
|
||||||
|
filtered_images.append(image)
|
||||||
|
|
||||||
|
return filtered_images
|
||||||
|
|
||||||
|
def _apply_advanced_filters(self, images: List, search_request: SearchRequest) -> List:
|
||||||
|
"""Apply advanced filters to the image list"""
|
||||||
|
filtered_images = []
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
# Check collection filter
|
||||||
|
if search_request.collection_id and str(image.collection_id) != search_request.collection_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check date range filter
|
||||||
|
if search_request.date_from and image.upload_date < search_request.date_from:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if search_request.date_to and image.upload_date > search_request.date_to:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check uploader filter
|
||||||
|
if search_request.uploader_id and str(image.uploader_id) != search_request.uploader_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered_images.append(image)
|
||||||
|
|
||||||
|
return filtered_images
|
||||||
|
|
||||||
|
def _convert_images_to_search_results(
|
||||||
|
self,
|
||||||
|
images: List,
|
||||||
|
scores: Dict[str, float],
|
||||||
|
request: Request
|
||||||
|
) -> List[ImageResponse]:
|
||||||
|
"""Convert images to search result format with similarity scores"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for image in images:
|
||||||
|
image_id = str(image.id)
|
||||||
|
similarity_score = scores.get(image_id, 0.0)
|
||||||
|
|
||||||
|
result = ImageResponse(
|
||||||
|
id=image_id,
|
||||||
|
filename=image.filename,
|
||||||
|
original_filename=image.original_filename,
|
||||||
|
file_size=image.file_size,
|
||||||
|
content_type=image.content_type,
|
||||||
|
storage_path=image.storage_path,
|
||||||
|
public_url=self._generate_api_download_url(request, image_id),
|
||||||
|
team_id=str(image.team_id),
|
||||||
|
uploader_id=str(image.uploader_id),
|
||||||
|
upload_date=image.upload_date,
|
||||||
|
description=image.description,
|
||||||
|
metadata=image.metadata,
|
||||||
|
has_embedding=image.has_embedding,
|
||||||
|
collection_id=str(image.collection_id) if image.collection_id else None,
|
||||||
|
similarity_score=similarity_score
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
176
src/services/team_service.py
Normal file
176
src/services/team_service.py
Normal file
@ -0,0 +1,176 @@
|
|||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db.repositories.team_repository import team_repository
|
||||||
|
from src.schemas.team import TeamCreate, TeamUpdate, TeamResponse, TeamListResponse
|
||||||
|
from src.models.team import TeamModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class TeamService:
|
||||||
|
"""Service class for handling team-related business logic"""
|
||||||
|
|
||||||
|
async def create_team(self, team_data: TeamCreate) -> TeamResponse:
|
||||||
|
"""
|
||||||
|
Create a new team
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_data: The team creation data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TeamResponse: The created team
|
||||||
|
"""
|
||||||
|
# Create team
|
||||||
|
team = TeamModel(
|
||||||
|
name=team_data.name,
|
||||||
|
description=team_data.description
|
||||||
|
)
|
||||||
|
|
||||||
|
created_team = await team_repository.create(team)
|
||||||
|
|
||||||
|
# Convert to response model
|
||||||
|
return TeamResponse(
|
||||||
|
id=str(created_team.id),
|
||||||
|
name=created_team.name,
|
||||||
|
description=created_team.description,
|
||||||
|
created_at=created_team.created_at,
|
||||||
|
updated_at=created_team.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_teams(self) -> TeamListResponse:
|
||||||
|
"""
|
||||||
|
List all teams
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TeamListResponse: List of all teams
|
||||||
|
"""
|
||||||
|
# Get all teams
|
||||||
|
teams = await team_repository.get_all()
|
||||||
|
|
||||||
|
# Convert to response models
|
||||||
|
response_teams = []
|
||||||
|
for team in teams:
|
||||||
|
response_teams.append(TeamResponse(
|
||||||
|
id=str(team.id),
|
||||||
|
name=team.name,
|
||||||
|
description=team.description,
|
||||||
|
created_at=team.created_at,
|
||||||
|
updated_at=team.updated_at
|
||||||
|
))
|
||||||
|
|
||||||
|
return TeamListResponse(teams=response_teams, total=len(response_teams))
|
||||||
|
|
||||||
|
async def get_team(self, team_id: str) -> TeamResponse:
|
||||||
|
"""
|
||||||
|
Get a team by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: The team ID to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TeamResponse: The team data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If team_id is invalid
|
||||||
|
RuntimeError: If team not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(team_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid team ID")
|
||||||
|
|
||||||
|
# Get the team
|
||||||
|
team = await team_repository.get_by_id(obj_id)
|
||||||
|
if not team:
|
||||||
|
raise RuntimeError("Team not found")
|
||||||
|
|
||||||
|
# Convert to response model
|
||||||
|
return TeamResponse(
|
||||||
|
id=str(team.id),
|
||||||
|
name=team.name,
|
||||||
|
description=team.description,
|
||||||
|
created_at=team.created_at,
|
||||||
|
updated_at=team.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_team(self, team_id: str, team_data: TeamUpdate) -> TeamResponse:
|
||||||
|
"""
|
||||||
|
Update a team
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: The team ID to update
|
||||||
|
team_data: The update data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TeamResponse: The updated team
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If team_id is invalid
|
||||||
|
RuntimeError: If team not found or update fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(team_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid team ID")
|
||||||
|
|
||||||
|
# Get the team
|
||||||
|
team = await team_repository.get_by_id(obj_id)
|
||||||
|
if not team:
|
||||||
|
raise RuntimeError("Team not found")
|
||||||
|
|
||||||
|
# Update the team
|
||||||
|
update_data = team_data.dict(exclude_unset=True)
|
||||||
|
if not update_data:
|
||||||
|
# No fields to update
|
||||||
|
return TeamResponse(
|
||||||
|
id=str(team.id),
|
||||||
|
name=team.name,
|
||||||
|
description=team.description,
|
||||||
|
created_at=team.created_at,
|
||||||
|
updated_at=team.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_team = await team_repository.update(obj_id, update_data)
|
||||||
|
if not updated_team:
|
||||||
|
raise RuntimeError("Failed to update team")
|
||||||
|
|
||||||
|
# Convert to response model
|
||||||
|
return TeamResponse(
|
||||||
|
id=str(updated_team.id),
|
||||||
|
name=updated_team.name,
|
||||||
|
description=updated_team.description,
|
||||||
|
created_at=updated_team.created_at,
|
||||||
|
updated_at=updated_team.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_team(self, team_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a team
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: The team ID to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successfully deleted
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If team_id is invalid
|
||||||
|
RuntimeError: If team not found or deletion fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(team_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid team ID")
|
||||||
|
|
||||||
|
# Get the team
|
||||||
|
team = await team_repository.get_by_id(obj_id)
|
||||||
|
if not team:
|
||||||
|
raise RuntimeError("Team not found")
|
||||||
|
|
||||||
|
# Delete the team
|
||||||
|
success = await team_repository.delete(obj_id)
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError("Failed to delete team")
|
||||||
|
|
||||||
|
return True
|
||||||
310
src/services/user_service.py
Normal file
310
src/services/user_service.py
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
import logging
|
||||||
|
from typing import Optional, List
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
from src.db.repositories.user_repository import user_repository
|
||||||
|
from src.db.repositories.team_repository import team_repository
|
||||||
|
from src.models.user import UserModel
|
||||||
|
from src.schemas.user import UserResponse, UserListResponse, UserCreate, UserUpdate
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class UserService:
|
||||||
|
"""Service class for handling user-related business logic"""
|
||||||
|
|
||||||
|
async def get_user_by_id(self, user_id: str) -> UserResponse:
|
||||||
|
"""
|
||||||
|
Get user information by user ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserResponse: The user data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If user_id is invalid
|
||||||
|
RuntimeError: If user not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(user_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid user ID")
|
||||||
|
|
||||||
|
user = await user_repository.get_by_id(obj_id)
|
||||||
|
if not user:
|
||||||
|
raise RuntimeError("User not found")
|
||||||
|
|
||||||
|
return UserResponse(
|
||||||
|
id=str(user.id),
|
||||||
|
name=user.name,
|
||||||
|
email=user.email,
|
||||||
|
team_id=str(user.team_id),
|
||||||
|
is_admin=user.is_admin,
|
||||||
|
is_active=user.is_active,
|
||||||
|
created_at=user.created_at,
|
||||||
|
updated_at=user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_user_by_id(self, user_id: str, user_data: UserUpdate) -> UserResponse:
|
||||||
|
"""
|
||||||
|
Update user information by user ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to update
|
||||||
|
user_data: The update data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserResponse: The updated user data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If user_id is invalid
|
||||||
|
RuntimeError: If user not found or update fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(user_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid user ID")
|
||||||
|
|
||||||
|
user = await user_repository.get_by_id(obj_id)
|
||||||
|
if not user:
|
||||||
|
raise RuntimeError("User not found")
|
||||||
|
|
||||||
|
# Update user
|
||||||
|
update_data = user_data.dict(exclude_unset=True)
|
||||||
|
if not update_data:
|
||||||
|
# No fields to update
|
||||||
|
return UserResponse(
|
||||||
|
id=str(user.id),
|
||||||
|
name=user.name,
|
||||||
|
email=user.email,
|
||||||
|
team_id=str(user.team_id),
|
||||||
|
is_admin=user.is_admin,
|
||||||
|
is_active=user.is_active,
|
||||||
|
created_at=user.created_at,
|
||||||
|
updated_at=user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_user = await user_repository.update(user.id, update_data)
|
||||||
|
if not updated_user:
|
||||||
|
raise RuntimeError("Failed to update user")
|
||||||
|
|
||||||
|
return UserResponse(
|
||||||
|
id=str(updated_user.id),
|
||||||
|
name=updated_user.name,
|
||||||
|
email=updated_user.email,
|
||||||
|
team_id=str(updated_user.team_id),
|
||||||
|
is_admin=updated_user.is_admin,
|
||||||
|
is_active=updated_user.is_active,
|
||||||
|
created_at=updated_user.created_at,
|
||||||
|
updated_at=updated_user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
async def create_user(self, user_data: UserCreate) -> UserResponse:
|
||||||
|
"""
|
||||||
|
Create a new user
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_data: The user creation data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserResponse: The created user
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If validation fails
|
||||||
|
RuntimeError: If team not found or user creation fails
|
||||||
|
"""
|
||||||
|
# Check if user with email already exists
|
||||||
|
existing_user = await user_repository.get_by_email(user_data.email)
|
||||||
|
if existing_user:
|
||||||
|
raise ValueError("User with this email already exists")
|
||||||
|
|
||||||
|
# Validate team exists if specified
|
||||||
|
if user_data.team_id:
|
||||||
|
team = await team_repository.get_by_id(ObjectId(user_data.team_id))
|
||||||
|
if not team:
|
||||||
|
raise RuntimeError("Team not found")
|
||||||
|
team_id = user_data.team_id
|
||||||
|
else:
|
||||||
|
raise ValueError("Team ID is required")
|
||||||
|
|
||||||
|
# Create user
|
||||||
|
user = UserModel(
|
||||||
|
name=user_data.name,
|
||||||
|
email=user_data.email,
|
||||||
|
team_id=ObjectId(team_id),
|
||||||
|
is_admin=user_data.is_admin or False,
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
|
||||||
|
created_user = await user_repository.create(user)
|
||||||
|
|
||||||
|
return UserResponse(
|
||||||
|
id=str(created_user.id),
|
||||||
|
name=created_user.name,
|
||||||
|
email=created_user.email,
|
||||||
|
team_id=str(created_user.team_id),
|
||||||
|
is_admin=created_user.is_admin,
|
||||||
|
is_active=created_user.is_active,
|
||||||
|
created_at=created_user.created_at,
|
||||||
|
updated_at=created_user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_users(self, team_id: Optional[str] = None) -> UserListResponse:
|
||||||
|
"""
|
||||||
|
List users, optionally filtered by team
|
||||||
|
|
||||||
|
Args:
|
||||||
|
team_id: Optional team ID to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserListResponse: List of users
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If team_id is invalid
|
||||||
|
"""
|
||||||
|
# Get users
|
||||||
|
if team_id:
|
||||||
|
try:
|
||||||
|
filter_team_id = ObjectId(team_id)
|
||||||
|
users = await user_repository.get_by_team(filter_team_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid team ID")
|
||||||
|
else:
|
||||||
|
users = await user_repository.get_all()
|
||||||
|
|
||||||
|
# Convert to response
|
||||||
|
response_users = []
|
||||||
|
for user in users:
|
||||||
|
response_users.append(UserResponse(
|
||||||
|
id=str(user.id),
|
||||||
|
name=user.name,
|
||||||
|
email=user.email,
|
||||||
|
team_id=str(user.team_id),
|
||||||
|
is_admin=user.is_admin,
|
||||||
|
is_active=user.is_active,
|
||||||
|
created_at=user.created_at,
|
||||||
|
updated_at=user.updated_at
|
||||||
|
))
|
||||||
|
|
||||||
|
return UserListResponse(users=response_users, total=len(response_users))
|
||||||
|
|
||||||
|
async def get_user(self, user_id: str) -> UserResponse:
|
||||||
|
"""
|
||||||
|
Get user by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserResponse: The user data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If user_id is invalid
|
||||||
|
RuntimeError: If user not found
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(user_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid user ID")
|
||||||
|
|
||||||
|
user = await user_repository.get_by_id(obj_id)
|
||||||
|
if not user:
|
||||||
|
raise RuntimeError("User not found")
|
||||||
|
|
||||||
|
return UserResponse(
|
||||||
|
id=str(user.id),
|
||||||
|
name=user.name,
|
||||||
|
email=user.email,
|
||||||
|
team_id=str(user.team_id),
|
||||||
|
is_admin=user.is_admin,
|
||||||
|
is_active=user.is_active,
|
||||||
|
created_at=user.created_at,
|
||||||
|
updated_at=user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_user(self, user_id: str, user_data: UserUpdate) -> UserResponse:
|
||||||
|
"""
|
||||||
|
Update user by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to update
|
||||||
|
user_data: The update data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UserResponse: The updated user data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If user_id is invalid
|
||||||
|
RuntimeError: If user not found or update fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(user_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid user ID")
|
||||||
|
|
||||||
|
# Check if user exists
|
||||||
|
user = await user_repository.get_by_id(obj_id)
|
||||||
|
if not user:
|
||||||
|
raise RuntimeError("User not found")
|
||||||
|
|
||||||
|
# Update user
|
||||||
|
update_data = user_data.dict(exclude_unset=True)
|
||||||
|
if not update_data:
|
||||||
|
# No fields to update
|
||||||
|
return UserResponse(
|
||||||
|
id=str(user.id),
|
||||||
|
name=user.name,
|
||||||
|
email=user.email,
|
||||||
|
team_id=str(user.team_id),
|
||||||
|
is_admin=user.is_admin,
|
||||||
|
is_active=user.is_active,
|
||||||
|
created_at=user.created_at,
|
||||||
|
updated_at=user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_user = await user_repository.update(obj_id, update_data)
|
||||||
|
if not updated_user:
|
||||||
|
raise RuntimeError("Failed to update user")
|
||||||
|
|
||||||
|
return UserResponse(
|
||||||
|
id=str(updated_user.id),
|
||||||
|
name=updated_user.name,
|
||||||
|
email=updated_user.email,
|
||||||
|
team_id=str(updated_user.team_id),
|
||||||
|
is_admin=updated_user.is_admin,
|
||||||
|
is_active=updated_user.is_active,
|
||||||
|
created_at=updated_user.created_at,
|
||||||
|
updated_at=updated_user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_user(self, user_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete user by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user ID to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successfully deleted
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If user_id is invalid
|
||||||
|
RuntimeError: If user not found or deletion fails
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(user_id)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("Invalid user ID")
|
||||||
|
|
||||||
|
# Check if user exists
|
||||||
|
user = await user_repository.get_by_id(obj_id)
|
||||||
|
if not user:
|
||||||
|
raise RuntimeError("User not found")
|
||||||
|
|
||||||
|
# Delete user
|
||||||
|
success = await user_repository.delete(obj_id)
|
||||||
|
if not success:
|
||||||
|
raise RuntimeError("Failed to delete user")
|
||||||
|
|
||||||
|
return True
|
||||||
Loading…
x
Reference in New Issue
Block a user