cp
This commit is contained in:
parent
1010ed8d4e
commit
d8aec1e6b4
@ -1,13 +1,398 @@
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends
|
from typing import Optional, List
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, Query, Request, Response
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from bson import ObjectId
|
||||||
|
import io
|
||||||
|
|
||||||
from src.api.v1.auth import get_current_user
|
from src.api.v1.auth import get_current_user
|
||||||
|
from src.db.repositories.image_repository import image_repository
|
||||||
|
from src.services.storage import StorageService
|
||||||
|
from src.services.image_processor import ImageProcessor
|
||||||
|
from src.services.embedding_service import EmbeddingService
|
||||||
|
from src.models.image import ImageModel
|
||||||
|
from src.models.user import UserModel
|
||||||
|
from src.schemas.image import ImageResponse, ImageListResponse, ImageCreate, ImageUpdate
|
||||||
|
from src.utils.logging import log_request
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(tags=["Images"], prefix="/images")
|
router = APIRouter(tags=["Images"], prefix="/images")
|
||||||
|
|
||||||
@router.get("")
|
# Initialize services
|
||||||
async def list_images(current_user = Depends(get_current_user)):
|
storage_service = StorageService()
|
||||||
"""List images (placeholder endpoint)"""
|
image_processor = ImageProcessor()
|
||||||
return {"message": "Images listing functionality to be implemented"}
|
embedding_service = EmbeddingService()
|
||||||
|
|
||||||
|
@router.post("", response_model=ImageResponse, status_code=201)
|
||||||
|
async def upload_image(
|
||||||
|
request: Request,
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
description: Optional[str] = None,
|
||||||
|
tags: Optional[str] = None,
|
||||||
|
collection_id: Optional[str] = None,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Upload a new image
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "filename": file.filename},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate file type
|
||||||
|
if not file.content_type or not file.content_type.startswith('image/'):
|
||||||
|
raise HTTPException(status_code=400, detail="File must be an image")
|
||||||
|
|
||||||
|
# Validate file size (10MB limit)
|
||||||
|
max_size = 10 * 1024 * 1024 # 10MB
|
||||||
|
content = await file.read()
|
||||||
|
if len(content) > max_size:
|
||||||
|
raise HTTPException(status_code=400, detail="File size exceeds 10MB limit")
|
||||||
|
|
||||||
|
# Reset file pointer
|
||||||
|
await file.seek(0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Upload to storage
|
||||||
|
storage_path, content_type, file_size, metadata = await storage_service.upload_file(
|
||||||
|
file, str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process tags
|
||||||
|
tag_list = []
|
||||||
|
if tags:
|
||||||
|
tag_list = [tag.strip() for tag in tags.split(',') if tag.strip()]
|
||||||
|
|
||||||
|
# Create image record
|
||||||
|
image = ImageModel(
|
||||||
|
filename=file.filename,
|
||||||
|
original_filename=file.filename,
|
||||||
|
file_size=file_size,
|
||||||
|
content_type=content_type,
|
||||||
|
storage_path=storage_path,
|
||||||
|
team_id=current_user.team_id,
|
||||||
|
uploader_id=current_user.id,
|
||||||
|
description=description,
|
||||||
|
tags=tag_list,
|
||||||
|
metadata=metadata,
|
||||||
|
collection_id=ObjectId(collection_id) if collection_id else None
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save to database
|
||||||
|
created_image = await image_repository.create(image)
|
||||||
|
|
||||||
|
# Start async processing for embeddings (in background)
|
||||||
|
try:
|
||||||
|
await embedding_service.process_image_async(str(created_image.id), storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to start embedding processing: {e}")
|
||||||
|
|
||||||
|
# Convert to response
|
||||||
|
response = ImageResponse(
|
||||||
|
id=str(created_image.id),
|
||||||
|
filename=created_image.filename,
|
||||||
|
original_filename=created_image.original_filename,
|
||||||
|
file_size=created_image.file_size,
|
||||||
|
content_type=created_image.content_type,
|
||||||
|
storage_path=created_image.storage_path,
|
||||||
|
team_id=str(created_image.team_id),
|
||||||
|
uploader_id=str(created_image.uploader_id),
|
||||||
|
upload_date=created_image.upload_date,
|
||||||
|
description=created_image.description,
|
||||||
|
tags=created_image.tags,
|
||||||
|
metadata=created_image.metadata,
|
||||||
|
has_embedding=created_image.has_embedding,
|
||||||
|
collection_id=str(created_image.collection_id) if created_image.collection_id else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error uploading image: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to upload image")
|
||||||
|
|
||||||
|
@router.get("", response_model=ImageListResponse)
|
||||||
|
async def list_images(
|
||||||
|
request: Request,
|
||||||
|
skip: int = Query(0, ge=0),
|
||||||
|
limit: int = Query(50, ge=1, le=100),
|
||||||
|
collection_id: Optional[str] = None,
|
||||||
|
tags: Optional[str] = None,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List images for the current user's team
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "skip": skip, "limit": limit},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse tags filter
|
||||||
|
tag_filter = []
|
||||||
|
if tags:
|
||||||
|
tag_filter = [tag.strip() for tag in tags.split(',') if tag.strip()]
|
||||||
|
|
||||||
|
# Get images
|
||||||
|
images = await image_repository.get_by_team(
|
||||||
|
current_user.team_id,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||||
|
tags=tag_filter
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get total count
|
||||||
|
total = await image_repository.count_by_team(
|
||||||
|
current_user.team_id,
|
||||||
|
collection_id=ObjectId(collection_id) if collection_id else None,
|
||||||
|
tags=tag_filter
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to response
|
||||||
|
response_images = []
|
||||||
|
for image in images:
|
||||||
|
response_images.append(ImageResponse(
|
||||||
|
id=str(image.id),
|
||||||
|
filename=image.filename,
|
||||||
|
original_filename=image.original_filename,
|
||||||
|
file_size=image.file_size,
|
||||||
|
content_type=image.content_type,
|
||||||
|
storage_path=image.storage_path,
|
||||||
|
team_id=str(image.team_id),
|
||||||
|
uploader_id=str(image.uploader_id),
|
||||||
|
upload_date=image.upload_date,
|
||||||
|
description=image.description,
|
||||||
|
tags=image.tags,
|
||||||
|
metadata=image.metadata,
|
||||||
|
has_embedding=image.has_embedding,
|
||||||
|
collection_id=str(image.collection_id) if image.collection_id else None
|
||||||
|
))
|
||||||
|
|
||||||
|
return ImageListResponse(images=response_images, total=total, skip=skip, limit=limit)
|
||||||
|
|
||||||
|
@router.get("/{image_id}", response_model=ImageResponse)
|
||||||
|
async def get_image(
|
||||||
|
image_id: str,
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get image metadata by ID
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "image_id": image_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(image_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid image ID")
|
||||||
|
|
||||||
|
# Get image
|
||||||
|
image = await image_repository.get_by_id(obj_id)
|
||||||
|
if not image:
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
|
# Check team access
|
||||||
|
if image.team_id != current_user.team_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized to access this image")
|
||||||
|
|
||||||
|
# Update last accessed
|
||||||
|
await image_repository.update_last_accessed(obj_id)
|
||||||
|
|
||||||
|
# Convert to response
|
||||||
|
response = ImageResponse(
|
||||||
|
id=str(image.id),
|
||||||
|
filename=image.filename,
|
||||||
|
original_filename=image.original_filename,
|
||||||
|
file_size=image.file_size,
|
||||||
|
content_type=image.content_type,
|
||||||
|
storage_path=image.storage_path,
|
||||||
|
team_id=str(image.team_id),
|
||||||
|
uploader_id=str(image.uploader_id),
|
||||||
|
upload_date=image.upload_date,
|
||||||
|
description=image.description,
|
||||||
|
tags=image.tags,
|
||||||
|
metadata=image.metadata,
|
||||||
|
has_embedding=image.has_embedding,
|
||||||
|
collection_id=str(image.collection_id) if image.collection_id else None,
|
||||||
|
last_accessed=image.last_accessed
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.get("/{image_id}/download")
|
||||||
|
async def download_image(
|
||||||
|
image_id: str,
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Download image file
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "image_id": image_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(image_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid image ID")
|
||||||
|
|
||||||
|
# Get image
|
||||||
|
image = await image_repository.get_by_id(obj_id)
|
||||||
|
if not image:
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
|
# Check team access
|
||||||
|
if image.team_id != current_user.team_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized to access this image")
|
||||||
|
|
||||||
|
# Get file from storage
|
||||||
|
file_content = storage_service.get_file(image.storage_path)
|
||||||
|
if not file_content:
|
||||||
|
raise HTTPException(status_code=404, detail="Image file not found in storage")
|
||||||
|
|
||||||
|
# Update last accessed
|
||||||
|
await image_repository.update_last_accessed(obj_id)
|
||||||
|
|
||||||
|
# Return file as streaming response
|
||||||
|
return StreamingResponse(
|
||||||
|
io.BytesIO(file_content),
|
||||||
|
media_type=image.content_type,
|
||||||
|
headers={"Content-Disposition": f"attachment; filename={image.original_filename}"}
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.put("/{image_id}", response_model=ImageResponse)
|
||||||
|
async def update_image(
|
||||||
|
image_id: str,
|
||||||
|
image_data: ImageUpdate,
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update image metadata
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "image_id": image_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(image_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid image ID")
|
||||||
|
|
||||||
|
# Get image
|
||||||
|
image = await image_repository.get_by_id(obj_id)
|
||||||
|
if not image:
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
|
# Check team access
|
||||||
|
if image.team_id != current_user.team_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized to update this image")
|
||||||
|
|
||||||
|
# Update image
|
||||||
|
update_data = image_data.dict(exclude_unset=True)
|
||||||
|
if not update_data:
|
||||||
|
# No fields to update
|
||||||
|
response = ImageResponse(
|
||||||
|
id=str(image.id),
|
||||||
|
filename=image.filename,
|
||||||
|
original_filename=image.original_filename,
|
||||||
|
file_size=image.file_size,
|
||||||
|
content_type=image.content_type,
|
||||||
|
storage_path=image.storage_path,
|
||||||
|
team_id=str(image.team_id),
|
||||||
|
uploader_id=str(image.uploader_id),
|
||||||
|
upload_date=image.upload_date,
|
||||||
|
description=image.description,
|
||||||
|
tags=image.tags,
|
||||||
|
metadata=image.metadata,
|
||||||
|
has_embedding=image.has_embedding,
|
||||||
|
collection_id=str(image.collection_id) if image.collection_id else None
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
updated_image = await image_repository.update(obj_id, update_data)
|
||||||
|
if not updated_image:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update image")
|
||||||
|
|
||||||
|
# Convert to response
|
||||||
|
response = ImageResponse(
|
||||||
|
id=str(updated_image.id),
|
||||||
|
filename=updated_image.filename,
|
||||||
|
original_filename=updated_image.original_filename,
|
||||||
|
file_size=updated_image.file_size,
|
||||||
|
content_type=updated_image.content_type,
|
||||||
|
storage_path=updated_image.storage_path,
|
||||||
|
team_id=str(updated_image.team_id),
|
||||||
|
uploader_id=str(updated_image.uploader_id),
|
||||||
|
upload_date=updated_image.upload_date,
|
||||||
|
description=updated_image.description,
|
||||||
|
tags=updated_image.tags,
|
||||||
|
metadata=updated_image.metadata,
|
||||||
|
has_embedding=updated_image.has_embedding,
|
||||||
|
collection_id=str(updated_image.collection_id) if updated_image.collection_id else None
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.delete("/{image_id}", status_code=204)
|
||||||
|
async def delete_image(
|
||||||
|
image_id: str,
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete an image
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "image_id": image_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(image_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid image ID")
|
||||||
|
|
||||||
|
# Get image
|
||||||
|
image = await image_repository.get_by_id(obj_id)
|
||||||
|
if not image:
|
||||||
|
raise HTTPException(status_code=404, detail="Image not found")
|
||||||
|
|
||||||
|
# Check team access
|
||||||
|
if image.team_id != current_user.team_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized to delete this image")
|
||||||
|
|
||||||
|
# Delete from storage
|
||||||
|
try:
|
||||||
|
storage_service.delete_file(image.storage_path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete file from storage: {e}")
|
||||||
|
|
||||||
|
# Delete from vector database if it has embeddings
|
||||||
|
if image.has_embedding and image.embedding_id:
|
||||||
|
try:
|
||||||
|
await embedding_service.delete_embedding(image.embedding_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete embedding: {e}")
|
||||||
|
|
||||||
|
# Delete from database
|
||||||
|
result = await image_repository.delete(obj_id)
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete image")
|
||||||
|
|
||||||
|
return None
|
||||||
@ -1,16 +1,362 @@
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends, Query
|
from typing import Optional, List
|
||||||
|
from fastapi import APIRouter, Depends, Query, Request, HTTPException
|
||||||
|
|
||||||
from src.api.v1.auth import get_current_user
|
from src.api.v1.auth import get_current_user
|
||||||
|
from src.db.repositories.image_repository import image_repository
|
||||||
|
from src.services.vector_store import VectorStoreService
|
||||||
|
from src.services.embedding_service import EmbeddingService
|
||||||
|
from src.models.user import UserModel
|
||||||
|
from src.schemas.image import ImageResponse
|
||||||
|
from src.schemas.search import SearchResponse, SearchRequest
|
||||||
|
from src.utils.logging import log_request
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(tags=["Search"], prefix="/search")
|
router = APIRouter(tags=["Search"], prefix="/search")
|
||||||
|
|
||||||
@router.get("")
|
# Initialize services
|
||||||
|
vector_store_service = VectorStoreService()
|
||||||
|
embedding_service = EmbeddingService()
|
||||||
|
|
||||||
|
@router.get("", response_model=SearchResponse)
|
||||||
async def search_images(
|
async def search_images(
|
||||||
|
request: Request,
|
||||||
q: str = Query(..., description="Search query"),
|
q: str = Query(..., description="Search query"),
|
||||||
current_user = Depends(get_current_user)
|
limit: int = Query(10, ge=1, le=50, description="Number of results to return"),
|
||||||
|
threshold: float = Query(0.7, ge=0.0, le=1.0, description="Similarity threshold"),
|
||||||
|
collection_id: Optional[str] = Query(None, description="Filter by collection ID"),
|
||||||
|
tags: Optional[str] = Query(None, description="Filter by tags (comma-separated)"),
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
):
|
):
|
||||||
"""Search for images (placeholder endpoint)"""
|
"""
|
||||||
return {"message": "Search functionality to be implemented", "query": q}
|
Search for images using semantic similarity
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{
|
||||||
|
"path": request.url.path,
|
||||||
|
"method": request.method,
|
||||||
|
"query": q,
|
||||||
|
"limit": limit,
|
||||||
|
"threshold": threshold
|
||||||
|
},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate embedding for the search query
|
||||||
|
query_embedding = await embedding_service.generate_text_embedding(q)
|
||||||
|
if not query_embedding:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
||||||
|
|
||||||
|
# Search in vector database
|
||||||
|
search_results = await vector_store_service.search_similar(
|
||||||
|
query_embedding,
|
||||||
|
limit=limit,
|
||||||
|
threshold=threshold,
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not search_results:
|
||||||
|
return SearchResponse(
|
||||||
|
query=q,
|
||||||
|
results=[],
|
||||||
|
total=0,
|
||||||
|
limit=limit,
|
||||||
|
threshold=threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get image IDs and scores from search results
|
||||||
|
image_ids = [result['id'] for result in search_results]
|
||||||
|
scores = {result['id']: result['score'] for result in search_results}
|
||||||
|
|
||||||
|
# Get image metadata from database
|
||||||
|
images = await image_repository.get_by_ids(image_ids)
|
||||||
|
|
||||||
|
# Filter by collection and tags if specified
|
||||||
|
filtered_images = []
|
||||||
|
for image in images:
|
||||||
|
# Check collection filter
|
||||||
|
if collection_id and str(image.collection_id) != collection_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check tags filter
|
||||||
|
if tags:
|
||||||
|
tag_filter = [tag.strip() for tag in tags.split(',') if tag.strip()]
|
||||||
|
if not any(tag in image.tags for tag in tag_filter):
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered_images.append(image)
|
||||||
|
|
||||||
|
# Convert to response format with similarity scores
|
||||||
|
results = []
|
||||||
|
for image in filtered_images:
|
||||||
|
image_id = str(image.id)
|
||||||
|
similarity_score = scores.get(image_id, 0.0)
|
||||||
|
|
||||||
|
result = ImageResponse(
|
||||||
|
id=image_id,
|
||||||
|
filename=image.filename,
|
||||||
|
original_filename=image.original_filename,
|
||||||
|
file_size=image.file_size,
|
||||||
|
content_type=image.content_type,
|
||||||
|
storage_path=image.storage_path,
|
||||||
|
team_id=str(image.team_id),
|
||||||
|
uploader_id=str(image.uploader_id),
|
||||||
|
upload_date=image.upload_date,
|
||||||
|
description=image.description,
|
||||||
|
tags=image.tags,
|
||||||
|
metadata=image.metadata,
|
||||||
|
has_embedding=image.has_embedding,
|
||||||
|
collection_id=str(image.collection_id) if image.collection_id else None,
|
||||||
|
similarity_score=similarity_score
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Sort by similarity score (highest first)
|
||||||
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||||
|
|
||||||
|
return SearchResponse(
|
||||||
|
query=q,
|
||||||
|
results=results,
|
||||||
|
total=len(results),
|
||||||
|
limit=limit,
|
||||||
|
threshold=threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error searching images: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Search failed")
|
||||||
|
|
||||||
|
@router.post("", response_model=SearchResponse)
|
||||||
|
async def search_images_advanced(
|
||||||
|
search_request: SearchRequest,
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Advanced search for images with more options
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{
|
||||||
|
"path": request.url.path,
|
||||||
|
"method": request.method,
|
||||||
|
"search_request": search_request.dict()
|
||||||
|
},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Generate embedding for the search query
|
||||||
|
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
|
||||||
|
if not query_embedding:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
||||||
|
|
||||||
|
# Search in vector database
|
||||||
|
search_results = await vector_store_service.search_similar(
|
||||||
|
query_embedding,
|
||||||
|
limit=search_request.limit,
|
||||||
|
threshold=search_request.threshold,
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not search_results:
|
||||||
|
return SearchResponse(
|
||||||
|
query=search_request.query,
|
||||||
|
results=[],
|
||||||
|
total=0,
|
||||||
|
limit=search_request.limit,
|
||||||
|
threshold=search_request.threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get image IDs and scores from search results
|
||||||
|
image_ids = [result['id'] for result in search_results]
|
||||||
|
scores = {result['id']: result['score'] for result in search_results}
|
||||||
|
|
||||||
|
# Get image metadata from database
|
||||||
|
images = await image_repository.get_by_ids(image_ids)
|
||||||
|
|
||||||
|
# Apply filters
|
||||||
|
filtered_images = []
|
||||||
|
for image in images:
|
||||||
|
# Check collection filter
|
||||||
|
if search_request.collection_id and str(image.collection_id) != search_request.collection_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check tags filter
|
||||||
|
if search_request.tags:
|
||||||
|
if not any(tag in image.tags for tag in search_request.tags):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check date range filter
|
||||||
|
if search_request.date_from and image.upload_date < search_request.date_from:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if search_request.date_to and image.upload_date > search_request.date_to:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check uploader filter
|
||||||
|
if search_request.uploader_id and str(image.uploader_id) != search_request.uploader_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered_images.append(image)
|
||||||
|
|
||||||
|
# Convert to response format with similarity scores
|
||||||
|
results = []
|
||||||
|
for image in filtered_images:
|
||||||
|
image_id = str(image.id)
|
||||||
|
similarity_score = scores.get(image_id, 0.0)
|
||||||
|
|
||||||
|
result = ImageResponse(
|
||||||
|
id=image_id,
|
||||||
|
filename=image.filename,
|
||||||
|
original_filename=image.original_filename,
|
||||||
|
file_size=image.file_size,
|
||||||
|
content_type=image.content_type,
|
||||||
|
storage_path=image.storage_path,
|
||||||
|
team_id=str(image.team_id),
|
||||||
|
uploader_id=str(image.uploader_id),
|
||||||
|
upload_date=image.upload_date,
|
||||||
|
description=image.description,
|
||||||
|
tags=image.tags,
|
||||||
|
metadata=image.metadata,
|
||||||
|
has_embedding=image.has_embedding,
|
||||||
|
collection_id=str(image.collection_id) if image.collection_id else None,
|
||||||
|
similarity_score=similarity_score
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Sort by similarity score (highest first)
|
||||||
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||||
|
|
||||||
|
return SearchResponse(
|
||||||
|
query=search_request.query,
|
||||||
|
results=results,
|
||||||
|
total=len(results),
|
||||||
|
limit=search_request.limit,
|
||||||
|
threshold=search_request.threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in advanced search: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Advanced search failed")
|
||||||
|
|
||||||
|
@router.get("/similar/{image_id}", response_model=SearchResponse)
|
||||||
|
async def find_similar_images(
|
||||||
|
image_id: str,
|
||||||
|
request: Request,
|
||||||
|
limit: int = Query(10, ge=1, le=50, description="Number of similar images to return"),
|
||||||
|
threshold: float = Query(0.7, ge=0.0, le=1.0, description="Similarity threshold"),
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Find images similar to a given image
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{
|
||||||
|
"path": request.url.path,
|
||||||
|
"method": request.method,
|
||||||
|
"image_id": image_id,
|
||||||
|
"limit": limit,
|
||||||
|
"threshold": threshold
|
||||||
|
},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
|
# Get the reference image
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(image_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid image ID")
|
||||||
|
|
||||||
|
reference_image = await image_repository.get_by_id(obj_id)
|
||||||
|
if not reference_image:
|
||||||
|
raise HTTPException(status_code=404, detail="Reference image not found")
|
||||||
|
|
||||||
|
# Check team access
|
||||||
|
if reference_image.team_id != current_user.team_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized to access this image")
|
||||||
|
|
||||||
|
# Check if the image has embeddings
|
||||||
|
if not reference_image.has_embedding or not reference_image.embedding_id:
|
||||||
|
raise HTTPException(status_code=400, detail="Reference image does not have embeddings")
|
||||||
|
|
||||||
|
# Get the embedding for the reference image
|
||||||
|
reference_embedding = await vector_store_service.get_embedding(reference_image.embedding_id)
|
||||||
|
if not reference_embedding:
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to get reference image embedding")
|
||||||
|
|
||||||
|
# Search for similar images
|
||||||
|
search_results = await vector_store_service.search_similar(
|
||||||
|
reference_embedding,
|
||||||
|
limit=limit + 1, # +1 to account for the reference image itself
|
||||||
|
threshold=threshold,
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove the reference image from results
|
||||||
|
search_results = [result for result in search_results if result['id'] != image_id][:limit]
|
||||||
|
|
||||||
|
if not search_results:
|
||||||
|
return SearchResponse(
|
||||||
|
query=f"Similar to image {image_id}",
|
||||||
|
results=[],
|
||||||
|
total=0,
|
||||||
|
limit=limit,
|
||||||
|
threshold=threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get image IDs and scores from search results
|
||||||
|
image_ids = [result['id'] for result in search_results]
|
||||||
|
scores = {result['id']: result['score'] for result in search_results}
|
||||||
|
|
||||||
|
# Get image metadata from database
|
||||||
|
images = await image_repository.get_by_ids(image_ids)
|
||||||
|
|
||||||
|
# Convert to response format with similarity scores
|
||||||
|
results = []
|
||||||
|
for image in images:
|
||||||
|
image_id_str = str(image.id)
|
||||||
|
similarity_score = scores.get(image_id_str, 0.0)
|
||||||
|
|
||||||
|
result = ImageResponse(
|
||||||
|
id=image_id_str,
|
||||||
|
filename=image.filename,
|
||||||
|
original_filename=image.original_filename,
|
||||||
|
file_size=image.file_size,
|
||||||
|
content_type=image.content_type,
|
||||||
|
storage_path=image.storage_path,
|
||||||
|
team_id=str(image.team_id),
|
||||||
|
uploader_id=str(image.uploader_id),
|
||||||
|
upload_date=image.upload_date,
|
||||||
|
description=image.description,
|
||||||
|
tags=image.tags,
|
||||||
|
metadata=image.metadata,
|
||||||
|
has_embedding=image.has_embedding,
|
||||||
|
collection_id=str(image.collection_id) if image.collection_id else None,
|
||||||
|
similarity_score=similarity_score
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
# Sort by similarity score (highest first)
|
||||||
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||||
|
|
||||||
|
return SearchResponse(
|
||||||
|
query=f"Similar to image {image_id}",
|
||||||
|
results=results,
|
||||||
|
total=len(results),
|
||||||
|
limit=limit,
|
||||||
|
threshold=threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error finding similar images: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail="Similar image search failed")
|
||||||
@ -1,13 +1,342 @@
|
|||||||
import logging
|
import logging
|
||||||
from fastapi import APIRouter, Depends
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from bson import ObjectId
|
||||||
|
|
||||||
from src.api.v1.auth import get_current_user
|
from src.api.v1.auth import get_current_user
|
||||||
|
from src.db.repositories.user_repository import user_repository
|
||||||
|
from src.db.repositories.team_repository import team_repository
|
||||||
|
from src.models.user import UserModel
|
||||||
|
from src.schemas.user import UserResponse, UserListResponse, UserCreate, UserUpdate
|
||||||
|
from src.utils.logging import log_request
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(tags=["Users"], prefix="/users")
|
router = APIRouter(tags=["Users"], prefix="/users")
|
||||||
|
|
||||||
@router.get("/me")
|
@router.get("/me", response_model=UserResponse)
|
||||||
async def read_users_me(current_user = Depends(get_current_user)):
|
async def read_users_me(
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
"""Get current user information"""
|
"""Get current user information"""
|
||||||
return current_user
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
response = UserResponse(
|
||||||
|
id=str(current_user.id),
|
||||||
|
name=current_user.name,
|
||||||
|
email=current_user.email,
|
||||||
|
team_id=str(current_user.team_id),
|
||||||
|
is_admin=current_user.is_admin,
|
||||||
|
is_active=current_user.is_active,
|
||||||
|
created_at=current_user.created_at,
|
||||||
|
updated_at=current_user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.put("/me", response_model=UserResponse)
|
||||||
|
async def update_current_user(
|
||||||
|
user_data: UserUpdate,
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""Update current user information"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "user_data": user_data.dict()},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update user
|
||||||
|
update_data = user_data.dict(exclude_unset=True)
|
||||||
|
if not update_data:
|
||||||
|
# No fields to update
|
||||||
|
response = UserResponse(
|
||||||
|
id=str(current_user.id),
|
||||||
|
name=current_user.name,
|
||||||
|
email=current_user.email,
|
||||||
|
team_id=str(current_user.team_id),
|
||||||
|
is_admin=current_user.is_admin,
|
||||||
|
is_active=current_user.is_active,
|
||||||
|
created_at=current_user.created_at,
|
||||||
|
updated_at=current_user.updated_at
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
updated_user = await user_repository.update(current_user.id, update_data)
|
||||||
|
if not updated_user:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update user")
|
||||||
|
|
||||||
|
response = UserResponse(
|
||||||
|
id=str(updated_user.id),
|
||||||
|
name=updated_user.name,
|
||||||
|
email=updated_user.email,
|
||||||
|
team_id=str(updated_user.team_id),
|
||||||
|
is_admin=updated_user.is_admin,
|
||||||
|
is_active=updated_user.is_active,
|
||||||
|
created_at=updated_user.created_at,
|
||||||
|
updated_at=updated_user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.post("", response_model=UserResponse, status_code=201)
|
||||||
|
async def create_user(
|
||||||
|
user_data: UserCreate,
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new user
|
||||||
|
|
||||||
|
This endpoint requires admin privileges
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "user_data": user_data.dict()},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only admins can create users
|
||||||
|
if not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Only admins can create users")
|
||||||
|
|
||||||
|
# Check if user with email already exists
|
||||||
|
existing_user = await user_repository.get_by_email(user_data.email)
|
||||||
|
if existing_user:
|
||||||
|
raise HTTPException(status_code=400, detail="User with this email already exists")
|
||||||
|
|
||||||
|
# Validate team exists if specified
|
||||||
|
team_id = user_data.team_id or current_user.team_id
|
||||||
|
team = await team_repository.get_by_id(ObjectId(team_id))
|
||||||
|
if not team:
|
||||||
|
raise HTTPException(status_code=400, detail="Team not found")
|
||||||
|
|
||||||
|
# Create user
|
||||||
|
user = UserModel(
|
||||||
|
name=user_data.name,
|
||||||
|
email=user_data.email,
|
||||||
|
team_id=ObjectId(team_id),
|
||||||
|
is_admin=user_data.is_admin or False,
|
||||||
|
is_active=True
|
||||||
|
)
|
||||||
|
|
||||||
|
created_user = await user_repository.create(user)
|
||||||
|
|
||||||
|
response = UserResponse(
|
||||||
|
id=str(created_user.id),
|
||||||
|
name=created_user.name,
|
||||||
|
email=created_user.email,
|
||||||
|
team_id=str(created_user.team_id),
|
||||||
|
is_admin=created_user.is_admin,
|
||||||
|
is_active=created_user.is_active,
|
||||||
|
created_at=created_user.created_at,
|
||||||
|
updated_at=created_user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.get("", response_model=UserListResponse)
|
||||||
|
async def list_users(
|
||||||
|
request: Request,
|
||||||
|
team_id: Optional[str] = None,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List users
|
||||||
|
|
||||||
|
Admins can list all users or filter by team.
|
||||||
|
Non-admins can only list users from their own team.
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "team_id": team_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine which team to filter by
|
||||||
|
if current_user.is_admin:
|
||||||
|
# Admins can specify team_id or get all users
|
||||||
|
filter_team_id = ObjectId(team_id) if team_id else None
|
||||||
|
else:
|
||||||
|
# Non-admins can only see their own team
|
||||||
|
filter_team_id = current_user.team_id
|
||||||
|
|
||||||
|
# Get users
|
||||||
|
if filter_team_id:
|
||||||
|
users = await user_repository.get_by_team(filter_team_id)
|
||||||
|
else:
|
||||||
|
users = await user_repository.get_all()
|
||||||
|
|
||||||
|
# Convert to response
|
||||||
|
response_users = []
|
||||||
|
for user in users:
|
||||||
|
response_users.append(UserResponse(
|
||||||
|
id=str(user.id),
|
||||||
|
name=user.name,
|
||||||
|
email=user.email,
|
||||||
|
team_id=str(user.team_id),
|
||||||
|
is_admin=user.is_admin,
|
||||||
|
is_active=user.is_active,
|
||||||
|
created_at=user.created_at,
|
||||||
|
updated_at=user.updated_at
|
||||||
|
))
|
||||||
|
|
||||||
|
return UserListResponse(users=response_users, total=len(response_users))
|
||||||
|
|
||||||
|
@router.get("/{user_id}", response_model=UserResponse)
|
||||||
|
async def get_user(
|
||||||
|
user_id: str,
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get user by ID
|
||||||
|
|
||||||
|
Admins can get any user.
|
||||||
|
Non-admins can only get users from their own team.
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "user_id": user_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(user_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||||
|
|
||||||
|
# Get user
|
||||||
|
user = await user_repository.get_by_id(obj_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
|
# Check access permissions
|
||||||
|
if not current_user.is_admin and user.team_id != current_user.team_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized to access this user")
|
||||||
|
|
||||||
|
response = UserResponse(
|
||||||
|
id=str(user.id),
|
||||||
|
name=user.name,
|
||||||
|
email=user.email,
|
||||||
|
team_id=str(user.team_id),
|
||||||
|
is_admin=user.is_admin,
|
||||||
|
is_active=user.is_active,
|
||||||
|
created_at=user.created_at,
|
||||||
|
updated_at=user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.put("/{user_id}", response_model=UserResponse)
|
||||||
|
async def update_user(
|
||||||
|
user_id: str,
|
||||||
|
user_data: UserUpdate,
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update user
|
||||||
|
|
||||||
|
This endpoint requires admin privileges
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "user_id": user_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only admins can update other users
|
||||||
|
if not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Only admins can update users")
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(user_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||||
|
|
||||||
|
# Get user
|
||||||
|
user = await user_repository.get_by_id(obj_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
|
# Update user
|
||||||
|
update_data = user_data.dict(exclude_unset=True)
|
||||||
|
if not update_data:
|
||||||
|
# No fields to update
|
||||||
|
response = UserResponse(
|
||||||
|
id=str(user.id),
|
||||||
|
name=user.name,
|
||||||
|
email=user.email,
|
||||||
|
team_id=str(user.team_id),
|
||||||
|
is_admin=user.is_admin,
|
||||||
|
is_active=user.is_active,
|
||||||
|
created_at=user.created_at,
|
||||||
|
updated_at=user.updated_at
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
updated_user = await user_repository.update(obj_id, update_data)
|
||||||
|
if not updated_user:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to update user")
|
||||||
|
|
||||||
|
response = UserResponse(
|
||||||
|
id=str(updated_user.id),
|
||||||
|
name=updated_user.name,
|
||||||
|
email=updated_user.email,
|
||||||
|
team_id=str(updated_user.team_id),
|
||||||
|
is_admin=updated_user.is_admin,
|
||||||
|
is_active=updated_user.is_active,
|
||||||
|
created_at=updated_user.created_at,
|
||||||
|
updated_at=updated_user.updated_at
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
@router.delete("/{user_id}", status_code=204)
|
||||||
|
async def delete_user(
|
||||||
|
user_id: str,
|
||||||
|
request: Request,
|
||||||
|
current_user: UserModel = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete (deactivate) user
|
||||||
|
|
||||||
|
This endpoint requires admin privileges
|
||||||
|
"""
|
||||||
|
log_request(
|
||||||
|
{"path": request.url.path, "method": request.method, "user_id": user_id},
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
team_id=str(current_user.team_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only admins can delete users
|
||||||
|
if not current_user.is_admin:
|
||||||
|
raise HTTPException(status_code=403, detail="Only admins can delete users")
|
||||||
|
|
||||||
|
try:
|
||||||
|
obj_id = ObjectId(user_id)
|
||||||
|
except:
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||||
|
|
||||||
|
# Get user
|
||||||
|
user = await user_repository.get_by_id(obj_id)
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
|
||||||
|
# Prevent self-deletion
|
||||||
|
if obj_id == current_user.id:
|
||||||
|
raise HTTPException(status_code=400, detail="Cannot delete yourself")
|
||||||
|
|
||||||
|
# Deactivate user instead of hard delete
|
||||||
|
result = await user_repository.update(obj_id, {"is_active": False})
|
||||||
|
if not result:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to delete user")
|
||||||
|
|
||||||
|
return None
|
||||||
@ -5,7 +5,7 @@ from pydantic import AnyHttpUrl, field_validator
|
|||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
# Project settings
|
# Project settings
|
||||||
PROJECT_NAME: str = "Image Management API"
|
PROJECT_NAME: str = "SEREACT - Secure Image Management API"
|
||||||
API_V1_STR: str = "/api/v1"
|
API_V1_STR: str = "/api/v1"
|
||||||
|
|
||||||
# Environment
|
# Environment
|
||||||
@ -29,6 +29,13 @@ class Settings(BaseSettings):
|
|||||||
GCS_BUCKET_NAME: str = os.getenv("GCS_BUCKET_NAME", "image-mgmt-bucket")
|
GCS_BUCKET_NAME: str = os.getenv("GCS_BUCKET_NAME", "image-mgmt-bucket")
|
||||||
GCS_CREDENTIALS_FILE: str = os.getenv("GCS_CREDENTIALS_FILE", "credentials.json")
|
GCS_CREDENTIALS_FILE: str = os.getenv("GCS_CREDENTIALS_FILE", "credentials.json")
|
||||||
|
|
||||||
|
# Google Pub/Sub settings
|
||||||
|
PUBSUB_TOPIC: str = os.getenv("PUBSUB_TOPIC", "image-processing-topic")
|
||||||
|
PUBSUB_SUBSCRIPTION: str = os.getenv("PUBSUB_SUBSCRIPTION", "image-processing-subscription")
|
||||||
|
|
||||||
|
# Google Cloud Vision API
|
||||||
|
VISION_API_ENABLED: bool = os.getenv("VISION_API_ENABLED", "true").lower() == "true"
|
||||||
|
|
||||||
# Security settings
|
# Security settings
|
||||||
API_KEY_SECRET: str = os.getenv("API_KEY_SECRET", "super-secret-key-for-development-only")
|
API_KEY_SECRET: str = os.getenv("API_KEY_SECRET", "super-secret-key-for-development-only")
|
||||||
API_KEY_EXPIRY_DAYS: int = int(os.getenv("API_KEY_EXPIRY_DAYS", "365"))
|
API_KEY_EXPIRY_DAYS: int = int(os.getenv("API_KEY_EXPIRY_DAYS", "365"))
|
||||||
|
|||||||
@ -55,21 +55,68 @@ class ImageRepository:
|
|||||||
logger.error(f"Error getting image by ID: {e}")
|
logger.error(f"Error getting image by ID: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_by_team(self, team_id: ObjectId, limit: int = 100, skip: int = 0) -> List[ImageModel]:
|
async def get_by_ids(self, image_ids: List[str]) -> List[ImageModel]:
|
||||||
"""
|
"""
|
||||||
Get images by team ID with pagination
|
Get images by list of IDs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_ids: List of image ID strings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of images
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert string IDs to ObjectIds
|
||||||
|
object_ids = []
|
||||||
|
for image_id in image_ids:
|
||||||
|
try:
|
||||||
|
object_ids.append(ObjectId(image_id))
|
||||||
|
except:
|
||||||
|
logger.warning(f"Invalid ObjectId: {image_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
images = []
|
||||||
|
cursor = self.collection.find({"_id": {"$in": object_ids}})
|
||||||
|
async for document in cursor:
|
||||||
|
images.append(ImageModel(**document))
|
||||||
|
return images
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting images by IDs: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def get_by_team(
|
||||||
|
self,
|
||||||
|
team_id: ObjectId,
|
||||||
|
limit: int = 100,
|
||||||
|
skip: int = 0,
|
||||||
|
collection_id: Optional[ObjectId] = None,
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
) -> List[ImageModel]:
|
||||||
|
"""
|
||||||
|
Get images by team ID with pagination and filters
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
team_id: Team ID
|
team_id: Team ID
|
||||||
limit: Max number of results
|
limit: Max number of results
|
||||||
skip: Number of records to skip
|
skip: Number of records to skip
|
||||||
|
collection_id: Optional collection filter
|
||||||
|
tags: Optional tags filter
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of images for the team
|
List of images for the team
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
# Build query
|
||||||
|
query = {"team_id": team_id}
|
||||||
|
|
||||||
|
if collection_id:
|
||||||
|
query["collection_id"] = collection_id
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
query["tags"] = {"$in": tags}
|
||||||
|
|
||||||
images = []
|
images = []
|
||||||
cursor = self.collection.find({"team_id": team_id}).sort("upload_date", -1).skip(skip).limit(limit)
|
cursor = self.collection.find(query).sort("upload_date", -1).skip(skip).limit(limit)
|
||||||
async for document in cursor:
|
async for document in cursor:
|
||||||
images.append(ImageModel(**document))
|
images.append(ImageModel(**document))
|
||||||
return images
|
return images
|
||||||
@ -77,18 +124,34 @@ class ImageRepository:
|
|||||||
logger.error(f"Error getting images by team: {e}")
|
logger.error(f"Error getting images by team: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def count_by_team(self, team_id: ObjectId) -> int:
|
async def count_by_team(
|
||||||
|
self,
|
||||||
|
team_id: ObjectId,
|
||||||
|
collection_id: Optional[ObjectId] = None,
|
||||||
|
tags: Optional[List[str]] = None
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Count images by team ID
|
Count images by team ID with filters
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
team_id: Team ID
|
team_id: Team ID
|
||||||
|
collection_id: Optional collection filter
|
||||||
|
tags: Optional tags filter
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Number of images for the team
|
Number of images for the team
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
return await self.collection.count_documents({"team_id": team_id})
|
# Build query
|
||||||
|
query = {"team_id": team_id}
|
||||||
|
|
||||||
|
if collection_id:
|
||||||
|
query["collection_id"] = collection_id
|
||||||
|
|
||||||
|
if tags:
|
||||||
|
query["tags"] = {"$in": tags}
|
||||||
|
|
||||||
|
return await self.collection.count_documents(query)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error counting images by team: {e}")
|
logger.error(f"Error counting images by team: {e}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
@ -21,6 +21,7 @@ class ImageModel(BaseModel):
|
|||||||
description: Optional[str] = None
|
description: Optional[str] = None
|
||||||
tags: List[str] = []
|
tags: List[str] = []
|
||||||
metadata: Dict[str, Any] = {}
|
metadata: Dict[str, Any] = {}
|
||||||
|
collection_id: Optional[PyObjectId] = None
|
||||||
|
|
||||||
# Fields for image understanding and semantic search
|
# Fields for image understanding and semantic search
|
||||||
embedding_id: Optional[str] = None
|
embedding_id: Optional[str] = None
|
||||||
|
|||||||
@ -7,6 +7,10 @@ class ImageBase(BaseModel):
|
|||||||
description: Optional[str] = Field(None, description="Image description", max_length=500)
|
description: Optional[str] = Field(None, description="Image description", max_length=500)
|
||||||
tags: List[str] = Field(default=[], description="Image tags")
|
tags: List[str] = Field(default=[], description="Image tags")
|
||||||
|
|
||||||
|
class ImageCreate(ImageBase):
|
||||||
|
"""Schema for creating an image"""
|
||||||
|
collection_id: Optional[str] = Field(None, description="Collection ID to organize images")
|
||||||
|
|
||||||
class ImageUpload(ImageBase):
|
class ImageUpload(ImageBase):
|
||||||
"""Schema for uploading an image"""
|
"""Schema for uploading an image"""
|
||||||
# Note: The file itself is handled by FastAPI's UploadFile
|
# Note: The file itself is handled by FastAPI's UploadFile
|
||||||
@ -17,6 +21,7 @@ class ImageUpdate(BaseModel):
|
|||||||
description: Optional[str] = Field(None, description="Image description", max_length=500)
|
description: Optional[str] = Field(None, description="Image description", max_length=500)
|
||||||
tags: Optional[List[str]] = Field(None, description="Image tags")
|
tags: Optional[List[str]] = Field(None, description="Image tags")
|
||||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Image metadata")
|
metadata: Optional[Dict[str, Any]] = Field(None, description="Image metadata")
|
||||||
|
collection_id: Optional[str] = Field(None, description="Collection ID to organize images")
|
||||||
|
|
||||||
class ImageResponse(ImageBase):
|
class ImageResponse(ImageBase):
|
||||||
"""Schema for image response"""
|
"""Schema for image response"""
|
||||||
@ -25,6 +30,7 @@ class ImageResponse(ImageBase):
|
|||||||
original_filename: str
|
original_filename: str
|
||||||
file_size: int
|
file_size: int
|
||||||
content_type: str
|
content_type: str
|
||||||
|
storage_path: str
|
||||||
public_url: Optional[HttpUrl] = None
|
public_url: Optional[HttpUrl] = None
|
||||||
team_id: str
|
team_id: str
|
||||||
uploader_id: str
|
uploader_id: str
|
||||||
@ -32,6 +38,8 @@ class ImageResponse(ImageBase):
|
|||||||
last_accessed: Optional[datetime] = None
|
last_accessed: Optional[datetime] = None
|
||||||
metadata: Dict[str, Any] = Field(default={})
|
metadata: Dict[str, Any] = Field(default={})
|
||||||
has_embedding: bool = False
|
has_embedding: bool = False
|
||||||
|
collection_id: Optional[str] = None
|
||||||
|
similarity_score: Optional[float] = Field(None, description="Similarity score for search results")
|
||||||
|
|
||||||
model_config: ClassVar[dict] = {
|
model_config: ClassVar[dict] = {
|
||||||
"from_attributes": True,
|
"from_attributes": True,
|
||||||
@ -42,7 +50,8 @@ class ImageResponse(ImageBase):
|
|||||||
"original_filename": "sunset.jpg",
|
"original_filename": "sunset.jpg",
|
||||||
"file_size": 1024000,
|
"file_size": 1024000,
|
||||||
"content_type": "image/jpeg",
|
"content_type": "image/jpeg",
|
||||||
"public_url": "https://storage.googleapis.com/bucket/1234567890abcdef.jpg",
|
"storage_path": "team123/1234567890abcdef.jpg",
|
||||||
|
"public_url": "https://storage.googleapis.com/bucket/team123/1234567890abcdef.jpg",
|
||||||
"team_id": "507f1f77bcf86cd799439022",
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
"uploader_id": "507f1f77bcf86cd799439033",
|
"uploader_id": "507f1f77bcf86cd799439033",
|
||||||
"upload_date": "2023-10-20T10:00:00",
|
"upload_date": "2023-10-20T10:00:00",
|
||||||
@ -54,7 +63,9 @@ class ImageResponse(ImageBase):
|
|||||||
"height": 1080,
|
"height": 1080,
|
||||||
"location": "Rocky Mountains"
|
"location": "Rocky Mountains"
|
||||||
},
|
},
|
||||||
"has_embedding": True
|
"has_embedding": True,
|
||||||
|
"collection_id": "507f1f77bcf86cd799439044",
|
||||||
|
"similarity_score": 0.95
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -63,9 +74,8 @@ class ImageListResponse(BaseModel):
|
|||||||
"""Schema for image list response"""
|
"""Schema for image list response"""
|
||||||
images: List[ImageResponse]
|
images: List[ImageResponse]
|
||||||
total: int
|
total: int
|
||||||
page: int
|
skip: int = Field(0, description="Number of items skipped")
|
||||||
page_size: int
|
limit: int = Field(50, description="Number of items per page")
|
||||||
total_pages: int
|
|
||||||
|
|
||||||
model_config: ClassVar[dict] = {
|
model_config: ClassVar[dict] = {
|
||||||
"json_schema_extra": {
|
"json_schema_extra": {
|
||||||
@ -77,7 +87,8 @@ class ImageListResponse(BaseModel):
|
|||||||
"original_filename": "sunset.jpg",
|
"original_filename": "sunset.jpg",
|
||||||
"file_size": 1024000,
|
"file_size": 1024000,
|
||||||
"content_type": "image/jpeg",
|
"content_type": "image/jpeg",
|
||||||
"public_url": "https://storage.googleapis.com/bucket/1234567890abcdef.jpg",
|
"storage_path": "team123/1234567890abcdef.jpg",
|
||||||
|
"public_url": "https://storage.googleapis.com/bucket/team123/1234567890abcdef.jpg",
|
||||||
"team_id": "507f1f77bcf86cd799439022",
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
"uploader_id": "507f1f77bcf86cd799439033",
|
"uploader_id": "507f1f77bcf86cd799439033",
|
||||||
"upload_date": "2023-10-20T10:00:00",
|
"upload_date": "2023-10-20T10:00:00",
|
||||||
@ -89,13 +100,13 @@ class ImageListResponse(BaseModel):
|
|||||||
"height": 1080,
|
"height": 1080,
|
||||||
"location": "Rocky Mountains"
|
"location": "Rocky Mountains"
|
||||||
},
|
},
|
||||||
"has_embedding": True
|
"has_embedding": True,
|
||||||
|
"collection_id": "507f1f77bcf86cd799439044"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"total": 1,
|
"total": 1,
|
||||||
"page": 1,
|
"skip": 0,
|
||||||
"page_size": 10,
|
"limit": 50
|
||||||
"total_pages": 1
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
75
src/schemas/search.py
Normal file
75
src/schemas/search.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from typing import List, Optional, ClassVar
|
||||||
|
from datetime import datetime
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from src.schemas.image import ImageResponse
|
||||||
|
|
||||||
|
class SearchRequest(BaseModel):
|
||||||
|
"""Schema for advanced search request"""
|
||||||
|
query: str = Field(..., description="Search query", min_length=1)
|
||||||
|
limit: int = Field(10, description="Maximum number of results", ge=1, le=50)
|
||||||
|
threshold: float = Field(0.7, description="Similarity threshold", ge=0.0, le=1.0)
|
||||||
|
collection_id: Optional[str] = Field(None, description="Filter by collection ID")
|
||||||
|
tags: Optional[List[str]] = Field(None, description="Filter by tags")
|
||||||
|
date_from: Optional[datetime] = Field(None, description="Filter images uploaded after this date")
|
||||||
|
date_to: Optional[datetime] = Field(None, description="Filter images uploaded before this date")
|
||||||
|
uploader_id: Optional[str] = Field(None, description="Filter by uploader ID")
|
||||||
|
|
||||||
|
model_config: ClassVar[dict] = {
|
||||||
|
"json_schema_extra": {
|
||||||
|
"example": {
|
||||||
|
"query": "mountain sunset",
|
||||||
|
"limit": 10,
|
||||||
|
"threshold": 0.7,
|
||||||
|
"collection_id": "507f1f77bcf86cd799439044",
|
||||||
|
"tags": ["nature", "landscape"],
|
||||||
|
"date_from": "2023-01-01T00:00:00",
|
||||||
|
"date_to": "2023-12-31T23:59:59",
|
||||||
|
"uploader_id": "507f1f77bcf86cd799439033"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class SearchResponse(BaseModel):
|
||||||
|
"""Schema for search response"""
|
||||||
|
query: str
|
||||||
|
results: List[ImageResponse]
|
||||||
|
total: int
|
||||||
|
limit: int
|
||||||
|
threshold: float
|
||||||
|
|
||||||
|
model_config: ClassVar[dict] = {
|
||||||
|
"json_schema_extra": {
|
||||||
|
"example": {
|
||||||
|
"query": "mountain sunset",
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"id": "507f1f77bcf86cd799439011",
|
||||||
|
"filename": "1234567890abcdef.jpg",
|
||||||
|
"original_filename": "sunset.jpg",
|
||||||
|
"file_size": 1024000,
|
||||||
|
"content_type": "image/jpeg",
|
||||||
|
"storage_path": "team123/1234567890abcdef.jpg",
|
||||||
|
"public_url": "https://storage.googleapis.com/bucket/team123/1234567890abcdef.jpg",
|
||||||
|
"team_id": "507f1f77bcf86cd799439022",
|
||||||
|
"uploader_id": "507f1f77bcf86cd799439033",
|
||||||
|
"upload_date": "2023-10-20T10:00:00",
|
||||||
|
"last_accessed": "2023-10-21T10:00:00",
|
||||||
|
"description": "Beautiful sunset over the mountains",
|
||||||
|
"tags": ["sunset", "mountains", "nature"],
|
||||||
|
"metadata": {
|
||||||
|
"width": 1920,
|
||||||
|
"height": 1080,
|
||||||
|
"location": "Rocky Mountains"
|
||||||
|
},
|
||||||
|
"has_embedding": True,
|
||||||
|
"collection_id": "507f1f77bcf86cd799439044",
|
||||||
|
"similarity_score": 0.95
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"total": 1,
|
||||||
|
"limit": 10,
|
||||||
|
"threshold": 0.7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -58,7 +58,7 @@ class EmbeddingService:
|
|||||||
logger.error(f"Error generating image embedding: {e}")
|
logger.error(f"Error generating image embedding: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def generate_text_embedding(self, text: str) -> List[float]:
|
async def generate_text_embedding(self, text: str) -> Optional[List[float]]:
|
||||||
"""
|
"""
|
||||||
Generate embedding for a text query
|
Generate embedding for a text query
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ class EmbeddingService:
|
|||||||
text: Text query
|
text: Text query
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Text embedding as a list of floats
|
Text embedding as a list of floats, or None if failed
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._load_model()
|
self._load_model()
|
||||||
@ -79,7 +79,55 @@ class EmbeddingService:
|
|||||||
return embedding.tolist()
|
return embedding.tolist()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating text embedding: {e}")
|
logger.error(f"Error generating text embedding: {e}")
|
||||||
raise
|
return None
|
||||||
|
|
||||||
|
async def process_image_async(self, image_id: str, storage_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Process image asynchronously to generate embeddings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_id: Image ID
|
||||||
|
storage_path: Path to image in storage
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if processing started successfully
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# In a real implementation, this would:
|
||||||
|
# 1. Publish a message to Pub/Sub queue
|
||||||
|
# 2. Cloud Function would pick up the message
|
||||||
|
# 3. Generate embeddings using Cloud Vision API
|
||||||
|
# 4. Store embeddings in Pinecone
|
||||||
|
# 5. Update image record with embedding info
|
||||||
|
|
||||||
|
logger.info(f"Starting async processing for image {image_id} at {storage_path}")
|
||||||
|
|
||||||
|
# For now, just log that processing would start
|
||||||
|
# In production, this would integrate with Google Pub/Sub
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting async image processing: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def delete_embedding(self, embedding_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete embedding from vector database
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_id: Embedding ID in vector database
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if deletion was successful
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# In a real implementation, this would delete from Pinecone
|
||||||
|
logger.info(f"Deleting embedding {embedding_id} from vector database")
|
||||||
|
|
||||||
|
# Placeholder - in production this would call Pinecone API
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting embedding: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
def calculate_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
def calculate_similarity(self, embedding1: List[float], embedding2: List[float]) -> float:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -102,6 +102,91 @@ class VectorStoreService:
|
|||||||
logger.error(f"Error storing embedding: {e}")
|
logger.error(f"Error storing embedding: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def search_similar(
|
||||||
|
self,
|
||||||
|
query_embedding: List[float],
|
||||||
|
limit: int = 10,
|
||||||
|
threshold: float = 0.7,
|
||||||
|
team_id: str = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Search for similar embeddings
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_embedding: Query embedding vector
|
||||||
|
limit: Maximum number of results
|
||||||
|
threshold: Similarity threshold (0-1)
|
||||||
|
team_id: Optional team filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of results with id and score
|
||||||
|
"""
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
if not self.initialized:
|
||||||
|
logger.warning("Pinecone not initialized, cannot search")
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create filter for team_id if provided
|
||||||
|
filter_dict = None
|
||||||
|
if team_id:
|
||||||
|
filter_dict = {
|
||||||
|
"team_id": {"$eq": team_id}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Query the index
|
||||||
|
results = self.index.query(
|
||||||
|
vector=query_embedding,
|
||||||
|
filter=filter_dict,
|
||||||
|
top_k=limit,
|
||||||
|
include_metadata=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Format the results and apply threshold
|
||||||
|
formatted_results = []
|
||||||
|
for match in results.matches:
|
||||||
|
if match.score >= threshold:
|
||||||
|
formatted_results.append({
|
||||||
|
"id": match.metadata.get("image_id", match.id),
|
||||||
|
"score": match.score,
|
||||||
|
"metadata": match.metadata
|
||||||
|
})
|
||||||
|
|
||||||
|
return formatted_results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error searching similar embeddings: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_embedding(self, embedding_id: str) -> Optional[List[float]]:
|
||||||
|
"""
|
||||||
|
Get embedding by ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embedding_id: Embedding ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Embedding vector or None if not found
|
||||||
|
"""
|
||||||
|
self.initialize()
|
||||||
|
|
||||||
|
if not self.initialized:
|
||||||
|
logger.warning("Pinecone not initialized, cannot get embedding")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Fetch the vector
|
||||||
|
result = self.index.fetch(ids=[embedding_id])
|
||||||
|
|
||||||
|
if embedding_id in result.vectors:
|
||||||
|
return result.vectors[embedding_id].values
|
||||||
|
else:
|
||||||
|
logger.warning(f"Embedding not found: {embedding_id}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting embedding: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
def search_by_embedding(self, team_id: str, query_embedding: List[float], limit: int = 10) -> List[Dict[str, Any]]:
|
def search_by_embedding(self, team_id: str, query_embedding: List[float], limit: int = 10) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Search for similar images by embedding
|
Search for similar images by embedding
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user