fix image search

This commit is contained in:
johnpccd 2025-05-25 16:52:38 +02:00
parent 9f2dd0dfc3
commit 46fd8e6e5e
6 changed files with 126 additions and 4 deletions

View File

@ -677,6 +677,7 @@ This modular architecture provides several benefits:
- [ ] Move all auth logic to auth module
- [ ] Remove bootstrap endpoint
- [ ] Move cloud function code to src folder and reuse code with embedding service
- [ ] Thumbnail generation
### Pagination Status ✅
- **✅ Images API**: Fully implemented with `skip`, `limit`, `total` parameters

View File

@ -177,11 +177,11 @@ class ApiClient {
}
// Search API
async searchImages(query, similarityThreshold = 0.7, maxResults = 20) {
async searchImages(query, similarityThreshold, maxResults = 20) {
const searchData = {
query,
similarity_threshold: similarityThreshold,
max_results: maxResults
threshold: similarityThreshold,
limit: maxResults
};
return this.makeRequest('POST', '/search', searchData);

View File

@ -31,7 +31,7 @@ async def simple_search_test():
# Test 1: Generate text embedding
logger.info("=== Generating Text Embedding ===")
search_query = "blank"
search_query = "rectangle"
text_embedding = await embedding_service.generate_text_embedding(search_query)
if text_embedding:
@ -72,6 +72,25 @@ async def simple_search_test():
logger.info(f"Total vectors in collection: {len(all_results)}")
# Test 4: With team filtering (like the API does)
logger.info("\n=== Testing Team Filtering ===")
test_team_id = "68330a29472a0704d2f77063" # From server logs
filtered_results = vector_db.search_similar_images(
query_vector=text_embedding,
limit=50,
score_threshold=0.0,
filter_conditions={"team_id": test_team_id}
)
logger.info(f"Results with team filter ({test_team_id}): {len(filtered_results)}")
# Show metadata for all results to see team_ids
logger.info("\n=== Checking Team IDs in Vector DB ===")
for i, result in enumerate(all_results):
metadata = result.get('metadata', {})
team_id = metadata.get('team_id', 'N/A')
logger.info(f" {i+1}. Image ID: {result['image_id']} | Team ID: {team_id}")
# Show some stats
if all_results:
scores = [r['score'] for r in all_results]

View File

@ -151,11 +151,16 @@ async def search_images_advanced(
try:
# Generate embedding for the search query
logger.info(f"Generating embedding for query: {search_request.query}")
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
if not query_embedding:
logger.error("Failed to generate search embedding")
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
logger.info(f"Generated embedding with length: {len(query_embedding)}")
# Search in vector database
logger.info(f"Searching vector database with threshold: {search_request.threshold}")
search_results = get_vector_db_service().search_similar_images(
query_vector=query_embedding,
limit=search_request.limit,
@ -163,7 +168,10 @@ async def search_images_advanced(
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
)
logger.info(f"Vector search returned {len(search_results) if search_results else 0} results")
if not search_results:
logger.info("No search results from vector database, returning empty response")
return SearchResponse(
query=search_request.query,
results=[],
@ -176,8 +184,12 @@ async def search_images_advanced(
image_ids = [result['image_id'] for result in search_results if result['image_id']]
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
logger.info(f"Extracted {len(image_ids)} image IDs: {image_ids}")
# Get image metadata from database
logger.info("Fetching image metadata from database...")
images = await image_repository.get_by_ids(image_ids)
logger.info(f"Retrieved {len(images)} images from database")
# Apply filters
filtered_images = []
@ -199,6 +211,8 @@ async def search_images_advanced(
filtered_images.append(image)
logger.info(f"After filtering: {len(filtered_images)} images remain")
# Convert to response format with similarity scores
results = []
for image in filtered_images:
@ -226,6 +240,8 @@ async def search_images_advanced(
# Sort by similarity score (highest first)
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
logger.info(f"Returning {len(results)} results")
return SearchResponse(
query=search_request.query,
results=results,
@ -236,4 +252,6 @@ async def search_images_advanced(
except Exception as e:
logger.error(f"Error in advanced search: {e}")
import traceback
logger.error(f"Traceback: {traceback.format_exc()}")
raise HTTPException(status_code=500, detail="Advanced search failed")

View File

@ -137,6 +137,42 @@ class FirestoreImageRepository(FirestoreRepository[ImageModel]):
logger.error(f"Error getting images by uploader ID: {e}")
raise
async def get_by_ids(self, image_ids: List[str]) -> List[ImageModel]:
"""
Get images by a list of IDs
Args:
image_ids: List of image IDs (as strings)
Returns:
List of images matching the IDs
"""
try:
if not image_ids:
return []
# Convert string IDs to ObjectIds
object_ids = []
for image_id in image_ids:
try:
object_ids.append(ObjectId(image_id))
except Exception as e:
logger.warning(f"Invalid ObjectId format: {image_id}, skipping")
continue
if not object_ids:
return []
# Get all images and filter by IDs
all_images = await self.get_all()
matching_images = [image for image in all_images if image.id in object_ids]
logger.info(f"Found {len(matching_images)} images out of {len(image_ids)} requested IDs")
return matching_images
except Exception as e:
logger.error(f"Error getting images by IDs: {e}")
raise
async def get_all_with_pagination(
self,

48
test_threshold_fix.py Normal file
View File

@ -0,0 +1,48 @@
#!/usr/bin/env python3
"""
Test script to verify that similarity threshold is properly handled
"""
import json
from src.schemas.search import SearchRequest
def test_threshold_handling():
"""Test that threshold values are properly handled in the schema"""
# Test with threshold = 0
test_data_zero = {
"query": "test query",
"threshold": 0.0,
"limit": 10
}
request_zero = SearchRequest(**test_data_zero)
print(f"Threshold 0.0 test: {request_zero.threshold}")
assert request_zero.threshold == 0.0, f"Expected 0.0, got {request_zero.threshold}"
# Test with threshold = 0.5
test_data_half = {
"query": "test query",
"threshold": 0.5,
"limit": 10
}
request_half = SearchRequest(**test_data_half)
print(f"Threshold 0.5 test: {request_half.threshold}")
assert request_half.threshold == 0.5, f"Expected 0.5, got {request_half.threshold}"
# Test with threshold = 1.0
test_data_one = {
"query": "test query",
"threshold": 1.0,
"limit": 10
}
request_one = SearchRequest(**test_data_one)
print(f"Threshold 1.0 test: {request_one.threshold}")
assert request_one.threshold == 1.0, f"Expected 1.0, got {request_one.threshold}"
print("All threshold tests passed!")
if __name__ == "__main__":
test_threshold_handling()