2025-05-24 12:39:42 +02:00

362 lines
14 KiB
Python

import logging
from typing import Optional, List
from fastapi import APIRouter, Depends, Query, Request, HTTPException
from src.api.v1.auth import get_current_user
from src.db.repositories.image_repository import image_repository
from src.services.vector_store import VectorStoreService
from src.services.embedding_service import EmbeddingService
from src.models.user import UserModel
from src.schemas.image import ImageResponse
from src.schemas.search import SearchResponse, SearchRequest
from src.utils.logging import log_request
logger = logging.getLogger(__name__)
router = APIRouter(tags=["Search"], prefix="/search")
# Initialize services
vector_store_service = VectorStoreService()
embedding_service = EmbeddingService()
@router.get("", response_model=SearchResponse)
async def search_images(
request: Request,
q: str = Query(..., description="Search query"),
limit: int = Query(10, ge=1, le=50, description="Number of results to return"),
threshold: float = Query(0.7, ge=0.0, le=1.0, description="Similarity threshold"),
collection_id: Optional[str] = Query(None, description="Filter by collection ID"),
tags: Optional[str] = Query(None, description="Filter by tags (comma-separated)"),
current_user: UserModel = Depends(get_current_user)
):
"""
Search for images using semantic similarity
"""
log_request(
{
"path": request.url.path,
"method": request.method,
"query": q,
"limit": limit,
"threshold": threshold
},
user_id=str(current_user.id),
team_id=str(current_user.team_id)
)
try:
# Generate embedding for the search query
query_embedding = await embedding_service.generate_text_embedding(q)
if not query_embedding:
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
# Search in vector database
search_results = await vector_store_service.search_similar(
query_embedding,
limit=limit,
threshold=threshold,
team_id=str(current_user.team_id)
)
if not search_results:
return SearchResponse(
query=q,
results=[],
total=0,
limit=limit,
threshold=threshold
)
# Get image IDs and scores from search results
image_ids = [result['id'] for result in search_results]
scores = {result['id']: result['score'] for result in search_results}
# Get image metadata from database
images = await image_repository.get_by_ids(image_ids)
# Filter by collection and tags if specified
filtered_images = []
for image in images:
# Check collection filter
if collection_id and str(image.collection_id) != collection_id:
continue
# Check tags filter
if tags:
tag_filter = [tag.strip() for tag in tags.split(',') if tag.strip()]
if not any(tag in image.tags for tag in tag_filter):
continue
filtered_images.append(image)
# Convert to response format with similarity scores
results = []
for image in filtered_images:
image_id = str(image.id)
similarity_score = scores.get(image_id, 0.0)
result = ImageResponse(
id=image_id,
filename=image.filename,
original_filename=image.original_filename,
file_size=image.file_size,
content_type=image.content_type,
storage_path=image.storage_path,
team_id=str(image.team_id),
uploader_id=str(image.uploader_id),
upload_date=image.upload_date,
description=image.description,
tags=image.tags,
metadata=image.metadata,
has_embedding=image.has_embedding,
collection_id=str(image.collection_id) if image.collection_id else None,
similarity_score=similarity_score
)
results.append(result)
# Sort by similarity score (highest first)
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
return SearchResponse(
query=q,
results=results,
total=len(results),
limit=limit,
threshold=threshold
)
except Exception as e:
logger.error(f"Error searching images: {e}")
raise HTTPException(status_code=500, detail="Search failed")
@router.post("", response_model=SearchResponse)
async def search_images_advanced(
search_request: SearchRequest,
request: Request,
current_user: UserModel = Depends(get_current_user)
):
"""
Advanced search for images with more options
"""
log_request(
{
"path": request.url.path,
"method": request.method,
"search_request": search_request.dict()
},
user_id=str(current_user.id),
team_id=str(current_user.team_id)
)
try:
# Generate embedding for the search query
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
if not query_embedding:
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
# Search in vector database
search_results = await vector_store_service.search_similar(
query_embedding,
limit=search_request.limit,
threshold=search_request.threshold,
team_id=str(current_user.team_id)
)
if not search_results:
return SearchResponse(
query=search_request.query,
results=[],
total=0,
limit=search_request.limit,
threshold=search_request.threshold
)
# Get image IDs and scores from search results
image_ids = [result['id'] for result in search_results]
scores = {result['id']: result['score'] for result in search_results}
# Get image metadata from database
images = await image_repository.get_by_ids(image_ids)
# Apply filters
filtered_images = []
for image in images:
# Check collection filter
if search_request.collection_id and str(image.collection_id) != search_request.collection_id:
continue
# Check tags filter
if search_request.tags:
if not any(tag in image.tags for tag in search_request.tags):
continue
# Check date range filter
if search_request.date_from and image.upload_date < search_request.date_from:
continue
if search_request.date_to and image.upload_date > search_request.date_to:
continue
# Check uploader filter
if search_request.uploader_id and str(image.uploader_id) != search_request.uploader_id:
continue
filtered_images.append(image)
# Convert to response format with similarity scores
results = []
for image in filtered_images:
image_id = str(image.id)
similarity_score = scores.get(image_id, 0.0)
result = ImageResponse(
id=image_id,
filename=image.filename,
original_filename=image.original_filename,
file_size=image.file_size,
content_type=image.content_type,
storage_path=image.storage_path,
team_id=str(image.team_id),
uploader_id=str(image.uploader_id),
upload_date=image.upload_date,
description=image.description,
tags=image.tags,
metadata=image.metadata,
has_embedding=image.has_embedding,
collection_id=str(image.collection_id) if image.collection_id else None,
similarity_score=similarity_score
)
results.append(result)
# Sort by similarity score (highest first)
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
return SearchResponse(
query=search_request.query,
results=results,
total=len(results),
limit=search_request.limit,
threshold=search_request.threshold
)
except Exception as e:
logger.error(f"Error in advanced search: {e}")
raise HTTPException(status_code=500, detail="Advanced search failed")
@router.get("/similar/{image_id}", response_model=SearchResponse)
async def find_similar_images(
image_id: str,
request: Request,
limit: int = Query(10, ge=1, le=50, description="Number of similar images to return"),
threshold: float = Query(0.7, ge=0.0, le=1.0, description="Similarity threshold"),
current_user: UserModel = Depends(get_current_user)
):
"""
Find images similar to a given image
"""
log_request(
{
"path": request.url.path,
"method": request.method,
"image_id": image_id,
"limit": limit,
"threshold": threshold
},
user_id=str(current_user.id),
team_id=str(current_user.team_id)
)
try:
from bson import ObjectId
# Get the reference image
try:
obj_id = ObjectId(image_id)
except:
raise HTTPException(status_code=400, detail="Invalid image ID")
reference_image = await image_repository.get_by_id(obj_id)
if not reference_image:
raise HTTPException(status_code=404, detail="Reference image not found")
# Check team access
if reference_image.team_id != current_user.team_id:
raise HTTPException(status_code=403, detail="Not authorized to access this image")
# Check if the image has embeddings
if not reference_image.has_embedding or not reference_image.embedding_id:
raise HTTPException(status_code=400, detail="Reference image does not have embeddings")
# Get the embedding for the reference image
reference_embedding = await vector_store_service.get_embedding(reference_image.embedding_id)
if not reference_embedding:
raise HTTPException(status_code=400, detail="Failed to get reference image embedding")
# Search for similar images
search_results = await vector_store_service.search_similar(
reference_embedding,
limit=limit + 1, # +1 to account for the reference image itself
threshold=threshold,
team_id=str(current_user.team_id)
)
# Remove the reference image from results
search_results = [result for result in search_results if result['id'] != image_id][:limit]
if not search_results:
return SearchResponse(
query=f"Similar to image {image_id}",
results=[],
total=0,
limit=limit,
threshold=threshold
)
# Get image IDs and scores from search results
image_ids = [result['id'] for result in search_results]
scores = {result['id']: result['score'] for result in search_results}
# Get image metadata from database
images = await image_repository.get_by_ids(image_ids)
# Convert to response format with similarity scores
results = []
for image in images:
image_id_str = str(image.id)
similarity_score = scores.get(image_id_str, 0.0)
result = ImageResponse(
id=image_id_str,
filename=image.filename,
original_filename=image.original_filename,
file_size=image.file_size,
content_type=image.content_type,
storage_path=image.storage_path,
team_id=str(image.team_id),
uploader_id=str(image.uploader_id),
upload_date=image.upload_date,
description=image.description,
tags=image.tags,
metadata=image.metadata,
has_embedding=image.has_embedding,
collection_id=str(image.collection_id) if image.collection_id else None,
similarity_score=similarity_score
)
results.append(result)
# Sort by similarity score (highest first)
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
return SearchResponse(
query=f"Similar to image {image_id}",
results=results,
total=len(results),
limit=limit,
threshold=threshold
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error finding similar images: {e}")
raise HTTPException(status_code=500, detail="Similar image search failed")