fix image search
This commit is contained in:
parent
9f2dd0dfc3
commit
46fd8e6e5e
@ -677,6 +677,7 @@ This modular architecture provides several benefits:
|
|||||||
- [ ] Move all auth logic to auth module
|
- [ ] Move all auth logic to auth module
|
||||||
- [ ] Remove bootstrap endpoint
|
- [ ] Remove bootstrap endpoint
|
||||||
- [ ] Move cloud function code to src folder and reuse code with embedding service
|
- [ ] Move cloud function code to src folder and reuse code with embedding service
|
||||||
|
- [ ] Thumbnail generation
|
||||||
|
|
||||||
### Pagination Status ✅
|
### Pagination Status ✅
|
||||||
- **✅ Images API**: Fully implemented with `skip`, `limit`, `total` parameters
|
- **✅ Images API**: Fully implemented with `skip`, `limit`, `total` parameters
|
||||||
|
|||||||
@ -177,11 +177,11 @@ class ApiClient {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Search API
|
// Search API
|
||||||
async searchImages(query, similarityThreshold = 0.7, maxResults = 20) {
|
async searchImages(query, similarityThreshold, maxResults = 20) {
|
||||||
const searchData = {
|
const searchData = {
|
||||||
query,
|
query,
|
||||||
similarity_threshold: similarityThreshold,
|
threshold: similarityThreshold,
|
||||||
max_results: maxResults
|
limit: maxResults
|
||||||
};
|
};
|
||||||
|
|
||||||
return this.makeRequest('POST', '/search', searchData);
|
return this.makeRequest('POST', '/search', searchData);
|
||||||
|
|||||||
@ -31,7 +31,7 @@ async def simple_search_test():
|
|||||||
|
|
||||||
# Test 1: Generate text embedding
|
# Test 1: Generate text embedding
|
||||||
logger.info("=== Generating Text Embedding ===")
|
logger.info("=== Generating Text Embedding ===")
|
||||||
search_query = "blank"
|
search_query = "rectangle"
|
||||||
text_embedding = await embedding_service.generate_text_embedding(search_query)
|
text_embedding = await embedding_service.generate_text_embedding(search_query)
|
||||||
|
|
||||||
if text_embedding:
|
if text_embedding:
|
||||||
@ -72,6 +72,25 @@ async def simple_search_test():
|
|||||||
|
|
||||||
logger.info(f"Total vectors in collection: {len(all_results)}")
|
logger.info(f"Total vectors in collection: {len(all_results)}")
|
||||||
|
|
||||||
|
# Test 4: With team filtering (like the API does)
|
||||||
|
logger.info("\n=== Testing Team Filtering ===")
|
||||||
|
test_team_id = "68330a29472a0704d2f77063" # From server logs
|
||||||
|
filtered_results = vector_db.search_similar_images(
|
||||||
|
query_vector=text_embedding,
|
||||||
|
limit=50,
|
||||||
|
score_threshold=0.0,
|
||||||
|
filter_conditions={"team_id": test_team_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Results with team filter ({test_team_id}): {len(filtered_results)}")
|
||||||
|
|
||||||
|
# Show metadata for all results to see team_ids
|
||||||
|
logger.info("\n=== Checking Team IDs in Vector DB ===")
|
||||||
|
for i, result in enumerate(all_results):
|
||||||
|
metadata = result.get('metadata', {})
|
||||||
|
team_id = metadata.get('team_id', 'N/A')
|
||||||
|
logger.info(f" {i+1}. Image ID: {result['image_id']} | Team ID: {team_id}")
|
||||||
|
|
||||||
# Show some stats
|
# Show some stats
|
||||||
if all_results:
|
if all_results:
|
||||||
scores = [r['score'] for r in all_results]
|
scores = [r['score'] for r in all_results]
|
||||||
|
|||||||
@ -151,11 +151,16 @@ async def search_images_advanced(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Generate embedding for the search query
|
# Generate embedding for the search query
|
||||||
|
logger.info(f"Generating embedding for query: {search_request.query}")
|
||||||
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
|
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
|
||||||
if not query_embedding:
|
if not query_embedding:
|
||||||
|
logger.error("Failed to generate search embedding")
|
||||||
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
||||||
|
|
||||||
|
logger.info(f"Generated embedding with length: {len(query_embedding)}")
|
||||||
|
|
||||||
# Search in vector database
|
# Search in vector database
|
||||||
|
logger.info(f"Searching vector database with threshold: {search_request.threshold}")
|
||||||
search_results = get_vector_db_service().search_similar_images(
|
search_results = get_vector_db_service().search_similar_images(
|
||||||
query_vector=query_embedding,
|
query_vector=query_embedding,
|
||||||
limit=search_request.limit,
|
limit=search_request.limit,
|
||||||
@ -163,7 +168,10 @@ async def search_images_advanced(
|
|||||||
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info(f"Vector search returned {len(search_results) if search_results else 0} results")
|
||||||
|
|
||||||
if not search_results:
|
if not search_results:
|
||||||
|
logger.info("No search results from vector database, returning empty response")
|
||||||
return SearchResponse(
|
return SearchResponse(
|
||||||
query=search_request.query,
|
query=search_request.query,
|
||||||
results=[],
|
results=[],
|
||||||
@ -176,8 +184,12 @@ async def search_images_advanced(
|
|||||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||||
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
||||||
|
|
||||||
|
logger.info(f"Extracted {len(image_ids)} image IDs: {image_ids}")
|
||||||
|
|
||||||
# Get image metadata from database
|
# Get image metadata from database
|
||||||
|
logger.info("Fetching image metadata from database...")
|
||||||
images = await image_repository.get_by_ids(image_ids)
|
images = await image_repository.get_by_ids(image_ids)
|
||||||
|
logger.info(f"Retrieved {len(images)} images from database")
|
||||||
|
|
||||||
# Apply filters
|
# Apply filters
|
||||||
filtered_images = []
|
filtered_images = []
|
||||||
@ -199,6 +211,8 @@ async def search_images_advanced(
|
|||||||
|
|
||||||
filtered_images.append(image)
|
filtered_images.append(image)
|
||||||
|
|
||||||
|
logger.info(f"After filtering: {len(filtered_images)} images remain")
|
||||||
|
|
||||||
# Convert to response format with similarity scores
|
# Convert to response format with similarity scores
|
||||||
results = []
|
results = []
|
||||||
for image in filtered_images:
|
for image in filtered_images:
|
||||||
@ -226,6 +240,8 @@ async def search_images_advanced(
|
|||||||
# Sort by similarity score (highest first)
|
# Sort by similarity score (highest first)
|
||||||
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||||
|
|
||||||
|
logger.info(f"Returning {len(results)} results")
|
||||||
|
|
||||||
return SearchResponse(
|
return SearchResponse(
|
||||||
query=search_request.query,
|
query=search_request.query,
|
||||||
results=results,
|
results=results,
|
||||||
@ -236,4 +252,6 @@ async def search_images_advanced(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in advanced search: {e}")
|
logger.error(f"Error in advanced search: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
raise HTTPException(status_code=500, detail="Advanced search failed")
|
raise HTTPException(status_code=500, detail="Advanced search failed")
|
||||||
|
|||||||
@ -137,6 +137,42 @@ class FirestoreImageRepository(FirestoreRepository[ImageModel]):
|
|||||||
logger.error(f"Error getting images by uploader ID: {e}")
|
logger.error(f"Error getting images by uploader ID: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
async def get_by_ids(self, image_ids: List[str]) -> List[ImageModel]:
|
||||||
|
"""
|
||||||
|
Get images by a list of IDs
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_ids: List of image IDs (as strings)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of images matching the IDs
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if not image_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Convert string IDs to ObjectIds
|
||||||
|
object_ids = []
|
||||||
|
for image_id in image_ids:
|
||||||
|
try:
|
||||||
|
object_ids.append(ObjectId(image_id))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Invalid ObjectId format: {image_id}, skipping")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not object_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get all images and filter by IDs
|
||||||
|
all_images = await self.get_all()
|
||||||
|
matching_images = [image for image in all_images if image.id in object_ids]
|
||||||
|
|
||||||
|
logger.info(f"Found {len(matching_images)} images out of {len(image_ids)} requested IDs")
|
||||||
|
return matching_images
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting images by IDs: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def get_all_with_pagination(
|
async def get_all_with_pagination(
|
||||||
self,
|
self,
|
||||||
|
|||||||
48
test_threshold_fix.py
Normal file
48
test_threshold_fix.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test script to verify that similarity threshold is properly handled
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
from src.schemas.search import SearchRequest
|
||||||
|
|
||||||
|
def test_threshold_handling():
|
||||||
|
"""Test that threshold values are properly handled in the schema"""
|
||||||
|
|
||||||
|
# Test with threshold = 0
|
||||||
|
test_data_zero = {
|
||||||
|
"query": "test query",
|
||||||
|
"threshold": 0.0,
|
||||||
|
"limit": 10
|
||||||
|
}
|
||||||
|
|
||||||
|
request_zero = SearchRequest(**test_data_zero)
|
||||||
|
print(f"Threshold 0.0 test: {request_zero.threshold}")
|
||||||
|
assert request_zero.threshold == 0.0, f"Expected 0.0, got {request_zero.threshold}"
|
||||||
|
|
||||||
|
# Test with threshold = 0.5
|
||||||
|
test_data_half = {
|
||||||
|
"query": "test query",
|
||||||
|
"threshold": 0.5,
|
||||||
|
"limit": 10
|
||||||
|
}
|
||||||
|
|
||||||
|
request_half = SearchRequest(**test_data_half)
|
||||||
|
print(f"Threshold 0.5 test: {request_half.threshold}")
|
||||||
|
assert request_half.threshold == 0.5, f"Expected 0.5, got {request_half.threshold}"
|
||||||
|
|
||||||
|
# Test with threshold = 1.0
|
||||||
|
test_data_one = {
|
||||||
|
"query": "test query",
|
||||||
|
"threshold": 1.0,
|
||||||
|
"limit": 10
|
||||||
|
}
|
||||||
|
|
||||||
|
request_one = SearchRequest(**test_data_one)
|
||||||
|
print(f"Threshold 1.0 test: {request_one.threshold}")
|
||||||
|
assert request_one.threshold == 1.0, f"Expected 1.0, got {request_one.threshold}"
|
||||||
|
|
||||||
|
print("All threshold tests passed!")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_threshold_handling()
|
||||||
Loading…
x
Reference in New Issue
Block a user