258 lines
10 KiB
Python
258 lines
10 KiB
Python
import logging
|
|
from typing import Optional, List, Dict, Any
|
|
from fastapi import APIRouter, Depends, Query, Request, HTTPException
|
|
|
|
from src.auth.security import get_current_user
|
|
from src.services.vector_db import VectorDatabaseService
|
|
from src.services.embedding_service import EmbeddingService
|
|
from src.db.repositories.image_repository import image_repository
|
|
from src.db.repositories.team_repository import team_repository
|
|
from src.models.user import UserModel
|
|
from src.schemas.image import ImageResponse
|
|
from src.schemas.search import SearchResponse, SearchRequest
|
|
from src.utils.logging import log_request
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(tags=["Search"], prefix="/search")
|
|
|
|
# Initialize services - delay VectorDatabaseService instantiation
|
|
vector_db_service = None
|
|
embedding_service = EmbeddingService()
|
|
|
|
def get_vector_db_service():
|
|
"""Get or create the vector database service instance"""
|
|
global vector_db_service
|
|
if vector_db_service is None:
|
|
logger.info("Initializing VectorDatabaseService...")
|
|
vector_db_service = VectorDatabaseService()
|
|
return vector_db_service
|
|
|
|
@router.get("", response_model=SearchResponse)
|
|
async def search_images(
|
|
request: Request,
|
|
q: str = Query(..., description="Search query"),
|
|
limit: int = Query(10, ge=1, le=50, description="Number of results to return"),
|
|
threshold: float = Query(0.65, ge=0.0, le=1.0, description="Similarity threshold"),
|
|
collection_id: Optional[str] = Query(None, description="Filter by collection ID"),
|
|
current_user: UserModel = Depends(get_current_user)
|
|
):
|
|
"""
|
|
Search for images using semantic similarity
|
|
"""
|
|
log_request(
|
|
{
|
|
"path": request.url.path,
|
|
"method": request.method,
|
|
"query": q,
|
|
"limit": limit,
|
|
"threshold": threshold
|
|
},
|
|
user_id=str(current_user.id),
|
|
team_id=str(current_user.team_id)
|
|
)
|
|
|
|
try:
|
|
# Generate embedding for the search query
|
|
query_embedding = await embedding_service.generate_text_embedding(q)
|
|
if not query_embedding:
|
|
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
|
|
|
# Search in vector database
|
|
search_results = get_vector_db_service().search_similar_images(
|
|
query_vector=query_embedding,
|
|
limit=limit,
|
|
score_threshold=threshold,
|
|
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
|
)
|
|
|
|
if not search_results:
|
|
return SearchResponse(
|
|
query=q,
|
|
results=[],
|
|
total=0,
|
|
limit=limit,
|
|
threshold=threshold
|
|
)
|
|
|
|
# Get image IDs and scores from search results
|
|
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
|
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
|
|
|
# Get image metadata from database
|
|
images = await image_repository.get_by_ids(image_ids)
|
|
|
|
# Filter by collection if specified
|
|
filtered_images = []
|
|
for image in images:
|
|
# Check collection filter
|
|
if collection_id and str(image.collection_id) != collection_id:
|
|
continue
|
|
|
|
filtered_images.append(image)
|
|
|
|
# Convert to response format with similarity scores
|
|
results = []
|
|
for image in filtered_images:
|
|
image_id = str(image.id)
|
|
similarity_score = scores.get(image_id, 0.0)
|
|
|
|
result = ImageResponse(
|
|
id=image_id,
|
|
filename=image.filename,
|
|
original_filename=image.original_filename,
|
|
file_size=image.file_size,
|
|
content_type=image.content_type,
|
|
storage_path=image.storage_path,
|
|
team_id=str(image.team_id),
|
|
uploader_id=str(image.uploader_id),
|
|
upload_date=image.upload_date,
|
|
description=image.description,
|
|
metadata=image.metadata,
|
|
has_embedding=image.has_embedding,
|
|
collection_id=str(image.collection_id) if image.collection_id else None,
|
|
similarity_score=similarity_score
|
|
)
|
|
results.append(result)
|
|
|
|
# Sort by similarity score (highest first)
|
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
|
|
|
return SearchResponse(
|
|
query=q,
|
|
results=results,
|
|
total=len(results),
|
|
limit=limit,
|
|
threshold=threshold
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error searching images: {e}")
|
|
raise HTTPException(status_code=500, detail="Search failed")
|
|
|
|
@router.post("", response_model=SearchResponse)
|
|
async def search_images_advanced(
|
|
search_request: SearchRequest,
|
|
request: Request,
|
|
current_user: UserModel = Depends(get_current_user)
|
|
):
|
|
"""
|
|
Advanced search for images with more options
|
|
"""
|
|
log_request(
|
|
{
|
|
"path": request.url.path,
|
|
"method": request.method,
|
|
"search_request": search_request.dict()
|
|
},
|
|
user_id=str(current_user.id),
|
|
team_id=str(current_user.team_id)
|
|
)
|
|
|
|
try:
|
|
# Generate embedding for the search query
|
|
logger.info(f"Generating embedding for query: {search_request.query}")
|
|
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
|
|
if not query_embedding:
|
|
logger.error("Failed to generate search embedding")
|
|
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
|
|
|
logger.info(f"Generated embedding with length: {len(query_embedding)}")
|
|
|
|
# Search in vector database
|
|
logger.info(f"Searching vector database with threshold: {search_request.threshold}")
|
|
search_results = get_vector_db_service().search_similar_images(
|
|
query_vector=query_embedding,
|
|
limit=search_request.limit,
|
|
score_threshold=search_request.threshold,
|
|
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
|
)
|
|
|
|
logger.info(f"Vector search returned {len(search_results) if search_results else 0} results")
|
|
|
|
if not search_results:
|
|
logger.info("No search results from vector database, returning empty response")
|
|
return SearchResponse(
|
|
query=search_request.query,
|
|
results=[],
|
|
total=0,
|
|
limit=search_request.limit,
|
|
threshold=search_request.threshold
|
|
)
|
|
|
|
# Get image IDs and scores from search results
|
|
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
|
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
|
|
|
logger.info(f"Extracted {len(image_ids)} image IDs: {image_ids}")
|
|
|
|
# Get image metadata from database
|
|
logger.info("Fetching image metadata from database...")
|
|
images = await image_repository.get_by_ids(image_ids)
|
|
logger.info(f"Retrieved {len(images)} images from database")
|
|
|
|
# Apply filters
|
|
filtered_images = []
|
|
for image in images:
|
|
# Check collection filter
|
|
if search_request.collection_id and str(image.collection_id) != search_request.collection_id:
|
|
continue
|
|
|
|
# Check date range filter
|
|
if search_request.date_from and image.upload_date < search_request.date_from:
|
|
continue
|
|
|
|
if search_request.date_to and image.upload_date > search_request.date_to:
|
|
continue
|
|
|
|
# Check uploader filter
|
|
if search_request.uploader_id and str(image.uploader_id) != search_request.uploader_id:
|
|
continue
|
|
|
|
filtered_images.append(image)
|
|
|
|
logger.info(f"After filtering: {len(filtered_images)} images remain")
|
|
|
|
# Convert to response format with similarity scores
|
|
results = []
|
|
for image in filtered_images:
|
|
image_id = str(image.id)
|
|
similarity_score = scores.get(image_id, 0.0)
|
|
|
|
result = ImageResponse(
|
|
id=image_id,
|
|
filename=image.filename,
|
|
original_filename=image.original_filename,
|
|
file_size=image.file_size,
|
|
content_type=image.content_type,
|
|
storage_path=image.storage_path,
|
|
team_id=str(image.team_id),
|
|
uploader_id=str(image.uploader_id),
|
|
upload_date=image.upload_date,
|
|
description=image.description,
|
|
metadata=image.metadata,
|
|
has_embedding=image.has_embedding,
|
|
collection_id=str(image.collection_id) if image.collection_id else None,
|
|
similarity_score=similarity_score
|
|
)
|
|
results.append(result)
|
|
|
|
# Sort by similarity score (highest first)
|
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
|
|
|
logger.info(f"Returning {len(results)} results")
|
|
|
|
return SearchResponse(
|
|
query=search_request.query,
|
|
results=results,
|
|
total=len(results),
|
|
limit=search_request.limit,
|
|
threshold=search_request.threshold
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in advanced search: {e}")
|
|
import traceback
|
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
|
raise HTTPException(status_code=500, detail="Advanced search failed")
|