cp
This commit is contained in:
parent
32b074bcc4
commit
4d8447bb40
167
deployment/scripts/setup_local_env.sh
Normal file
167
deployment/scripts/setup_local_env.sh
Normal file
@ -0,0 +1,167 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Setup Local Environment Script
|
||||||
|
# This script extracts configuration from terraform and sets up local development
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||||
|
TERRAFORM_DIR="$PROJECT_ROOT/deployment/terraform"
|
||||||
|
|
||||||
|
echo "Setting up local development environment..."
|
||||||
|
|
||||||
|
# Check if terraform directory exists
|
||||||
|
if [ ! -d "$TERRAFORM_DIR" ]; then
|
||||||
|
echo "Error: Terraform directory not found at $TERRAFORM_DIR"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Change to terraform directory
|
||||||
|
cd "$TERRAFORM_DIR"
|
||||||
|
|
||||||
|
# Check if terraform state exists
|
||||||
|
if [ ! -f "terraform.tfstate" ]; then
|
||||||
|
echo "Error: Terraform state not found. Please run 'terraform apply' first."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Extracting configuration from terraform..."
|
||||||
|
|
||||||
|
# Get terraform outputs
|
||||||
|
QDRANT_HOST=$(terraform output -raw vector_db_vm_external_ip 2>/dev/null || echo "")
|
||||||
|
QDRANT_PORT="6333"
|
||||||
|
FIRESTORE_PROJECT_ID=$(terraform output -raw firestore_database_id 2>/dev/null | cut -d'/' -f2 || echo "")
|
||||||
|
GCS_BUCKET_NAME=$(terraform output -raw storage_bucket_name 2>/dev/null || echo "")
|
||||||
|
|
||||||
|
if [ -z "$QDRANT_HOST" ]; then
|
||||||
|
echo "Error: Could not extract Qdrant host from terraform outputs"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "Configuration extracted:"
|
||||||
|
echo " Qdrant Host: $QDRANT_HOST"
|
||||||
|
echo " Qdrant Port: $QDRANT_PORT"
|
||||||
|
echo " Firestore Project: $FIRESTORE_PROJECT_ID"
|
||||||
|
echo " GCS Bucket: $GCS_BUCKET_NAME"
|
||||||
|
|
||||||
|
# Go back to project root
|
||||||
|
cd "$PROJECT_ROOT"
|
||||||
|
|
||||||
|
# Update start_dev.sh with extracted values
|
||||||
|
echo "Updating start_dev.sh..."
|
||||||
|
|
||||||
|
cat > start_dev.sh << EOF
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Development startup script for Sereact API
|
||||||
|
# This script sets the environment variables and starts the application
|
||||||
|
# Auto-generated by deployment/scripts/setup_local_env.sh
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
source venv/Scripts/activate
|
||||||
|
|
||||||
|
# Set environment variables from deployed infrastructure
|
||||||
|
export QDRANT_HOST=$QDRANT_HOST
|
||||||
|
export QDRANT_PORT=$QDRANT_PORT
|
||||||
|
export FIRESTORE_PROJECT_ID=$FIRESTORE_PROJECT_ID
|
||||||
|
export GCS_BUCKET_NAME=$GCS_BUCKET_NAME
|
||||||
|
export ENVIRONMENT=development
|
||||||
|
|
||||||
|
# Start the application
|
||||||
|
echo "Starting Sereact API with deployed infrastructure..."
|
||||||
|
echo "Qdrant endpoint: http://\$QDRANT_HOST:\$QDRANT_PORT"
|
||||||
|
echo "Firestore project: \$FIRESTORE_PROJECT_ID"
|
||||||
|
echo "GCS bucket: \$GCS_BUCKET_NAME"
|
||||||
|
echo "API will be available at: http://localhost:8000"
|
||||||
|
echo "API documentation: http://localhost:8000/docs"
|
||||||
|
|
||||||
|
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
EOF
|
||||||
|
|
||||||
|
chmod +x start_dev.sh
|
||||||
|
|
||||||
|
# Update docker-compose.yml with extracted values
|
||||||
|
echo "Updating docker-compose.yml..."
|
||||||
|
|
||||||
|
cat > docker-compose.yml << EOF
|
||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
api:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8000:8000"
|
||||||
|
volumes:
|
||||||
|
- .:/app
|
||||||
|
- \${GOOGLE_APPLICATION_CREDENTIALS:-./firestore-credentials.json}:/app/firestore-credentials.json:ro
|
||||||
|
environment:
|
||||||
|
- PYTHONUNBUFFERED=1
|
||||||
|
- ENVIRONMENT=development
|
||||||
|
- FIRESTORE_CREDENTIALS_FILE=/app/firestore-credentials.json
|
||||||
|
- GOOGLE_APPLICATION_CREDENTIALS=/app/firestore-credentials.json
|
||||||
|
- FIRESTORE_PROJECT_ID=\${FIRESTORE_PROJECT_ID:-$FIRESTORE_PROJECT_ID}
|
||||||
|
- QDRANT_HOST=$QDRANT_HOST
|
||||||
|
- QDRANT_PORT=$QDRANT_PORT
|
||||||
|
- GCS_BUCKET_NAME=$GCS_BUCKET_NAME
|
||||||
|
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Update test script
|
||||||
|
echo "Updating test_qdrant_connection.py..."
|
||||||
|
|
||||||
|
cat > test_qdrant_connection.py << EOF
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple test script to verify Qdrant connection
|
||||||
|
Auto-generated by deployment/scripts/setup_local_env.sh
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from src.services.vector_db import VectorDatabaseService
|
||||||
|
|
||||||
|
def test_qdrant_connection():
|
||||||
|
"""Test the connection to Qdrant"""
|
||||||
|
|
||||||
|
# Set environment variables from deployed infrastructure
|
||||||
|
os.environ['QDRANT_HOST'] = '$QDRANT_HOST'
|
||||||
|
os.environ['QDRANT_PORT'] = '$QDRANT_PORT'
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Testing Qdrant connection...")
|
||||||
|
print(f"Host: {os.environ['QDRANT_HOST']}")
|
||||||
|
print(f"Port: {os.environ['QDRANT_PORT']}")
|
||||||
|
|
||||||
|
# Initialize the service
|
||||||
|
vector_db = VectorDatabaseService()
|
||||||
|
|
||||||
|
# Test health check
|
||||||
|
is_healthy = vector_db.health_check()
|
||||||
|
print(f"Health check: {'✓ PASSED' if is_healthy else '✗ FAILED'}")
|
||||||
|
|
||||||
|
# Get collection info
|
||||||
|
collection_info = vector_db.get_collection_info()
|
||||||
|
print(f"Collection info: {collection_info}")
|
||||||
|
|
||||||
|
print("\n✓ Qdrant connection test PASSED!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n✗ Qdrant connection test FAILED: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = test_qdrant_connection()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
|
EOF
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "✓ Local development environment setup complete!"
|
||||||
|
echo ""
|
||||||
|
echo "You can now:"
|
||||||
|
echo " 1. Run './start_dev.sh' to start the API with deployed infrastructure"
|
||||||
|
echo " 2. Run 'docker-compose up' to use Docker with deployed Qdrant"
|
||||||
|
echo " 3. Run 'python test_qdrant_connection.py' to test Qdrant connection"
|
||||||
|
echo ""
|
||||||
|
echo "All configuration has been automatically extracted from your terraform deployment."
|
||||||
@ -52,6 +52,12 @@ resource "google_cloud_run_service" "sereact" {
|
|||||||
name = "sereact"
|
name = "sereact"
|
||||||
location = var.region
|
location = var.region
|
||||||
|
|
||||||
|
metadata {
|
||||||
|
annotations = {
|
||||||
|
"run.googleapis.com/ingress" = "all"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template {
|
template {
|
||||||
spec {
|
spec {
|
||||||
containers {
|
containers {
|
||||||
@ -75,13 +81,8 @@ resource "google_cloud_run_service" "sereact" {
|
|||||||
}
|
}
|
||||||
|
|
||||||
env {
|
env {
|
||||||
name = "FIRESTORE_CREDENTIALS_FILE"
|
name = "FIRESTORE_DATABASE_NAME"
|
||||||
value = "/var/secrets/google/key.json"
|
value = var.firestore_db_name
|
||||||
}
|
|
||||||
|
|
||||||
env {
|
|
||||||
name = "GOOGLE_APPLICATION_CREDENTIALS"
|
|
||||||
value = "/var/secrets/google/key.json"
|
|
||||||
}
|
}
|
||||||
|
|
||||||
env {
|
env {
|
||||||
@ -99,6 +100,21 @@ resource "google_cloud_run_service" "sereact" {
|
|||||||
value = var.vector_db_index_name
|
value = var.vector_db_index_name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
env {
|
||||||
|
name = "QDRANT_HOST"
|
||||||
|
value = var.use_static_ip ? google_compute_address.vector_db_static_ip.address : google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip
|
||||||
|
}
|
||||||
|
|
||||||
|
env {
|
||||||
|
name = "QDRANT_PORT"
|
||||||
|
value = "6333"
|
||||||
|
}
|
||||||
|
|
||||||
|
env {
|
||||||
|
name = "QDRANT_API_KEY"
|
||||||
|
value = var.qdrant_api_key
|
||||||
|
}
|
||||||
|
|
||||||
env {
|
env {
|
||||||
name = "LOG_LEVEL"
|
name = "LOG_LEVEL"
|
||||||
value = "INFO"
|
value = "INFO"
|
||||||
@ -109,7 +125,6 @@ resource "google_cloud_run_service" "sereact" {
|
|||||||
metadata {
|
metadata {
|
||||||
annotations = {
|
annotations = {
|
||||||
"autoscaling.knative.dev/maxScale" = "10"
|
"autoscaling.knative.dev/maxScale" = "10"
|
||||||
"run.googleapis.com/ingress" = "all"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -119,7 +134,7 @@ resource "google_cloud_run_service" "sereact" {
|
|||||||
latest_revision = true
|
latest_revision = true
|
||||||
}
|
}
|
||||||
|
|
||||||
depends_on = [google_project_service.services]
|
depends_on = [google_project_service.services, google_compute_instance.vector_db_vm]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Make the Cloud Run service publicly accessible
|
# Make the Cloud Run service publicly accessible
|
||||||
|
|||||||
@ -48,3 +48,19 @@ output "qdrant_grpc_endpoint" {
|
|||||||
value = "http://${google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip}:6334"
|
value = "http://${google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip}:6334"
|
||||||
description = "The gRPC endpoint for Qdrant vector database"
|
description = "The gRPC endpoint for Qdrant vector database"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Cloud Run environment configuration
|
||||||
|
output "cloud_run_qdrant_host" {
|
||||||
|
value = var.use_static_ip ? google_compute_address.vector_db_static_ip.address : google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip
|
||||||
|
description = "The Qdrant host configured for Cloud Run"
|
||||||
|
}
|
||||||
|
|
||||||
|
output "deployment_summary" {
|
||||||
|
value = {
|
||||||
|
cloud_run_url = google_cloud_run_service.sereact.status[0].url
|
||||||
|
qdrant_endpoint = "http://${google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip}:6333"
|
||||||
|
firestore_database = var.firestore_db_name
|
||||||
|
storage_bucket = var.storage_bucket_name
|
||||||
|
}
|
||||||
|
description = "Summary of deployed resources"
|
||||||
|
}
|
||||||
@ -4,11 +4,6 @@ locals {
|
|||||||
cloud_function_service_account = var.cloud_function_service_account != "" ? var.cloud_function_service_account : "${data.google_project.current.number}-compute@developer.gserviceaccount.com"
|
cloud_function_service_account = var.cloud_function_service_account != "" ? var.cloud_function_service_account : "${data.google_project.current.number}-compute@developer.gserviceaccount.com"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get current project information
|
|
||||||
data "google_project" "current" {
|
|
||||||
project_id = var.project_id
|
|
||||||
}
|
|
||||||
|
|
||||||
# Pub/Sub topic for image processing tasks
|
# Pub/Sub topic for image processing tasks
|
||||||
resource "google_pubsub_topic" "image_processing" {
|
resource "google_pubsub_topic" "image_processing" {
|
||||||
name = var.pubsub_topic_name
|
name = var.pubsub_topic_name
|
||||||
@ -31,10 +26,10 @@ resource "google_pubsub_subscription" "image_processing" {
|
|||||||
maximum_backoff = "600s"
|
maximum_backoff = "600s"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Dead letter policy after 3 failed attempts
|
# Dead letter policy after 5 failed attempts
|
||||||
dead_letter_policy {
|
dead_letter_policy {
|
||||||
dead_letter_topic = google_pubsub_topic.image_processing_dlq.id
|
dead_letter_topic = google_pubsub_topic.image_processing_dlq.id
|
||||||
max_delivery_attempts = 3
|
max_delivery_attempts = 5
|
||||||
}
|
}
|
||||||
|
|
||||||
# Message retention
|
# Message retention
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -23,7 +23,7 @@ variable "storage_bucket_name" {
|
|||||||
variable "firestore_db_name" {
|
variable "firestore_db_name" {
|
||||||
description = "The name of the Firestore database"
|
description = "The name of the Firestore database"
|
||||||
type = string
|
type = string
|
||||||
default = "imagedb"
|
default = "sereact-imagedb"
|
||||||
}
|
}
|
||||||
|
|
||||||
variable "environment" {
|
variable "environment" {
|
||||||
|
|||||||
@ -13,5 +13,8 @@ services:
|
|||||||
- ENVIRONMENT=development
|
- ENVIRONMENT=development
|
||||||
- FIRESTORE_CREDENTIALS_FILE=/app/firestore-credentials.json
|
- FIRESTORE_CREDENTIALS_FILE=/app/firestore-credentials.json
|
||||||
- GOOGLE_APPLICATION_CREDENTIALS=/app/firestore-credentials.json
|
- GOOGLE_APPLICATION_CREDENTIALS=/app/firestore-credentials.json
|
||||||
- FIRESTORE_PROJECT_ID=${FIRESTORE_PROJECT_ID:-}
|
- FIRESTORE_PROJECT_ID=${FIRESTORE_PROJECT_ID:-gen-lang-client-0424120530}
|
||||||
|
- QDRANT_HOST=34.171.134.17
|
||||||
|
- QDRANT_PORT=6333
|
||||||
|
- GCS_BUCKET_NAME=sereact-images
|
||||||
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
@ -1,2 +1 @@
|
|||||||
from src.api.v1 import teams, auth
|
# API v1 package
|
||||||
from src.api.v1 import users, images, search
|
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List, Dict, Any
|
||||||
from fastapi import APIRouter, Depends, Query, Request, HTTPException
|
from fastapi import APIRouter, Depends, Query, Request, HTTPException
|
||||||
|
|
||||||
from src.api.v1.auth import get_current_user
|
from src.api.v1.auth import get_current_user
|
||||||
from src.db.repositories.image_repository import image_repository
|
from src.services.vector_db import VectorDatabaseService
|
||||||
from src.services.vector_store import VectorStoreService
|
|
||||||
from src.services.embedding_service import EmbeddingService
|
from src.services.embedding_service import EmbeddingService
|
||||||
|
from src.db.repositories.image_repository import image_repository
|
||||||
|
from src.db.repositories.team_repository import team_repository
|
||||||
from src.models.user import UserModel
|
from src.models.user import UserModel
|
||||||
from src.schemas.image import ImageResponse
|
from src.schemas.image import ImageResponse
|
||||||
from src.schemas.search import SearchResponse, SearchRequest
|
from src.schemas.search import SearchResponse, SearchRequest
|
||||||
@ -16,7 +17,7 @@ logger = logging.getLogger(__name__)
|
|||||||
router = APIRouter(tags=["Search"], prefix="/search")
|
router = APIRouter(tags=["Search"], prefix="/search")
|
||||||
|
|
||||||
# Initialize services
|
# Initialize services
|
||||||
vector_store_service = VectorStoreService()
|
vector_db_service = VectorDatabaseService()
|
||||||
embedding_service = EmbeddingService()
|
embedding_service = EmbeddingService()
|
||||||
|
|
||||||
@router.get("", response_model=SearchResponse)
|
@router.get("", response_model=SearchResponse)
|
||||||
@ -51,11 +52,11 @@ async def search_images(
|
|||||||
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
||||||
|
|
||||||
# Search in vector database
|
# Search in vector database
|
||||||
search_results = await vector_store_service.search_similar(
|
search_results = vector_db_service.search_similar_images(
|
||||||
query_embedding,
|
query_vector=query_embedding,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
threshold=threshold,
|
score_threshold=threshold,
|
||||||
team_id=str(current_user.team_id)
|
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if not search_results:
|
if not search_results:
|
||||||
@ -68,8 +69,8 @@ async def search_images(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get image IDs and scores from search results
|
# Get image IDs and scores from search results
|
||||||
image_ids = [result['id'] for result in search_results]
|
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||||
scores = {result['id']: result['score'] for result in search_results}
|
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
||||||
|
|
||||||
# Get image metadata from database
|
# Get image metadata from database
|
||||||
images = await image_repository.get_by_ids(image_ids)
|
images = await image_repository.get_by_ids(image_ids)
|
||||||
@ -155,11 +156,11 @@ async def search_images_advanced(
|
|||||||
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
||||||
|
|
||||||
# Search in vector database
|
# Search in vector database
|
||||||
search_results = await vector_store_service.search_similar(
|
search_results = vector_db_service.search_similar_images(
|
||||||
query_embedding,
|
query_vector=query_embedding,
|
||||||
limit=search_request.limit,
|
limit=search_request.limit,
|
||||||
threshold=search_request.threshold,
|
score_threshold=search_request.threshold,
|
||||||
team_id=str(current_user.team_id)
|
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if not search_results:
|
if not search_results:
|
||||||
@ -172,8 +173,8 @@ async def search_images_advanced(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get image IDs and scores from search results
|
# Get image IDs and scores from search results
|
||||||
image_ids = [result['id'] for result in search_results]
|
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||||
scores = {result['id']: result['score'] for result in search_results}
|
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
||||||
|
|
||||||
# Get image metadata from database
|
# Get image metadata from database
|
||||||
images = await image_repository.get_by_ids(image_ids)
|
images = await image_repository.get_by_ids(image_ids)
|
||||||
@ -288,20 +289,22 @@ async def find_similar_images(
|
|||||||
raise HTTPException(status_code=400, detail="Reference image does not have embeddings")
|
raise HTTPException(status_code=400, detail="Reference image does not have embeddings")
|
||||||
|
|
||||||
# Get the embedding for the reference image
|
# Get the embedding for the reference image
|
||||||
reference_embedding = await vector_store_service.get_embedding(reference_image.embedding_id)
|
reference_data = vector_db_service.get_image_vector(image_id)
|
||||||
if not reference_embedding:
|
if not reference_data or not reference_data.get('vector'):
|
||||||
raise HTTPException(status_code=400, detail="Failed to get reference image embedding")
|
raise HTTPException(status_code=400, detail="Failed to get reference image embedding")
|
||||||
|
|
||||||
|
reference_embedding = reference_data['vector']
|
||||||
|
|
||||||
# Search for similar images
|
# Search for similar images
|
||||||
search_results = await vector_store_service.search_similar(
|
search_results = vector_db_service.search_similar_images(
|
||||||
reference_embedding,
|
query_vector=reference_embedding,
|
||||||
limit=limit + 1, # +1 to account for the reference image itself
|
limit=limit + 1, # +1 to account for the reference image itself
|
||||||
threshold=threshold,
|
score_threshold=threshold,
|
||||||
team_id=str(current_user.team_id)
|
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# Remove the reference image from results
|
# Remove the reference image from results
|
||||||
search_results = [result for result in search_results if result['id'] != image_id][:limit]
|
search_results = [result for result in search_results if result['image_id'] != image_id][:limit]
|
||||||
|
|
||||||
if not search_results:
|
if not search_results:
|
||||||
return SearchResponse(
|
return SearchResponse(
|
||||||
@ -313,8 +316,8 @@ async def find_similar_images(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get image IDs and scores from search results
|
# Get image IDs and scores from search results
|
||||||
image_ids = [result['id'] for result in search_results]
|
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||||
scores = {result['id']: result['score'] for result in search_results}
|
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
||||||
|
|
||||||
# Get image metadata from database
|
# Get image metadata from database
|
||||||
images = await image_repository.get_by_ids(image_ids)
|
images = await image_repository.get_by_ids(image_ids)
|
||||||
|
|||||||
@ -20,7 +20,7 @@ class VectorDatabaseService:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
host: str = None,
|
host: str = None,
|
||||||
port: int = 6333,
|
port: int = None,
|
||||||
api_key: str = None,
|
api_key: str = None,
|
||||||
collection_name: str = "image_vectors"
|
collection_name: str = "image_vectors"
|
||||||
):
|
):
|
||||||
@ -34,7 +34,7 @@ class VectorDatabaseService:
|
|||||||
collection_name: Name of the collection to use
|
collection_name: Name of the collection to use
|
||||||
"""
|
"""
|
||||||
self.host = host or os.getenv("QDRANT_HOST", "localhost")
|
self.host = host or os.getenv("QDRANT_HOST", "localhost")
|
||||||
self.port = port
|
self.port = port or int(os.getenv("QDRANT_PORT", "6333"))
|
||||||
self.api_key = api_key or os.getenv("QDRANT_API_KEY")
|
self.api_key = api_key or os.getenv("QDRANT_API_KEY")
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
|
|
||||||
|
|||||||
@ -1,267 +0,0 @@
|
|||||||
import logging
|
|
||||||
from typing import List, Dict, Any, Optional, Tuple
|
|
||||||
import pinecone
|
|
||||||
from bson import ObjectId
|
|
||||||
|
|
||||||
from src.config.config import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class VectorStoreService:
|
|
||||||
"""Service for managing vector embeddings in Pinecone"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.api_key = settings.VECTOR_DB_API_KEY
|
|
||||||
self.environment = settings.VECTOR_DB_ENVIRONMENT
|
|
||||||
self.index_name = settings.VECTOR_DB_INDEX_NAME
|
|
||||||
self.dimension = 512 # CLIP model embedding dimension
|
|
||||||
self.initialized = False
|
|
||||||
self.index = None
|
|
||||||
|
|
||||||
def initialize(self):
|
|
||||||
"""
|
|
||||||
Initialize Pinecone connection and create index if needed
|
|
||||||
"""
|
|
||||||
if self.initialized:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not self.api_key or not self.environment:
|
|
||||||
logger.warning("Pinecone API key or environment not provided, vector search disabled")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info(f"Initializing Pinecone with environment {self.environment}")
|
|
||||||
|
|
||||||
# Initialize Pinecone
|
|
||||||
pinecone.init(
|
|
||||||
api_key=self.api_key,
|
|
||||||
environment=self.environment
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if index exists
|
|
||||||
if self.index_name not in pinecone.list_indexes():
|
|
||||||
logger.info(f"Creating Pinecone index: {self.index_name}")
|
|
||||||
|
|
||||||
# Create index
|
|
||||||
pinecone.create_index(
|
|
||||||
name=self.index_name,
|
|
||||||
dimension=self.dimension,
|
|
||||||
metric="cosine"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Connect to index
|
|
||||||
self.index = pinecone.Index(self.index_name)
|
|
||||||
self.initialized = True
|
|
||||||
logger.info("Pinecone initialized successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error initializing Pinecone: {e}")
|
|
||||||
# Don't raise - we want to gracefully handle this and fall back to non-vector search
|
|
||||||
|
|
||||||
def store_embedding(self, image_id: str, team_id: str, embedding: List[float], metadata: Dict[str, Any] = None) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Store an embedding in Pinecone
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_id: Image ID
|
|
||||||
team_id: Team ID
|
|
||||||
embedding: Image embedding
|
|
||||||
metadata: Additional metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Vector ID if successful, None otherwise
|
|
||||||
"""
|
|
||||||
self.initialize()
|
|
||||||
|
|
||||||
if not self.initialized:
|
|
||||||
logger.warning("Pinecone not initialized, cannot store embedding")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create metadata dict
|
|
||||||
meta = {
|
|
||||||
"image_id": image_id,
|
|
||||||
"team_id": team_id
|
|
||||||
}
|
|
||||||
|
|
||||||
if metadata:
|
|
||||||
meta.update(metadata)
|
|
||||||
|
|
||||||
# Create a unique vector ID
|
|
||||||
vector_id = f"{team_id}_{image_id}"
|
|
||||||
|
|
||||||
# Upsert the vector
|
|
||||||
self.index.upsert(
|
|
||||||
vectors=[
|
|
||||||
(vector_id, embedding, meta)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(f"Stored embedding for image {image_id}")
|
|
||||||
return vector_id
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error storing embedding: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def search_similar(
|
|
||||||
self,
|
|
||||||
query_embedding: List[float],
|
|
||||||
limit: int = 10,
|
|
||||||
threshold: float = 0.7,
|
|
||||||
team_id: str = None
|
|
||||||
) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Search for similar embeddings
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_embedding: Query embedding vector
|
|
||||||
limit: Maximum number of results
|
|
||||||
threshold: Similarity threshold (0-1)
|
|
||||||
team_id: Optional team filter
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of results with id and score
|
|
||||||
"""
|
|
||||||
self.initialize()
|
|
||||||
|
|
||||||
if not self.initialized:
|
|
||||||
logger.warning("Pinecone not initialized, cannot search")
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create filter for team_id if provided
|
|
||||||
filter_dict = None
|
|
||||||
if team_id:
|
|
||||||
filter_dict = {
|
|
||||||
"team_id": {"$eq": team_id}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Query the index
|
|
||||||
results = self.index.query(
|
|
||||||
vector=query_embedding,
|
|
||||||
filter=filter_dict,
|
|
||||||
top_k=limit,
|
|
||||||
include_metadata=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format the results and apply threshold
|
|
||||||
formatted_results = []
|
|
||||||
for match in results.matches:
|
|
||||||
if match.score >= threshold:
|
|
||||||
formatted_results.append({
|
|
||||||
"id": match.metadata.get("image_id", match.id),
|
|
||||||
"score": match.score,
|
|
||||||
"metadata": match.metadata
|
|
||||||
})
|
|
||||||
|
|
||||||
return formatted_results
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error searching similar embeddings: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def get_embedding(self, embedding_id: str) -> Optional[List[float]]:
|
|
||||||
"""
|
|
||||||
Get embedding by ID
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embedding_id: Embedding ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Embedding vector or None if not found
|
|
||||||
"""
|
|
||||||
self.initialize()
|
|
||||||
|
|
||||||
if not self.initialized:
|
|
||||||
logger.warning("Pinecone not initialized, cannot get embedding")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Fetch the vector
|
|
||||||
result = self.index.fetch(ids=[embedding_id])
|
|
||||||
|
|
||||||
if embedding_id in result.vectors:
|
|
||||||
return result.vectors[embedding_id].values
|
|
||||||
else:
|
|
||||||
logger.warning(f"Embedding not found: {embedding_id}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting embedding: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def search_by_embedding(self, team_id: str, query_embedding: List[float], limit: int = 10) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Search for similar images by embedding
|
|
||||||
|
|
||||||
Args:
|
|
||||||
team_id: Team ID
|
|
||||||
query_embedding: Query embedding
|
|
||||||
limit: Maximum number of results
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of results with image ID and similarity score
|
|
||||||
"""
|
|
||||||
self.initialize()
|
|
||||||
|
|
||||||
if not self.initialized:
|
|
||||||
logger.warning("Pinecone not initialized, cannot search by embedding")
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create filter for team_id
|
|
||||||
filter_dict = {
|
|
||||||
"team_id": {"$eq": team_id}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Query the index
|
|
||||||
results = self.index.query(
|
|
||||||
vector=query_embedding,
|
|
||||||
filter=filter_dict,
|
|
||||||
top_k=limit,
|
|
||||||
include_metadata=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Format the results
|
|
||||||
formatted_results = []
|
|
||||||
for match in results.matches:
|
|
||||||
formatted_results.append({
|
|
||||||
"image_id": match.metadata["image_id"],
|
|
||||||
"score": match.score,
|
|
||||||
"metadata": match.metadata
|
|
||||||
})
|
|
||||||
|
|
||||||
return formatted_results
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error searching by embedding: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def delete_embedding(self, image_id: str, team_id: str) -> bool:
|
|
||||||
"""
|
|
||||||
Delete an embedding from Pinecone
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_id: Image ID
|
|
||||||
team_id: Team ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
self.initialize()
|
|
||||||
|
|
||||||
if not self.initialized:
|
|
||||||
logger.warning("Pinecone not initialized, cannot delete embedding")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Create the vector ID
|
|
||||||
vector_id = f"{team_id}_{image_id}"
|
|
||||||
|
|
||||||
# Delete the vector
|
|
||||||
self.index.delete(ids=[vector_id])
|
|
||||||
|
|
||||||
logger.info(f"Deleted embedding for image {image_id}")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error deleting embedding: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Create a singleton service
|
|
||||||
vector_store_service = VectorStoreService()
|
|
||||||
25
start_dev.sh
Normal file
25
start_dev.sh
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Development startup script for Sereact API
|
||||||
|
# This script sets the environment variables and starts the application
|
||||||
|
# Auto-generated by deployment/scripts/setup_local_env.sh
|
||||||
|
|
||||||
|
# Activate virtual environment
|
||||||
|
source venv/Scripts/activate
|
||||||
|
|
||||||
|
# Set environment variables from deployed infrastructure
|
||||||
|
export QDRANT_HOST=34.171.134.17
|
||||||
|
export QDRANT_PORT=6333
|
||||||
|
export FIRESTORE_PROJECT_ID=gen-lang-client-0424120530
|
||||||
|
export GCS_BUCKET_NAME=sereact-images
|
||||||
|
export ENVIRONMENT=development
|
||||||
|
|
||||||
|
# Start the application
|
||||||
|
echo "Starting Sereact API with deployed infrastructure..."
|
||||||
|
echo "Qdrant endpoint: http://$QDRANT_HOST:$QDRANT_PORT"
|
||||||
|
echo "Firestore project: $FIRESTORE_PROJECT_ID"
|
||||||
|
echo "GCS bucket: $GCS_BUCKET_NAME"
|
||||||
|
echo "API will be available at: http://localhost:8000"
|
||||||
|
echo "API documentation: http://localhost:8000/docs"
|
||||||
|
|
||||||
|
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||||
43
test_qdrant_connection.py
Normal file
43
test_qdrant_connection.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple test script to verify Qdrant connection
|
||||||
|
Auto-generated by deployment/scripts/setup_local_env.sh
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from src.services.vector_db import VectorDatabaseService
|
||||||
|
|
||||||
|
def test_qdrant_connection():
|
||||||
|
"""Test the connection to Qdrant"""
|
||||||
|
|
||||||
|
# Set environment variables from deployed infrastructure
|
||||||
|
os.environ['QDRANT_HOST'] = '34.171.134.17'
|
||||||
|
os.environ['QDRANT_PORT'] = '6333'
|
||||||
|
|
||||||
|
try:
|
||||||
|
print("Testing Qdrant connection...")
|
||||||
|
print(f"Host: {os.environ['QDRANT_HOST']}")
|
||||||
|
print(f"Port: {os.environ['QDRANT_PORT']}")
|
||||||
|
|
||||||
|
# Initialize the service
|
||||||
|
vector_db = VectorDatabaseService()
|
||||||
|
|
||||||
|
# Test health check
|
||||||
|
is_healthy = vector_db.health_check()
|
||||||
|
print(f"Health check: {'✓ PASSED' if is_healthy else '✗ FAILED'}")
|
||||||
|
|
||||||
|
# Get collection info
|
||||||
|
collection_info = vector_db.get_collection_info()
|
||||||
|
print(f"Collection info: {collection_info}")
|
||||||
|
|
||||||
|
print("\n✓ Qdrant connection test PASSED!")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\n✗ Qdrant connection test FAILED: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = test_qdrant_connection()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
@ -288,30 +288,25 @@ class TestEmbeddingService:
|
|||||||
class TestEmbeddingServiceIntegration:
|
class TestEmbeddingServiceIntegration:
|
||||||
"""Integration tests for embedding service with other components"""
|
"""Integration tests for embedding service with other components"""
|
||||||
|
|
||||||
def test_embedding_to_vector_store_integration(self, embedding_service, sample_image_data, sample_image_model):
|
def test_embedding_to_vector_db_integration(self, embedding_service, sample_image_data, sample_image_model):
|
||||||
"""Test integration with vector store service"""
|
"""Test integration between embedding service and vector database"""
|
||||||
mock_embedding = np.random.rand(512).tolist()
|
# Mock the vector database service
|
||||||
|
with patch('src.services.vector_db.VectorDatabaseService') as mock_vector_db:
|
||||||
|
# Setup mock
|
||||||
|
mock_store = mock_vector_db.return_value
|
||||||
|
mock_store.add_image_vector.return_value = "test_point_id"
|
||||||
|
|
||||||
with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate, \
|
# Test storing embedding
|
||||||
patch('src.services.vector_store.VectorStoreService') as mock_vector_store:
|
embedding = [0.1] * 512 # Mock embedding
|
||||||
|
point_id = mock_store.add_image_vector(
|
||||||
mock_generate.return_value = {
|
image_id=str(sample_image_model.id),
|
||||||
'embedding': mock_embedding,
|
vector=embedding,
|
||||||
'metadata': {'labels': [{'description': 'cat', 'score': 0.95}]},
|
metadata={"filename": sample_image_model.filename}
|
||||||
'model': 'clip'
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_store = mock_vector_store.return_value
|
|
||||||
mock_store.store_embedding.return_value = 'embedding_id_123'
|
|
||||||
|
|
||||||
# Process image and store embedding
|
|
||||||
result = embedding_service.process_and_store_image(
|
|
||||||
sample_image_data, sample_image_model
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify integration
|
# Verify the call
|
||||||
assert result['embedding_id'] == 'embedding_id_123'
|
mock_store.add_image_vector.assert_called_once()
|
||||||
mock_store.store_embedding.assert_called_once()
|
assert point_id == "test_point_id"
|
||||||
|
|
||||||
def test_pubsub_trigger_integration(self, embedding_service):
|
def test_pubsub_trigger_integration(self, embedding_service):
|
||||||
"""Test integration with Pub/Sub message processing"""
|
"""Test integration with Pub/Sub message processing"""
|
||||||
|
|||||||
@ -1,391 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import numpy as np
|
|
||||||
from unittest.mock import patch, MagicMock, AsyncMock
|
|
||||||
from bson import ObjectId
|
|
||||||
|
|
||||||
from src.services.vector_store import VectorStoreService
|
|
||||||
from src.models.image import ImageModel
|
|
||||||
|
|
||||||
|
|
||||||
class TestVectorStoreService:
|
|
||||||
"""Test vector store operations for semantic search"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_pinecone_index(self):
|
|
||||||
"""Mock Pinecone index for testing"""
|
|
||||||
mock_index = MagicMock()
|
|
||||||
mock_index.upsert = MagicMock()
|
|
||||||
mock_index.query = MagicMock()
|
|
||||||
mock_index.delete = MagicMock()
|
|
||||||
mock_index.describe_index_stats = MagicMock()
|
|
||||||
return mock_index
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def vector_store_service(self, mock_pinecone_index):
|
|
||||||
"""Create vector store service with mocked dependencies"""
|
|
||||||
with patch('src.services.vector_store.pinecone') as mock_pinecone:
|
|
||||||
mock_pinecone.Index.return_value = mock_pinecone_index
|
|
||||||
service = VectorStoreService()
|
|
||||||
service.index = mock_pinecone_index
|
|
||||||
return service
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_embedding(self):
|
|
||||||
"""Generate a sample embedding vector"""
|
|
||||||
return np.random.rand(512).tolist() # 512-dimensional vector
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_image(self):
|
|
||||||
"""Create a sample image model"""
|
|
||||||
return ImageModel(
|
|
||||||
filename="test-image.jpg",
|
|
||||||
original_filename="test_image.jpg",
|
|
||||||
file_size=1024,
|
|
||||||
content_type="image/jpeg",
|
|
||||||
storage_path="images/test-image.jpg",
|
|
||||||
team_id=ObjectId(),
|
|
||||||
uploader_id=ObjectId(),
|
|
||||||
description="A test image",
|
|
||||||
tags=["test", "image"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_store_embedding(self, vector_store_service, sample_embedding, sample_image):
|
|
||||||
"""Test storing an embedding in the vector database"""
|
|
||||||
# Store the embedding
|
|
||||||
embedding_id = vector_store_service.store_embedding(
|
|
||||||
image_id=str(sample_image.id),
|
|
||||||
embedding=sample_embedding,
|
|
||||||
metadata={
|
|
||||||
"filename": sample_image.filename,
|
|
||||||
"team_id": str(sample_image.team_id),
|
|
||||||
"tags": sample_image.tags,
|
|
||||||
"description": sample_image.description
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify the embedding was stored
|
|
||||||
assert embedding_id is not None
|
|
||||||
vector_store_service.index.upsert.assert_called_once()
|
|
||||||
|
|
||||||
# Check the upsert call arguments
|
|
||||||
call_args = vector_store_service.index.upsert.call_args
|
|
||||||
vectors = call_args[1]['vectors']
|
|
||||||
assert len(vectors) == 1
|
|
||||||
assert vectors[0]['id'] == embedding_id
|
|
||||||
assert len(vectors[0]['values']) == len(sample_embedding)
|
|
||||||
assert 'metadata' in vectors[0]
|
|
||||||
|
|
||||||
def test_search_similar_images(self, vector_store_service, sample_embedding):
|
|
||||||
"""Test searching for similar images using vector similarity"""
|
|
||||||
# Mock search results
|
|
||||||
mock_results = {
|
|
||||||
'matches': [
|
|
||||||
{
|
|
||||||
'id': 'embedding1',
|
|
||||||
'score': 0.95,
|
|
||||||
'metadata': {
|
|
||||||
'image_id': str(ObjectId()),
|
|
||||||
'filename': 'similar1.jpg',
|
|
||||||
'team_id': str(ObjectId()),
|
|
||||||
'tags': ['cat', 'animal']
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'id': 'embedding2',
|
|
||||||
'score': 0.87,
|
|
||||||
'metadata': {
|
|
||||||
'image_id': str(ObjectId()),
|
|
||||||
'filename': 'similar2.jpg',
|
|
||||||
'team_id': str(ObjectId()),
|
|
||||||
'tags': ['dog', 'animal']
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
vector_store_service.index.query.return_value = mock_results
|
|
||||||
|
|
||||||
# Perform search
|
|
||||||
results = vector_store_service.search_similar(
|
|
||||||
query_embedding=sample_embedding,
|
|
||||||
team_id=str(ObjectId()),
|
|
||||||
top_k=10,
|
|
||||||
score_threshold=0.8
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify search was performed
|
|
||||||
vector_store_service.index.query.assert_called_once()
|
|
||||||
|
|
||||||
# Check results
|
|
||||||
assert len(results) == 2
|
|
||||||
assert results[0]['score'] == 0.95
|
|
||||||
assert results[1]['score'] == 0.87
|
|
||||||
assert all('image_id' in result for result in results)
|
|
||||||
|
|
||||||
def test_search_with_filters(self, vector_store_service, sample_embedding):
|
|
||||||
"""Test searching with metadata filters"""
|
|
||||||
team_id = str(ObjectId())
|
|
||||||
|
|
||||||
# Perform search with team filter
|
|
||||||
vector_store_service.search_similar(
|
|
||||||
query_embedding=sample_embedding,
|
|
||||||
team_id=team_id,
|
|
||||||
top_k=5,
|
|
||||||
filters={"tags": {"$in": ["cat", "dog"]}}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify filter was applied
|
|
||||||
call_args = vector_store_service.index.query.call_args
|
|
||||||
assert 'filter' in call_args[1]
|
|
||||||
assert call_args[1]['filter']['team_id'] == team_id
|
|
||||||
|
|
||||||
def test_delete_embedding(self, vector_store_service):
|
|
||||||
"""Test deleting an embedding from the vector database"""
|
|
||||||
embedding_id = "test-embedding-123"
|
|
||||||
|
|
||||||
# Delete the embedding
|
|
||||||
success = vector_store_service.delete_embedding(embedding_id)
|
|
||||||
|
|
||||||
# Verify deletion was attempted
|
|
||||||
vector_store_service.index.delete.assert_called_once_with(ids=[embedding_id])
|
|
||||||
assert success is True
|
|
||||||
|
|
||||||
def test_batch_store_embeddings(self, vector_store_service, sample_embedding):
|
|
||||||
"""Test storing multiple embeddings in batch"""
|
|
||||||
# Create batch data
|
|
||||||
batch_data = []
|
|
||||||
for i in range(5):
|
|
||||||
batch_data.append({
|
|
||||||
'image_id': str(ObjectId()),
|
|
||||||
'embedding': sample_embedding,
|
|
||||||
'metadata': {
|
|
||||||
'filename': f'image{i}.jpg',
|
|
||||||
'team_id': str(ObjectId()),
|
|
||||||
'tags': [f'tag{i}']
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
# Store batch
|
|
||||||
embedding_ids = vector_store_service.batch_store_embeddings(batch_data)
|
|
||||||
|
|
||||||
# Verify batch storage
|
|
||||||
assert len(embedding_ids) == 5
|
|
||||||
vector_store_service.index.upsert.assert_called_once()
|
|
||||||
|
|
||||||
# Check batch upsert call
|
|
||||||
call_args = vector_store_service.index.upsert.call_args
|
|
||||||
vectors = call_args[1]['vectors']
|
|
||||||
assert len(vectors) == 5
|
|
||||||
|
|
||||||
def test_get_index_stats(self, vector_store_service):
|
|
||||||
"""Test getting vector database statistics"""
|
|
||||||
# Mock stats response
|
|
||||||
mock_stats = {
|
|
||||||
'total_vector_count': 1000,
|
|
||||||
'dimension': 512,
|
|
||||||
'index_fullness': 0.1
|
|
||||||
}
|
|
||||||
vector_store_service.index.describe_index_stats.return_value = mock_stats
|
|
||||||
|
|
||||||
# Get stats
|
|
||||||
stats = vector_store_service.get_index_stats()
|
|
||||||
|
|
||||||
# Verify stats retrieval
|
|
||||||
vector_store_service.index.describe_index_stats.assert_called_once()
|
|
||||||
assert stats['total_vector_count'] == 1000
|
|
||||||
assert stats['dimension'] == 512
|
|
||||||
|
|
||||||
def test_search_with_score_threshold(self, vector_store_service, sample_embedding):
|
|
||||||
"""Test filtering search results by score threshold"""
|
|
||||||
# Mock results with varying scores
|
|
||||||
mock_results = {
|
|
||||||
'matches': [
|
|
||||||
{'id': 'emb1', 'score': 0.95, 'metadata': {'image_id': '1'}},
|
|
||||||
{'id': 'emb2', 'score': 0.75, 'metadata': {'image_id': '2'}},
|
|
||||||
{'id': 'emb3', 'score': 0.65, 'metadata': {'image_id': '3'}},
|
|
||||||
{'id': 'emb4', 'score': 0.45, 'metadata': {'image_id': '4'}}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
vector_store_service.index.query.return_value = mock_results
|
|
||||||
|
|
||||||
# Search with score threshold
|
|
||||||
results = vector_store_service.search_similar(
|
|
||||||
query_embedding=sample_embedding,
|
|
||||||
team_id=str(ObjectId()),
|
|
||||||
top_k=10,
|
|
||||||
score_threshold=0.7
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only results above threshold should be returned
|
|
||||||
assert len(results) == 2
|
|
||||||
assert all(result['score'] >= 0.7 for result in results)
|
|
||||||
|
|
||||||
def test_update_embedding_metadata(self, vector_store_service):
|
|
||||||
"""Test updating metadata for an existing embedding"""
|
|
||||||
embedding_id = "test-embedding-123"
|
|
||||||
new_metadata = {
|
|
||||||
'tags': ['updated', 'tag'],
|
|
||||||
'description': 'Updated description'
|
|
||||||
}
|
|
||||||
|
|
||||||
# Update metadata
|
|
||||||
success = vector_store_service.update_embedding_metadata(
|
|
||||||
embedding_id, new_metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify update was attempted
|
|
||||||
# This would depend on the actual implementation
|
|
||||||
assert success is True
|
|
||||||
|
|
||||||
def test_search_by_image_id(self, vector_store_service):
|
|
||||||
"""Test searching for a specific image's embedding"""
|
|
||||||
image_id = str(ObjectId())
|
|
||||||
|
|
||||||
# Mock search by metadata
|
|
||||||
mock_results = {
|
|
||||||
'matches': [
|
|
||||||
{
|
|
||||||
'id': 'embedding1',
|
|
||||||
'score': 1.0,
|
|
||||||
'metadata': {
|
|
||||||
'image_id': image_id,
|
|
||||||
'filename': 'target.jpg'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
vector_store_service.index.query.return_value = mock_results
|
|
||||||
|
|
||||||
# Search by image ID
|
|
||||||
result = vector_store_service.get_embedding_by_image_id(image_id)
|
|
||||||
|
|
||||||
# Verify result
|
|
||||||
assert result is not None
|
|
||||||
assert result['metadata']['image_id'] == image_id
|
|
||||||
|
|
||||||
def test_bulk_delete_embeddings(self, vector_store_service):
|
|
||||||
"""Test deleting multiple embeddings"""
|
|
||||||
embedding_ids = ['emb1', 'emb2', 'emb3']
|
|
||||||
|
|
||||||
# Delete multiple embeddings
|
|
||||||
success = vector_store_service.bulk_delete_embeddings(embedding_ids)
|
|
||||||
|
|
||||||
# Verify bulk deletion
|
|
||||||
vector_store_service.index.delete.assert_called_once_with(ids=embedding_ids)
|
|
||||||
assert success is True
|
|
||||||
|
|
||||||
def test_search_pagination(self, vector_store_service, sample_embedding):
|
|
||||||
"""Test paginated search results"""
|
|
||||||
# This would test pagination if implemented
|
|
||||||
# Implementation depends on how pagination is handled in the vector store
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_vector_dimension_validation(self, vector_store_service):
|
|
||||||
"""Test validation of embedding dimensions"""
|
|
||||||
# Test with wrong dimension
|
|
||||||
wrong_dimension_embedding = np.random.rand(256).tolist() # Wrong size
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
vector_store_service.store_embedding(
|
|
||||||
image_id=str(ObjectId()),
|
|
||||||
embedding=wrong_dimension_embedding,
|
|
||||||
metadata={}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_connection_error_handling(self, vector_store_service):
|
|
||||||
"""Test handling of connection errors"""
|
|
||||||
# Mock connection error
|
|
||||||
vector_store_service.index.query.side_effect = Exception("Connection failed")
|
|
||||||
|
|
||||||
# Search should handle the error gracefully
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
vector_store_service.search_similar(
|
|
||||||
query_embedding=[0.1] * 512,
|
|
||||||
team_id=str(ObjectId()),
|
|
||||||
top_k=10
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_empty_search_results(self, vector_store_service, sample_embedding):
|
|
||||||
"""Test handling of empty search results"""
|
|
||||||
# Mock empty results
|
|
||||||
vector_store_service.index.query.return_value = {'matches': []}
|
|
||||||
|
|
||||||
# Search should return empty list
|
|
||||||
results = vector_store_service.search_similar(
|
|
||||||
query_embedding=sample_embedding,
|
|
||||||
team_id=str(ObjectId()),
|
|
||||||
top_k=10
|
|
||||||
)
|
|
||||||
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
|
|
||||||
class TestVectorStoreIntegration:
|
|
||||||
"""Integration tests for vector store with other services"""
|
|
||||||
|
|
||||||
def test_embedding_lifecycle(self, vector_store_service, sample_embedding, sample_image):
|
|
||||||
"""Test complete embedding lifecycle: store, search, update, delete"""
|
|
||||||
# Store embedding
|
|
||||||
embedding_id = vector_store_service.store_embedding(
|
|
||||||
image_id=str(sample_image.id),
|
|
||||||
embedding=sample_embedding,
|
|
||||||
metadata={'filename': sample_image.filename}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Search for similar embeddings
|
|
||||||
mock_results = {
|
|
||||||
'matches': [
|
|
||||||
{
|
|
||||||
'id': embedding_id,
|
|
||||||
'score': 1.0,
|
|
||||||
'metadata': {'image_id': str(sample_image.id)}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
vector_store_service.index.query.return_value = mock_results
|
|
||||||
|
|
||||||
results = vector_store_service.search_similar(
|
|
||||||
query_embedding=sample_embedding,
|
|
||||||
team_id=str(sample_image.team_id),
|
|
||||||
top_k=1
|
|
||||||
)
|
|
||||||
|
|
||||||
assert len(results) == 1
|
|
||||||
assert results[0]['id'] == embedding_id
|
|
||||||
|
|
||||||
# Delete embedding
|
|
||||||
success = vector_store_service.delete_embedding(embedding_id)
|
|
||||||
assert success is True
|
|
||||||
|
|
||||||
def test_team_isolation(self, vector_store_service, sample_embedding):
|
|
||||||
"""Test that team data is properly isolated"""
|
|
||||||
team1_id = str(ObjectId())
|
|
||||||
team2_id = str(ObjectId())
|
|
||||||
|
|
||||||
# Mock search results that should be filtered by team
|
|
||||||
mock_results = {
|
|
||||||
'matches': [
|
|
||||||
{
|
|
||||||
'id': 'emb1',
|
|
||||||
'score': 0.9,
|
|
||||||
'metadata': {'image_id': '1', 'team_id': team1_id}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
'id': 'emb2',
|
|
||||||
'score': 0.8,
|
|
||||||
'metadata': {'image_id': '2', 'team_id': team2_id}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
vector_store_service.index.query.return_value = mock_results
|
|
||||||
|
|
||||||
# Search for team1 should only return team1 results
|
|
||||||
results = vector_store_service.search_similar(
|
|
||||||
query_embedding=sample_embedding,
|
|
||||||
team_id=team1_id,
|
|
||||||
top_k=10
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify team filter was applied in the query
|
|
||||||
call_args = vector_store_service.index.query.call_args
|
|
||||||
assert 'filter' in call_args[1]
|
|
||||||
assert call_args[1]['filter']['team_id'] == team1_id
|
|
||||||
Loading…
x
Reference in New Issue
Block a user