fix image search
This commit is contained in:
parent
9f2dd0dfc3
commit
46fd8e6e5e
@ -677,6 +677,7 @@ This modular architecture provides several benefits:
|
||||
- [ ] Move all auth logic to auth module
|
||||
- [ ] Remove bootstrap endpoint
|
||||
- [ ] Move cloud function code to src folder and reuse code with embedding service
|
||||
- [ ] Thumbnail generation
|
||||
|
||||
### Pagination Status ✅
|
||||
- **✅ Images API**: Fully implemented with `skip`, `limit`, `total` parameters
|
||||
|
||||
@ -177,11 +177,11 @@ class ApiClient {
|
||||
}
|
||||
|
||||
// Search API
|
||||
async searchImages(query, similarityThreshold = 0.7, maxResults = 20) {
|
||||
async searchImages(query, similarityThreshold, maxResults = 20) {
|
||||
const searchData = {
|
||||
query,
|
||||
similarity_threshold: similarityThreshold,
|
||||
max_results: maxResults
|
||||
threshold: similarityThreshold,
|
||||
limit: maxResults
|
||||
};
|
||||
|
||||
return this.makeRequest('POST', '/search', searchData);
|
||||
|
||||
@ -31,7 +31,7 @@ async def simple_search_test():
|
||||
|
||||
# Test 1: Generate text embedding
|
||||
logger.info("=== Generating Text Embedding ===")
|
||||
search_query = "blank"
|
||||
search_query = "rectangle"
|
||||
text_embedding = await embedding_service.generate_text_embedding(search_query)
|
||||
|
||||
if text_embedding:
|
||||
@ -72,6 +72,25 @@ async def simple_search_test():
|
||||
|
||||
logger.info(f"Total vectors in collection: {len(all_results)}")
|
||||
|
||||
# Test 4: With team filtering (like the API does)
|
||||
logger.info("\n=== Testing Team Filtering ===")
|
||||
test_team_id = "68330a29472a0704d2f77063" # From server logs
|
||||
filtered_results = vector_db.search_similar_images(
|
||||
query_vector=text_embedding,
|
||||
limit=50,
|
||||
score_threshold=0.0,
|
||||
filter_conditions={"team_id": test_team_id}
|
||||
)
|
||||
|
||||
logger.info(f"Results with team filter ({test_team_id}): {len(filtered_results)}")
|
||||
|
||||
# Show metadata for all results to see team_ids
|
||||
logger.info("\n=== Checking Team IDs in Vector DB ===")
|
||||
for i, result in enumerate(all_results):
|
||||
metadata = result.get('metadata', {})
|
||||
team_id = metadata.get('team_id', 'N/A')
|
||||
logger.info(f" {i+1}. Image ID: {result['image_id']} | Team ID: {team_id}")
|
||||
|
||||
# Show some stats
|
||||
if all_results:
|
||||
scores = [r['score'] for r in all_results]
|
||||
|
||||
@ -151,11 +151,16 @@ async def search_images_advanced(
|
||||
|
||||
try:
|
||||
# Generate embedding for the search query
|
||||
logger.info(f"Generating embedding for query: {search_request.query}")
|
||||
query_embedding = await embedding_service.generate_text_embedding(search_request.query)
|
||||
if not query_embedding:
|
||||
logger.error("Failed to generate search embedding")
|
||||
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
||||
|
||||
logger.info(f"Generated embedding with length: {len(query_embedding)}")
|
||||
|
||||
# Search in vector database
|
||||
logger.info(f"Searching vector database with threshold: {search_request.threshold}")
|
||||
search_results = get_vector_db_service().search_similar_images(
|
||||
query_vector=query_embedding,
|
||||
limit=search_request.limit,
|
||||
@ -163,7 +168,10 @@ async def search_images_advanced(
|
||||
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
||||
)
|
||||
|
||||
logger.info(f"Vector search returned {len(search_results) if search_results else 0} results")
|
||||
|
||||
if not search_results:
|
||||
logger.info("No search results from vector database, returning empty response")
|
||||
return SearchResponse(
|
||||
query=search_request.query,
|
||||
results=[],
|
||||
@ -176,8 +184,12 @@ async def search_images_advanced(
|
||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
||||
|
||||
logger.info(f"Extracted {len(image_ids)} image IDs: {image_ids}")
|
||||
|
||||
# Get image metadata from database
|
||||
logger.info("Fetching image metadata from database...")
|
||||
images = await image_repository.get_by_ids(image_ids)
|
||||
logger.info(f"Retrieved {len(images)} images from database")
|
||||
|
||||
# Apply filters
|
||||
filtered_images = []
|
||||
@ -199,6 +211,8 @@ async def search_images_advanced(
|
||||
|
||||
filtered_images.append(image)
|
||||
|
||||
logger.info(f"After filtering: {len(filtered_images)} images remain")
|
||||
|
||||
# Convert to response format with similarity scores
|
||||
results = []
|
||||
for image in filtered_images:
|
||||
@ -226,6 +240,8 @@ async def search_images_advanced(
|
||||
# Sort by similarity score (highest first)
|
||||
results.sort(key=lambda x: x.similarity_score or 0, reverse=True)
|
||||
|
||||
logger.info(f"Returning {len(results)} results")
|
||||
|
||||
return SearchResponse(
|
||||
query=search_request.query,
|
||||
results=results,
|
||||
@ -236,4 +252,6 @@ async def search_images_advanced(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in advanced search: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
raise HTTPException(status_code=500, detail="Advanced search failed")
|
||||
|
||||
@ -137,6 +137,42 @@ class FirestoreImageRepository(FirestoreRepository[ImageModel]):
|
||||
logger.error(f"Error getting images by uploader ID: {e}")
|
||||
raise
|
||||
|
||||
async def get_by_ids(self, image_ids: List[str]) -> List[ImageModel]:
|
||||
"""
|
||||
Get images by a list of IDs
|
||||
|
||||
Args:
|
||||
image_ids: List of image IDs (as strings)
|
||||
|
||||
Returns:
|
||||
List of images matching the IDs
|
||||
"""
|
||||
try:
|
||||
if not image_ids:
|
||||
return []
|
||||
|
||||
# Convert string IDs to ObjectIds
|
||||
object_ids = []
|
||||
for image_id in image_ids:
|
||||
try:
|
||||
object_ids.append(ObjectId(image_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"Invalid ObjectId format: {image_id}, skipping")
|
||||
continue
|
||||
|
||||
if not object_ids:
|
||||
return []
|
||||
|
||||
# Get all images and filter by IDs
|
||||
all_images = await self.get_all()
|
||||
matching_images = [image for image in all_images if image.id in object_ids]
|
||||
|
||||
logger.info(f"Found {len(matching_images)} images out of {len(image_ids)} requested IDs")
|
||||
return matching_images
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting images by IDs: {e}")
|
||||
raise
|
||||
|
||||
async def get_all_with_pagination(
|
||||
self,
|
||||
|
||||
48
test_threshold_fix.py
Normal file
48
test_threshold_fix.py
Normal file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify that similarity threshold is properly handled
|
||||
"""
|
||||
|
||||
import json
|
||||
from src.schemas.search import SearchRequest
|
||||
|
||||
def test_threshold_handling():
|
||||
"""Test that threshold values are properly handled in the schema"""
|
||||
|
||||
# Test with threshold = 0
|
||||
test_data_zero = {
|
||||
"query": "test query",
|
||||
"threshold": 0.0,
|
||||
"limit": 10
|
||||
}
|
||||
|
||||
request_zero = SearchRequest(**test_data_zero)
|
||||
print(f"Threshold 0.0 test: {request_zero.threshold}")
|
||||
assert request_zero.threshold == 0.0, f"Expected 0.0, got {request_zero.threshold}"
|
||||
|
||||
# Test with threshold = 0.5
|
||||
test_data_half = {
|
||||
"query": "test query",
|
||||
"threshold": 0.5,
|
||||
"limit": 10
|
||||
}
|
||||
|
||||
request_half = SearchRequest(**test_data_half)
|
||||
print(f"Threshold 0.5 test: {request_half.threshold}")
|
||||
assert request_half.threshold == 0.5, f"Expected 0.5, got {request_half.threshold}"
|
||||
|
||||
# Test with threshold = 1.0
|
||||
test_data_one = {
|
||||
"query": "test query",
|
||||
"threshold": 1.0,
|
||||
"limit": 10
|
||||
}
|
||||
|
||||
request_one = SearchRequest(**test_data_one)
|
||||
print(f"Threshold 1.0 test: {request_one.threshold}")
|
||||
assert request_one.threshold == 1.0, f"Expected 1.0, got {request_one.threshold}"
|
||||
|
||||
print("All threshold tests passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_threshold_handling()
|
||||
Loading…
x
Reference in New Issue
Block a user