2025-05-25 16:52:38 +02:00

258 lines
10 KiB
Python

import logging
from typing import Optional, List, Dict, Any
from fastapi import APIRouter, Depends, Query, Request, HTTPException
from src.auth.security import get_current_user
from src.services.vector_db import VectorDatabaseService
from src.services.embedding_service import EmbeddingService
from src.db.repositories.image_repository import image_repository
from src.db.repositories.team_repository import team_repository
from src.models.user import UserModel
from src.schemas.image import ImageResponse
from src.schemas.search import SearchResponse, SearchRequest
from src.utils.logging import log_request
logger = logging.getLogger(__name__)
router = APIRouter(tags=["Search"], prefix="/search")
# Initialize services - delay VectorDatabaseService instantiation
vector_db_service = None
embedding_service = EmbeddingService()
def get_vector_db_service():
"""Get or create the vector database service instance"""
global vector_db_service
if vector_db_service is None:
logger.info("Initializing VectorDatabaseService...")
vector_db_service = VectorDatabaseService()
return vector_db_service
@router.get("", response_model=SearchResponse)
async def search_images(
request: Request,
q: str = Query(..., description="Search query"),
limit: int = Query(10, ge=1, le=50, description="Number of results to return"),
threshold: float = Query(0.65, ge=0.0, le=1.0, description="Similarity threshold"),
collection_id: Optional[str] = Query(None, description="Filter by collection ID"),
current_user: UserModel = Depends(get_current_user)
):
"""
Search for images using semantic similarity
"""
log_request(
{
"path": request.url.path,
"method": request.method,
"query": q,
"limit": limit,
"threshold": threshold
},
user_id=str(current_user.id),
team_id=str(current_user.team_id)
)
try:
# Generate embedding for the search query
query_embedding = await embedding_service.generate_text_embedding(q)
if not query_embedding:
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
# Search in vector database
search_results = get_vector_db_service().search_similar_images(
query_vector=query_embedding,
limit=limit,
score_threshold=threshold,
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
)
if not search_results:
return SearchResponse(
query=q,
results=[],
total=0,
limit=limit,
threshold=threshold
)
# Get image IDs and scores from search results
image_ids = [result['image_id'] for result in search_results if result['image_id']]
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
# Get image metadata from database
images = await image_repository.get_by_ids(image_ids)
# Filter by collection if specified
filtered_images = []
for image in images:
# Check collection filter
if collection_id and str(image.collection_id) != collection_id:
continue
filtered_images.append(image)
# Convert to response format with similarity scores
results = []
for image in filtered_images:
image_id = str(image.id)
similarity_score = scores.get(image_id, 0.0)
result = ImageResponse(
id=image_id,
filename=image.filename,
original_filename=image.original_filename,
file_size=image.file_size,
content_type=image.content_type,
storage_path=image.storage_path,
team_id=str(image.team_id),
uploader_id=str(image.uploader_id),
upload_date=image.upload_date,
description=image.description,
metadata=image.metadata,
has_embedding=image.has_embedding,
collection_id=str(image.collection_id) if image.collection_id else None,
similarity_score=similarity_score
)
results.append(result)
# Sort by similarity score (highest first)
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
return SearchResponse(
query=q,
results=results,
total=len(results),
limit=limit,
threshold=threshold
)
except Exception as e:
logger.error(f"Error searching images: {e}")
raise HTTPException(status_code=500, detail="Search failed")
@router.post("", response_model=SearchResponse)
async def search_images_advanced(
search_request: SearchRequest,
request: Request,
current_user: UserModel = Depends(get_current_user)
):
"""
Advanced search for images with more options
"""
log_request(
{
"path": request.url.path,
"method": request.method,
"search_request": search_request.dict()
},
user_id=str(current_user.id),
team_id=str(current_user.team_id)
)
try:
# Generate embedding for the search query
logger.info(f"Generating embedding for query: {search_request.query}")
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
if not query_embedding:
logger.error("Failed to generate search embedding")
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
logger.info(f"Generated embedding with length: {len(query_embedding)}")
# Search in vector database
logger.info(f"Searching vector database with threshold: {search_request.threshold}")
search_results = get_vector_db_service().search_similar_images(
query_vector=query_embedding,
limit=search_request.limit,
score_threshold=search_request.threshold,
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
)
logger.info(f"Vector search returned {len(search_results) if search_results else 0} results")
if not search_results:
logger.info("No search results from vector database, returning empty response")
return SearchResponse(
query=search_request.query,
results=[],
total=0,
limit=search_request.limit,
threshold=search_request.threshold
)
# Get image IDs and scores from search results
image_ids = [result['image_id'] for result in search_results if result['image_id']]
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
logger.info(f"Extracted {len(image_ids)} image IDs: {image_ids}")
# Get image metadata from database
logger.info("Fetching image metadata from database...")
images = await image_repository.get_by_ids(image_ids)
logger.info(f"Retrieved {len(images)} images from database")
# Apply filters
filtered_images = []
for image in images:
# Check collection filter
if search_request.collection_id and str(image.collection_id) != search_request.collection_id:
continue
# Check date range filter
if search_request.date_from and image.upload_date < search_request.date_from:
continue
if search_request.date_to and image.upload_date > search_request.date_to:
continue
# Check uploader filter
if search_request.uploader_id and str(image.uploader_id) != search_request.uploader_id:
continue
filtered_images.append(image)
logger.info(f"After filtering: {len(filtered_images)} images remain")
# Convert to response format with similarity scores
results = []
for image in filtered_images:
image_id = str(image.id)
similarity_score = scores.get(image_id, 0.0)
result = ImageResponse(
id=image_id,
filename=image.filename,
original_filename=image.original_filename,
file_size=image.file_size,
content_type=image.content_type,
storage_path=image.storage_path,
team_id=str(image.team_id),
uploader_id=str(image.uploader_id),
upload_date=image.upload_date,
description=image.description,
metadata=image.metadata,
has_embedding=image.has_embedding,
collection_id=str(image.collection_id) if image.collection_id else None,
similarity_score=similarity_score
)
results.append(result)
# Sort by similarity score (highest first)
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
logger.info(f"Returning {len(results)} results")
return SearchResponse(
query=search_request.query,
results=results,
total=len(results),
limit=search_request.limit,
threshold=search_request.threshold
)
except Exception as e:
logger.error(f"Error in advanced search: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail="Advanced search failed")