diff --git a/README.md b/README.md index 21de490..42986a2 100644 --- a/README.md +++ b/README.md @@ -677,6 +677,7 @@ This modular architecture provides several benefits: - [ ] Move all auth logic to auth module - [ ] Remove bootstrap endpoint - [ ] Move cloud function code to src folder and reuse code with embedding service +- [ ] Thumbnail generation ### Pagination Status ✅ - **✅ Images API**: Fully implemented with `skip`, `limit`, `total` parameters diff --git a/client/js/api.js b/client/js/api.js index 6a67f66..03b228e 100644 --- a/client/js/api.js +++ b/client/js/api.js @@ -177,11 +177,11 @@ class ApiClient { } // Search API - async searchImages(query, similarityThreshold = 0.7, maxResults = 20) { + async searchImages(query, similarityThreshold, maxResults = 20) { const searchData = { query, - similarity_threshold: similarityThreshold, - max_results: maxResults + threshold: similarityThreshold, + limit: maxResults }; return this.makeRequest('POST', '/search', searchData); diff --git a/simple_search_test.py b/simple_search_test.py index 70b638c..eafc082 100644 --- a/simple_search_test.py +++ b/simple_search_test.py @@ -31,7 +31,7 @@ async def simple_search_test(): # Test 1: Generate text embedding logger.info("=== Generating Text Embedding ===") - search_query = "blank" + search_query = "rectangle" text_embedding = await embedding_service.generate_text_embedding(search_query) if text_embedding: @@ -72,6 +72,25 @@ async def simple_search_test(): logger.info(f"Total vectors in collection: {len(all_results)}") + # Test 4: With team filtering (like the API does) + logger.info("\n=== Testing Team Filtering ===") + test_team_id = "68330a29472a0704d2f77063" # From server logs + filtered_results = vector_db.search_similar_images( + query_vector=text_embedding, + limit=50, + score_threshold=0.0, + filter_conditions={"team_id": test_team_id} + ) + + logger.info(f"Results with team filter ({test_team_id}): {len(filtered_results)}") + + # Show metadata for all results to see team_ids + logger.info("\n=== Checking Team IDs in Vector DB ===") + for i, result in enumerate(all_results): + metadata = result.get('metadata', {}) + team_id = metadata.get('team_id', 'N/A') + logger.info(f" {i+1}. Image ID: {result['image_id']} | Team ID: {team_id}") + # Show some stats if all_results: scores = [r['score'] for r in all_results] diff --git a/src/api/v1/search.py b/src/api/v1/search.py index f1499c7..cdc5b84 100644 --- a/src/api/v1/search.py +++ b/src/api/v1/search.py @@ -151,11 +151,16 @@ async def search_images_advanced( try: # Generate embedding for the search query + logger.info(f"Generating embedding for query: {search_request.query}") query_embedding = await embedding_service.generate_text_embedding(search_request.query) if not query_embedding: + logger.error("Failed to generate search embedding") raise HTTPException(status_code=400, detail="Failed to generate search embedding") + logger.info(f"Generated embedding with length: {len(query_embedding)}") + # Search in vector database + logger.info(f"Searching vector database with threshold: {search_request.threshold}") search_results = get_vector_db_service().search_similar_images( query_vector=query_embedding, limit=search_request.limit, @@ -163,7 +168,10 @@ async def search_images_advanced( filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None ) + logger.info(f"Vector search returned {len(search_results) if search_results else 0} results") + if not search_results: + logger.info("No search results from vector database, returning empty response") return SearchResponse( query=search_request.query, results=[], @@ -176,8 +184,12 @@ async def search_images_advanced( image_ids = [result['image_id'] for result in search_results if result['image_id']] scores = {result['image_id']: result['score'] for result in search_results if result['image_id']} + logger.info(f"Extracted {len(image_ids)} image IDs: {image_ids}") + # Get image metadata from database + logger.info("Fetching image metadata from database...") images = await image_repository.get_by_ids(image_ids) + logger.info(f"Retrieved {len(images)} images from database") # Apply filters filtered_images = [] @@ -199,6 +211,8 @@ async def search_images_advanced( filtered_images.append(image) + logger.info(f"After filtering: {len(filtered_images)} images remain") + # Convert to response format with similarity scores results = [] for image in filtered_images: @@ -226,6 +240,8 @@ async def search_images_advanced( # Sort by similarity score (highest first) results.sort(key=lambda x: x.similarity_score or 0, reverse=True) + logger.info(f"Returning {len(results)} results") + return SearchResponse( query=search_request.query, results=results, @@ -236,4 +252,6 @@ async def search_images_advanced( except Exception as e: logger.error(f"Error in advanced search: {e}") + import traceback + logger.error(f"Traceback: {traceback.format_exc()}") raise HTTPException(status_code=500, detail="Advanced search failed") diff --git a/src/db/repositories/firestore_image_repository.py b/src/db/repositories/firestore_image_repository.py index 95e479c..ec5de31 100644 --- a/src/db/repositories/firestore_image_repository.py +++ b/src/db/repositories/firestore_image_repository.py @@ -137,6 +137,42 @@ class FirestoreImageRepository(FirestoreRepository[ImageModel]): logger.error(f"Error getting images by uploader ID: {e}") raise + async def get_by_ids(self, image_ids: List[str]) -> List[ImageModel]: + """ + Get images by a list of IDs + + Args: + image_ids: List of image IDs (as strings) + + Returns: + List of images matching the IDs + """ + try: + if not image_ids: + return [] + + # Convert string IDs to ObjectIds + object_ids = [] + for image_id in image_ids: + try: + object_ids.append(ObjectId(image_id)) + except Exception as e: + logger.warning(f"Invalid ObjectId format: {image_id}, skipping") + continue + + if not object_ids: + return [] + + # Get all images and filter by IDs + all_images = await self.get_all() + matching_images = [image for image in all_images if image.id in object_ids] + + logger.info(f"Found {len(matching_images)} images out of {len(image_ids)} requested IDs") + return matching_images + + except Exception as e: + logger.error(f"Error getting images by IDs: {e}") + raise async def get_all_with_pagination( self, diff --git a/test_threshold_fix.py b/test_threshold_fix.py new file mode 100644 index 0000000..506b9bc --- /dev/null +++ b/test_threshold_fix.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +""" +Test script to verify that similarity threshold is properly handled +""" + +import json +from src.schemas.search import SearchRequest + +def test_threshold_handling(): + """Test that threshold values are properly handled in the schema""" + + # Test with threshold = 0 + test_data_zero = { + "query": "test query", + "threshold": 0.0, + "limit": 10 + } + + request_zero = SearchRequest(**test_data_zero) + print(f"Threshold 0.0 test: {request_zero.threshold}") + assert request_zero.threshold == 0.0, f"Expected 0.0, got {request_zero.threshold}" + + # Test with threshold = 0.5 + test_data_half = { + "query": "test query", + "threshold": 0.5, + "limit": 10 + } + + request_half = SearchRequest(**test_data_half) + print(f"Threshold 0.5 test: {request_half.threshold}") + assert request_half.threshold == 0.5, f"Expected 0.5, got {request_half.threshold}" + + # Test with threshold = 1.0 + test_data_one = { + "query": "test query", + "threshold": 1.0, + "limit": 10 + } + + request_one = SearchRequest(**test_data_one) + print(f"Threshold 1.0 test: {request_one.threshold}") + assert request_one.threshold == 1.0, f"Expected 1.0, got {request_one.threshold}" + + print("All threshold tests passed!") + +if __name__ == "__main__": + test_threshold_handling() \ No newline at end of file