362 lines
14 KiB
Python
362 lines
14 KiB
Python
import logging
|
|
from typing import Optional, List
|
|
from fastapi import APIRouter, Depends, Query, Request, HTTPException
|
|
|
|
from src.api.v1.auth import get_current_user
|
|
from src.db.repositories.image_repository import image_repository
|
|
from src.services.vector_store import VectorStoreService
|
|
from src.services.embedding_service import EmbeddingService
|
|
from src.models.user import UserModel
|
|
from src.schemas.image import ImageResponse
|
|
from src.schemas.search import SearchResponse, SearchRequest
|
|
from src.utils.logging import log_request
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(tags=["Search"], prefix="/search")
|
|
|
|
# Initialize services
|
|
vector_store_service = VectorStoreService()
|
|
embedding_service = EmbeddingService()
|
|
|
|
@router.get("", response_model=SearchResponse)
|
|
async def search_images(
|
|
request: Request,
|
|
q: str = Query(..., description="Search query"),
|
|
limit: int = Query(10, ge=1, le=50, description="Number of results to return"),
|
|
threshold: float = Query(0.7, ge=0.0, le=1.0, description="Similarity threshold"),
|
|
collection_id: Optional[str] = Query(None, description="Filter by collection ID"),
|
|
tags: Optional[str] = Query(None, description="Filter by tags (comma-separated)"),
|
|
current_user: UserModel = Depends(get_current_user)
|
|
):
|
|
"""
|
|
Search for images using semantic similarity
|
|
"""
|
|
log_request(
|
|
{
|
|
"path": request.url.path,
|
|
"method": request.method,
|
|
"query": q,
|
|
"limit": limit,
|
|
"threshold": threshold
|
|
},
|
|
user_id=str(current_user.id),
|
|
team_id=str(current_user.team_id)
|
|
)
|
|
|
|
try:
|
|
# Generate embedding for the search query
|
|
query_embedding = await embedding_service.generate_text_embedding(q)
|
|
if not query_embedding:
|
|
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
|
|
|
# Search in vector database
|
|
search_results = await vector_store_service.search_similar(
|
|
query_embedding,
|
|
limit=limit,
|
|
threshold=threshold,
|
|
team_id=str(current_user.team_id)
|
|
)
|
|
|
|
if not search_results:
|
|
return SearchResponse(
|
|
query=q,
|
|
results=[],
|
|
total=0,
|
|
limit=limit,
|
|
threshold=threshold
|
|
)
|
|
|
|
# Get image IDs and scores from search results
|
|
image_ids = [result['id'] for result in search_results]
|
|
scores = {result['id']: result['score'] for result in search_results}
|
|
|
|
# Get image metadata from database
|
|
images = await image_repository.get_by_ids(image_ids)
|
|
|
|
# Filter by collection and tags if specified
|
|
filtered_images = []
|
|
for image in images:
|
|
# Check collection filter
|
|
if collection_id and str(image.collection_id) != collection_id:
|
|
continue
|
|
|
|
# Check tags filter
|
|
if tags:
|
|
tag_filter = [tag.strip() for tag in tags.split(',') if tag.strip()]
|
|
if not any(tag in image.tags for tag in tag_filter):
|
|
continue
|
|
|
|
filtered_images.append(image)
|
|
|
|
# Convert to response format with similarity scores
|
|
results = []
|
|
for image in filtered_images:
|
|
image_id = str(image.id)
|
|
similarity_score = scores.get(image_id, 0.0)
|
|
|
|
result = ImageResponse(
|
|
id=image_id,
|
|
filename=image.filename,
|
|
original_filename=image.original_filename,
|
|
file_size=image.file_size,
|
|
content_type=image.content_type,
|
|
storage_path=image.storage_path,
|
|
team_id=str(image.team_id),
|
|
uploader_id=str(image.uploader_id),
|
|
upload_date=image.upload_date,
|
|
description=image.description,
|
|
tags=image.tags,
|
|
metadata=image.metadata,
|
|
has_embedding=image.has_embedding,
|
|
collection_id=str(image.collection_id) if image.collection_id else None,
|
|
similarity_score=similarity_score
|
|
)
|
|
results.append(result)
|
|
|
|
# Sort by similarity score (highest first)
|
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
|
|
|
return SearchResponse(
|
|
query=q,
|
|
results=results,
|
|
total=len(results),
|
|
limit=limit,
|
|
threshold=threshold
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error searching images: {e}")
|
|
raise HTTPException(status_code=500, detail="Search failed")
|
|
|
|
@router.post("", response_model=SearchResponse)
|
|
async def search_images_advanced(
|
|
search_request: SearchRequest,
|
|
request: Request,
|
|
current_user: UserModel = Depends(get_current_user)
|
|
):
|
|
"""
|
|
Advanced search for images with more options
|
|
"""
|
|
log_request(
|
|
{
|
|
"path": request.url.path,
|
|
"method": request.method,
|
|
"search_request": search_request.dict()
|
|
},
|
|
user_id=str(current_user.id),
|
|
team_id=str(current_user.team_id)
|
|
)
|
|
|
|
try:
|
|
# Generate embedding for the search query
|
|
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
|
|
if not query_embedding:
|
|
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
|
|
|
# Search in vector database
|
|
search_results = await vector_store_service.search_similar(
|
|
query_embedding,
|
|
limit=search_request.limit,
|
|
threshold=search_request.threshold,
|
|
team_id=str(current_user.team_id)
|
|
)
|
|
|
|
if not search_results:
|
|
return SearchResponse(
|
|
query=search_request.query,
|
|
results=[],
|
|
total=0,
|
|
limit=search_request.limit,
|
|
threshold=search_request.threshold
|
|
)
|
|
|
|
# Get image IDs and scores from search results
|
|
image_ids = [result['id'] for result in search_results]
|
|
scores = {result['id']: result['score'] for result in search_results}
|
|
|
|
# Get image metadata from database
|
|
images = await image_repository.get_by_ids(image_ids)
|
|
|
|
# Apply filters
|
|
filtered_images = []
|
|
for image in images:
|
|
# Check collection filter
|
|
if search_request.collection_id and str(image.collection_id) != search_request.collection_id:
|
|
continue
|
|
|
|
# Check tags filter
|
|
if search_request.tags:
|
|
if not any(tag in image.tags for tag in search_request.tags):
|
|
continue
|
|
|
|
# Check date range filter
|
|
if search_request.date_from and image.upload_date < search_request.date_from:
|
|
continue
|
|
|
|
if search_request.date_to and image.upload_date > search_request.date_to:
|
|
continue
|
|
|
|
# Check uploader filter
|
|
if search_request.uploader_id and str(image.uploader_id) != search_request.uploader_id:
|
|
continue
|
|
|
|
filtered_images.append(image)
|
|
|
|
# Convert to response format with similarity scores
|
|
results = []
|
|
for image in filtered_images:
|
|
image_id = str(image.id)
|
|
similarity_score = scores.get(image_id, 0.0)
|
|
|
|
result = ImageResponse(
|
|
id=image_id,
|
|
filename=image.filename,
|
|
original_filename=image.original_filename,
|
|
file_size=image.file_size,
|
|
content_type=image.content_type,
|
|
storage_path=image.storage_path,
|
|
team_id=str(image.team_id),
|
|
uploader_id=str(image.uploader_id),
|
|
upload_date=image.upload_date,
|
|
description=image.description,
|
|
tags=image.tags,
|
|
metadata=image.metadata,
|
|
has_embedding=image.has_embedding,
|
|
collection_id=str(image.collection_id) if image.collection_id else None,
|
|
similarity_score=similarity_score
|
|
)
|
|
results.append(result)
|
|
|
|
# Sort by similarity score (highest first)
|
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
|
|
|
return SearchResponse(
|
|
query=search_request.query,
|
|
results=results,
|
|
total=len(results),
|
|
limit=search_request.limit,
|
|
threshold=search_request.threshold
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in advanced search: {e}")
|
|
raise HTTPException(status_code=500, detail="Advanced search failed")
|
|
|
|
@router.get("/similar/{image_id}", response_model=SearchResponse)
|
|
async def find_similar_images(
|
|
image_id: str,
|
|
request: Request,
|
|
limit: int = Query(10, ge=1, le=50, description="Number of similar images to return"),
|
|
threshold: float = Query(0.7, ge=0.0, le=1.0, description="Similarity threshold"),
|
|
current_user: UserModel = Depends(get_current_user)
|
|
):
|
|
"""
|
|
Find images similar to a given image
|
|
"""
|
|
log_request(
|
|
{
|
|
"path": request.url.path,
|
|
"method": request.method,
|
|
"image_id": image_id,
|
|
"limit": limit,
|
|
"threshold": threshold
|
|
},
|
|
user_id=str(current_user.id),
|
|
team_id=str(current_user.team_id)
|
|
)
|
|
|
|
try:
|
|
from bson import ObjectId
|
|
|
|
# Get the reference image
|
|
try:
|
|
obj_id = ObjectId(image_id)
|
|
except:
|
|
raise HTTPException(status_code=400, detail="Invalid image ID")
|
|
|
|
reference_image = await image_repository.get_by_id(obj_id)
|
|
if not reference_image:
|
|
raise HTTPException(status_code=404, detail="Reference image not found")
|
|
|
|
# Check team access
|
|
if reference_image.team_id != current_user.team_id:
|
|
raise HTTPException(status_code=403, detail="Not authorized to access this image")
|
|
|
|
# Check if the image has embeddings
|
|
if not reference_image.has_embedding or not reference_image.embedding_id:
|
|
raise HTTPException(status_code=400, detail="Reference image does not have embeddings")
|
|
|
|
# Get the embedding for the reference image
|
|
reference_embedding = await vector_store_service.get_embedding(reference_image.embedding_id)
|
|
if not reference_embedding:
|
|
raise HTTPException(status_code=400, detail="Failed to get reference image embedding")
|
|
|
|
# Search for similar images
|
|
search_results = await vector_store_service.search_similar(
|
|
reference_embedding,
|
|
limit=limit + 1, # +1 to account for the reference image itself
|
|
threshold=threshold,
|
|
team_id=str(current_user.team_id)
|
|
)
|
|
|
|
# Remove the reference image from results
|
|
search_results = [result for result in search_results if result['id'] != image_id][:limit]
|
|
|
|
if not search_results:
|
|
return SearchResponse(
|
|
query=f"Similar to image {image_id}",
|
|
results=[],
|
|
total=0,
|
|
limit=limit,
|
|
threshold=threshold
|
|
)
|
|
|
|
# Get image IDs and scores from search results
|
|
image_ids = [result['id'] for result in search_results]
|
|
scores = {result['id']: result['score'] for result in search_results}
|
|
|
|
# Get image metadata from database
|
|
images = await image_repository.get_by_ids(image_ids)
|
|
|
|
# Convert to response format with similarity scores
|
|
results = []
|
|
for image in images:
|
|
image_id_str = str(image.id)
|
|
similarity_score = scores.get(image_id_str, 0.0)
|
|
|
|
result = ImageResponse(
|
|
id=image_id_str,
|
|
filename=image.filename,
|
|
original_filename=image.original_filename,
|
|
file_size=image.file_size,
|
|
content_type=image.content_type,
|
|
storage_path=image.storage_path,
|
|
team_id=str(image.team_id),
|
|
uploader_id=str(image.uploader_id),
|
|
upload_date=image.upload_date,
|
|
description=image.description,
|
|
tags=image.tags,
|
|
metadata=image.metadata,
|
|
has_embedding=image.has_embedding,
|
|
collection_id=str(image.collection_id) if image.collection_id else None,
|
|
similarity_score=similarity_score
|
|
)
|
|
results.append(result)
|
|
|
|
# Sort by similarity score (highest first)
|
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
|
|
|
return SearchResponse(
|
|
query=f"Similar to image {image_id}",
|
|
results=results,
|
|
total=len(results),
|
|
limit=limit,
|
|
threshold=threshold
|
|
)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"Error finding similar images: {e}")
|
|
raise HTTPException(status_code=500, detail="Similar image search failed") |