This commit is contained in:
johnpccd 2025-05-24 18:33:51 +02:00
parent 32b074bcc4
commit 4d8447bb40
16 changed files with 2185 additions and 905 deletions

View File

@ -0,0 +1,167 @@
#!/bin/bash
# Setup Local Environment Script
# This script extracts configuration from terraform and sets up local development
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
TERRAFORM_DIR="$PROJECT_ROOT/deployment/terraform"
echo "Setting up local development environment..."
# Check if terraform directory exists
if [ ! -d "$TERRAFORM_DIR" ]; then
echo "Error: Terraform directory not found at $TERRAFORM_DIR"
exit 1
fi
# Change to terraform directory
cd "$TERRAFORM_DIR"
# Check if terraform state exists
if [ ! -f "terraform.tfstate" ]; then
echo "Error: Terraform state not found. Please run 'terraform apply' first."
exit 1
fi
echo "Extracting configuration from terraform..."
# Get terraform outputs
QDRANT_HOST=$(terraform output -raw vector_db_vm_external_ip 2>/dev/null || echo "")
QDRANT_PORT="6333"
FIRESTORE_PROJECT_ID=$(terraform output -raw firestore_database_id 2>/dev/null | cut -d'/' -f2 || echo "")
GCS_BUCKET_NAME=$(terraform output -raw storage_bucket_name 2>/dev/null || echo "")
if [ -z "$QDRANT_HOST" ]; then
echo "Error: Could not extract Qdrant host from terraform outputs"
exit 1
fi
echo "Configuration extracted:"
echo " Qdrant Host: $QDRANT_HOST"
echo " Qdrant Port: $QDRANT_PORT"
echo " Firestore Project: $FIRESTORE_PROJECT_ID"
echo " GCS Bucket: $GCS_BUCKET_NAME"
# Go back to project root
cd "$PROJECT_ROOT"
# Update start_dev.sh with extracted values
echo "Updating start_dev.sh..."
cat > start_dev.sh << EOF
#!/bin/bash
# Development startup script for Sereact API
# This script sets the environment variables and starts the application
# Auto-generated by deployment/scripts/setup_local_env.sh
# Activate virtual environment
source venv/Scripts/activate
# Set environment variables from deployed infrastructure
export QDRANT_HOST=$QDRANT_HOST
export QDRANT_PORT=$QDRANT_PORT
export FIRESTORE_PROJECT_ID=$FIRESTORE_PROJECT_ID
export GCS_BUCKET_NAME=$GCS_BUCKET_NAME
export ENVIRONMENT=development
# Start the application
echo "Starting Sereact API with deployed infrastructure..."
echo "Qdrant endpoint: http://\$QDRANT_HOST:\$QDRANT_PORT"
echo "Firestore project: \$FIRESTORE_PROJECT_ID"
echo "GCS bucket: \$GCS_BUCKET_NAME"
echo "API will be available at: http://localhost:8000"
echo "API documentation: http://localhost:8000/docs"
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
EOF
chmod +x start_dev.sh
# Update docker-compose.yml with extracted values
echo "Updating docker-compose.yml..."
cat > docker-compose.yml << EOF
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
volumes:
- .:/app
- \${GOOGLE_APPLICATION_CREDENTIALS:-./firestore-credentials.json}:/app/firestore-credentials.json:ro
environment:
- PYTHONUNBUFFERED=1
- ENVIRONMENT=development
- FIRESTORE_CREDENTIALS_FILE=/app/firestore-credentials.json
- GOOGLE_APPLICATION_CREDENTIALS=/app/firestore-credentials.json
- FIRESTORE_PROJECT_ID=\${FIRESTORE_PROJECT_ID:-$FIRESTORE_PROJECT_ID}
- QDRANT_HOST=$QDRANT_HOST
- QDRANT_PORT=$QDRANT_PORT
- GCS_BUCKET_NAME=$GCS_BUCKET_NAME
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
EOF
# Update test script
echo "Updating test_qdrant_connection.py..."
cat > test_qdrant_connection.py << EOF
#!/usr/bin/env python3
"""
Simple test script to verify Qdrant connection
Auto-generated by deployment/scripts/setup_local_env.sh
"""
import os
import sys
from src.services.vector_db import VectorDatabaseService
def test_qdrant_connection():
"""Test the connection to Qdrant"""
# Set environment variables from deployed infrastructure
os.environ['QDRANT_HOST'] = '$QDRANT_HOST'
os.environ['QDRANT_PORT'] = '$QDRANT_PORT'
try:
print("Testing Qdrant connection...")
print(f"Host: {os.environ['QDRANT_HOST']}")
print(f"Port: {os.environ['QDRANT_PORT']}")
# Initialize the service
vector_db = VectorDatabaseService()
# Test health check
is_healthy = vector_db.health_check()
print(f"Health check: {'✓ PASSED' if is_healthy else '✗ FAILED'}")
# Get collection info
collection_info = vector_db.get_collection_info()
print(f"Collection info: {collection_info}")
print("\n✓ Qdrant connection test PASSED!")
return True
except Exception as e:
print(f"\n✗ Qdrant connection test FAILED: {e}")
return False
if __name__ == "__main__":
success = test_qdrant_connection()
sys.exit(0 if success else 1)
EOF
echo ""
echo "✓ Local development environment setup complete!"
echo ""
echo "You can now:"
echo " 1. Run './start_dev.sh' to start the API with deployed infrastructure"
echo " 2. Run 'docker-compose up' to use Docker with deployed Qdrant"
echo " 3. Run 'python test_qdrant_connection.py' to test Qdrant connection"
echo ""
echo "All configuration has been automatically extracted from your terraform deployment."

View File

@ -52,6 +52,12 @@ resource "google_cloud_run_service" "sereact" {
name = "sereact"
location = var.region
metadata {
annotations = {
"run.googleapis.com/ingress" = "all"
}
}
template {
spec {
containers {
@ -75,13 +81,8 @@ resource "google_cloud_run_service" "sereact" {
}
env {
name = "FIRESTORE_CREDENTIALS_FILE"
value = "/var/secrets/google/key.json"
}
env {
name = "GOOGLE_APPLICATION_CREDENTIALS"
value = "/var/secrets/google/key.json"
name = "FIRESTORE_DATABASE_NAME"
value = var.firestore_db_name
}
env {
@ -99,6 +100,21 @@ resource "google_cloud_run_service" "sereact" {
value = var.vector_db_index_name
}
env {
name = "QDRANT_HOST"
value = var.use_static_ip ? google_compute_address.vector_db_static_ip.address : google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip
}
env {
name = "QDRANT_PORT"
value = "6333"
}
env {
name = "QDRANT_API_KEY"
value = var.qdrant_api_key
}
env {
name = "LOG_LEVEL"
value = "INFO"
@ -109,7 +125,6 @@ resource "google_cloud_run_service" "sereact" {
metadata {
annotations = {
"autoscaling.knative.dev/maxScale" = "10"
"run.googleapis.com/ingress" = "all"
}
}
}
@ -119,7 +134,7 @@ resource "google_cloud_run_service" "sereact" {
latest_revision = true
}
depends_on = [google_project_service.services]
depends_on = [google_project_service.services, google_compute_instance.vector_db_vm]
}
# Make the Cloud Run service publicly accessible

View File

@ -47,4 +47,20 @@ output "qdrant_http_endpoint" {
output "qdrant_grpc_endpoint" {
value = "http://${google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip}:6334"
description = "The gRPC endpoint for Qdrant vector database"
}
# Cloud Run environment configuration
output "cloud_run_qdrant_host" {
value = var.use_static_ip ? google_compute_address.vector_db_static_ip.address : google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip
description = "The Qdrant host configured for Cloud Run"
}
output "deployment_summary" {
value = {
cloud_run_url = google_cloud_run_service.sereact.status[0].url
qdrant_endpoint = "http://${google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip}:6333"
firestore_database = var.firestore_db_name
storage_bucket = var.storage_bucket_name
}
description = "Summary of deployed resources"
}

View File

@ -4,11 +4,6 @@ locals {
cloud_function_service_account = var.cloud_function_service_account != "" ? var.cloud_function_service_account : "${data.google_project.current.number}-compute@developer.gserviceaccount.com"
}
# Get current project information
data "google_project" "current" {
project_id = var.project_id
}
# Pub/Sub topic for image processing tasks
resource "google_pubsub_topic" "image_processing" {
name = var.pubsub_topic_name
@ -31,10 +26,10 @@ resource "google_pubsub_subscription" "image_processing" {
maximum_backoff = "600s"
}
# Dead letter policy after 3 failed attempts
# Dead letter policy after 5 failed attempts
dead_letter_policy {
dead_letter_topic = google_pubsub_topic.image_processing_dlq.id
max_delivery_attempts = 3
max_delivery_attempts = 5
}
# Message retention

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -23,7 +23,7 @@ variable "storage_bucket_name" {
variable "firestore_db_name" {
description = "The name of the Firestore database"
type = string
default = "imagedb"
default = "sereact-imagedb"
}
variable "environment" {

View File

@ -1,17 +1,20 @@
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
volumes:
- .:/app
- ${GOOGLE_APPLICATION_CREDENTIALS:-./firestore-credentials.json}:/app/firestore-credentials.json:ro
environment:
- PYTHONUNBUFFERED=1
- ENVIRONMENT=development
- FIRESTORE_CREDENTIALS_FILE=/app/firestore-credentials.json
- GOOGLE_APPLICATION_CREDENTIALS=/app/firestore-credentials.json
- FIRESTORE_PROJECT_ID=${FIRESTORE_PROJECT_ID:-}
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
version: '3.8'
services:
api:
build: .
ports:
- "8000:8000"
volumes:
- .:/app
- ${GOOGLE_APPLICATION_CREDENTIALS:-./firestore-credentials.json}:/app/firestore-credentials.json:ro
environment:
- PYTHONUNBUFFERED=1
- ENVIRONMENT=development
- FIRESTORE_CREDENTIALS_FILE=/app/firestore-credentials.json
- GOOGLE_APPLICATION_CREDENTIALS=/app/firestore-credentials.json
- FIRESTORE_PROJECT_ID=${FIRESTORE_PROJECT_ID:-gen-lang-client-0424120530}
- QDRANT_HOST=34.171.134.17
- QDRANT_PORT=6333
- GCS_BUCKET_NAME=sereact-images
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload

View File

@ -1,2 +1 @@
from src.api.v1 import teams, auth
from src.api.v1 import users, images, search
# API v1 package

View File

@ -1,11 +1,12 @@
import logging
from typing import Optional, List
from typing import Optional, List, Dict, Any
from fastapi import APIRouter, Depends, Query, Request, HTTPException
from src.api.v1.auth import get_current_user
from src.db.repositories.image_repository import image_repository
from src.services.vector_store import VectorStoreService
from src.services.vector_db import VectorDatabaseService
from src.services.embedding_service import EmbeddingService
from src.db.repositories.image_repository import image_repository
from src.db.repositories.team_repository import team_repository
from src.models.user import UserModel
from src.schemas.image import ImageResponse
from src.schemas.search import SearchResponse, SearchRequest
@ -16,7 +17,7 @@ logger = logging.getLogger(__name__)
router = APIRouter(tags=["Search"], prefix="/search")
# Initialize services
vector_store_service = VectorStoreService()
vector_db_service = VectorDatabaseService()
embedding_service = EmbeddingService()
@router.get("", response_model=SearchResponse)
@ -51,11 +52,11 @@ async def search_images(
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
# Search in vector database
search_results = await vector_store_service.search_similar(
query_embedding,
search_results = vector_db_service.search_similar_images(
query_vector=query_embedding,
limit=limit,
threshold=threshold,
team_id=str(current_user.team_id)
score_threshold=threshold,
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
)
if not search_results:
@ -68,8 +69,8 @@ async def search_images(
)
# Get image IDs and scores from search results
image_ids = [result['id'] for result in search_results]
scores = {result['id']: result['score'] for result in search_results}
image_ids = [result['image_id'] for result in search_results if result['image_id']]
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
# Get image metadata from database
images = await image_repository.get_by_ids(image_ids)
@ -155,11 +156,11 @@ async def search_images_advanced(
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
# Search in vector database
search_results = await vector_store_service.search_similar(
query_embedding,
search_results = vector_db_service.search_similar_images(
query_vector=query_embedding,
limit=search_request.limit,
threshold=search_request.threshold,
team_id=str(current_user.team_id)
score_threshold=search_request.threshold,
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
)
if not search_results:
@ -172,8 +173,8 @@ async def search_images_advanced(
)
# Get image IDs and scores from search results
image_ids = [result['id'] for result in search_results]
scores = {result['id']: result['score'] for result in search_results}
image_ids = [result['image_id'] for result in search_results if result['image_id']]
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
# Get image metadata from database
images = await image_repository.get_by_ids(image_ids)
@ -288,20 +289,22 @@ async def find_similar_images(
raise HTTPException(status_code=400, detail="Reference image does not have embeddings")
# Get the embedding for the reference image
reference_embedding = await vector_store_service.get_embedding(reference_image.embedding_id)
if not reference_embedding:
reference_data = vector_db_service.get_image_vector(image_id)
if not reference_data or not reference_data.get('vector'):
raise HTTPException(status_code=400, detail="Failed to get reference image embedding")
reference_embedding = reference_data['vector']
# Search for similar images
search_results = await vector_store_service.search_similar(
reference_embedding,
search_results = vector_db_service.search_similar_images(
query_vector=reference_embedding,
limit=limit + 1, # +1 to account for the reference image itself
threshold=threshold,
team_id=str(current_user.team_id)
score_threshold=threshold,
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
)
# Remove the reference image from results
search_results = [result for result in search_results if result['id'] != image_id][:limit]
search_results = [result for result in search_results if result['image_id'] != image_id][:limit]
if not search_results:
return SearchResponse(
@ -313,8 +316,8 @@ async def find_similar_images(
)
# Get image IDs and scores from search results
image_ids = [result['id'] for result in search_results]
scores = {result['id']: result['score'] for result in search_results}
image_ids = [result['image_id'] for result in search_results if result['image_id']]
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
# Get image metadata from database
images = await image_repository.get_by_ids(image_ids)

View File

@ -20,7 +20,7 @@ class VectorDatabaseService:
def __init__(
self,
host: str = None,
port: int = 6333,
port: int = None,
api_key: str = None,
collection_name: str = "image_vectors"
):
@ -34,7 +34,7 @@ class VectorDatabaseService:
collection_name: Name of the collection to use
"""
self.host = host or os.getenv("QDRANT_HOST", "localhost")
self.port = port
self.port = port or int(os.getenv("QDRANT_PORT", "6333"))
self.api_key = api_key or os.getenv("QDRANT_API_KEY")
self.collection_name = collection_name

View File

@ -1,267 +0,0 @@
import logging
from typing import List, Dict, Any, Optional, Tuple
import pinecone
from bson import ObjectId
from src.config.config import settings
logger = logging.getLogger(__name__)
class VectorStoreService:
"""Service for managing vector embeddings in Pinecone"""
def __init__(self):
self.api_key = settings.VECTOR_DB_API_KEY
self.environment = settings.VECTOR_DB_ENVIRONMENT
self.index_name = settings.VECTOR_DB_INDEX_NAME
self.dimension = 512 # CLIP model embedding dimension
self.initialized = False
self.index = None
def initialize(self):
"""
Initialize Pinecone connection and create index if needed
"""
if self.initialized:
return
if not self.api_key or not self.environment:
logger.warning("Pinecone API key or environment not provided, vector search disabled")
return
try:
logger.info(f"Initializing Pinecone with environment {self.environment}")
# Initialize Pinecone
pinecone.init(
api_key=self.api_key,
environment=self.environment
)
# Check if index exists
if self.index_name not in pinecone.list_indexes():
logger.info(f"Creating Pinecone index: {self.index_name}")
# Create index
pinecone.create_index(
name=self.index_name,
dimension=self.dimension,
metric="cosine"
)
# Connect to index
self.index = pinecone.Index(self.index_name)
self.initialized = True
logger.info("Pinecone initialized successfully")
except Exception as e:
logger.error(f"Error initializing Pinecone: {e}")
# Don't raise - we want to gracefully handle this and fall back to non-vector search
def store_embedding(self, image_id: str, team_id: str, embedding: List[float], metadata: Dict[str, Any] = None) -> Optional[str]:
"""
Store an embedding in Pinecone
Args:
image_id: Image ID
team_id: Team ID
embedding: Image embedding
metadata: Additional metadata
Returns:
Vector ID if successful, None otherwise
"""
self.initialize()
if not self.initialized:
logger.warning("Pinecone not initialized, cannot store embedding")
return None
try:
# Create metadata dict
meta = {
"image_id": image_id,
"team_id": team_id
}
if metadata:
meta.update(metadata)
# Create a unique vector ID
vector_id = f"{team_id}_{image_id}"
# Upsert the vector
self.index.upsert(
vectors=[
(vector_id, embedding, meta)
]
)
logger.info(f"Stored embedding for image {image_id}")
return vector_id
except Exception as e:
logger.error(f"Error storing embedding: {e}")
return None
async def search_similar(
self,
query_embedding: List[float],
limit: int = 10,
threshold: float = 0.7,
team_id: str = None
) -> List[Dict[str, Any]]:
"""
Search for similar embeddings
Args:
query_embedding: Query embedding vector
limit: Maximum number of results
threshold: Similarity threshold (0-1)
team_id: Optional team filter
Returns:
List of results with id and score
"""
self.initialize()
if not self.initialized:
logger.warning("Pinecone not initialized, cannot search")
return []
try:
# Create filter for team_id if provided
filter_dict = None
if team_id:
filter_dict = {
"team_id": {"$eq": team_id}
}
# Query the index
results = self.index.query(
vector=query_embedding,
filter=filter_dict,
top_k=limit,
include_metadata=True
)
# Format the results and apply threshold
formatted_results = []
for match in results.matches:
if match.score >= threshold:
formatted_results.append({
"id": match.metadata.get("image_id", match.id),
"score": match.score,
"metadata": match.metadata
})
return formatted_results
except Exception as e:
logger.error(f"Error searching similar embeddings: {e}")
return []
async def get_embedding(self, embedding_id: str) -> Optional[List[float]]:
"""
Get embedding by ID
Args:
embedding_id: Embedding ID
Returns:
Embedding vector or None if not found
"""
self.initialize()
if not self.initialized:
logger.warning("Pinecone not initialized, cannot get embedding")
return None
try:
# Fetch the vector
result = self.index.fetch(ids=[embedding_id])
if embedding_id in result.vectors:
return result.vectors[embedding_id].values
else:
logger.warning(f"Embedding not found: {embedding_id}")
return None
except Exception as e:
logger.error(f"Error getting embedding: {e}")
return None
def search_by_embedding(self, team_id: str, query_embedding: List[float], limit: int = 10) -> List[Dict[str, Any]]:
"""
Search for similar images by embedding
Args:
team_id: Team ID
query_embedding: Query embedding
limit: Maximum number of results
Returns:
List of results with image ID and similarity score
"""
self.initialize()
if not self.initialized:
logger.warning("Pinecone not initialized, cannot search by embedding")
return []
try:
# Create filter for team_id
filter_dict = {
"team_id": {"$eq": team_id}
}
# Query the index
results = self.index.query(
vector=query_embedding,
filter=filter_dict,
top_k=limit,
include_metadata=True
)
# Format the results
formatted_results = []
for match in results.matches:
formatted_results.append({
"image_id": match.metadata["image_id"],
"score": match.score,
"metadata": match.metadata
})
return formatted_results
except Exception as e:
logger.error(f"Error searching by embedding: {e}")
return []
def delete_embedding(self, image_id: str, team_id: str) -> bool:
"""
Delete an embedding from Pinecone
Args:
image_id: Image ID
team_id: Team ID
Returns:
True if successful, False otherwise
"""
self.initialize()
if not self.initialized:
logger.warning("Pinecone not initialized, cannot delete embedding")
return False
try:
# Create the vector ID
vector_id = f"{team_id}_{image_id}"
# Delete the vector
self.index.delete(ids=[vector_id])
logger.info(f"Deleted embedding for image {image_id}")
return True
except Exception as e:
logger.error(f"Error deleting embedding: {e}")
return False
# Create a singleton service
vector_store_service = VectorStoreService()

25
start_dev.sh Normal file
View File

@ -0,0 +1,25 @@
#!/bin/bash
# Development startup script for Sereact API
# This script sets the environment variables and starts the application
# Auto-generated by deployment/scripts/setup_local_env.sh
# Activate virtual environment
source venv/Scripts/activate
# Set environment variables from deployed infrastructure
export QDRANT_HOST=34.171.134.17
export QDRANT_PORT=6333
export FIRESTORE_PROJECT_ID=gen-lang-client-0424120530
export GCS_BUCKET_NAME=sereact-images
export ENVIRONMENT=development
# Start the application
echo "Starting Sereact API with deployed infrastructure..."
echo "Qdrant endpoint: http://$QDRANT_HOST:$QDRANT_PORT"
echo "Firestore project: $FIRESTORE_PROJECT_ID"
echo "GCS bucket: $GCS_BUCKET_NAME"
echo "API will be available at: http://localhost:8000"
echo "API documentation: http://localhost:8000/docs"
uvicorn main:app --host 0.0.0.0 --port 8000 --reload

43
test_qdrant_connection.py Normal file
View File

@ -0,0 +1,43 @@
#!/usr/bin/env python3
"""
Simple test script to verify Qdrant connection
Auto-generated by deployment/scripts/setup_local_env.sh
"""
import os
import sys
from src.services.vector_db import VectorDatabaseService
def test_qdrant_connection():
"""Test the connection to Qdrant"""
# Set environment variables from deployed infrastructure
os.environ['QDRANT_HOST'] = '34.171.134.17'
os.environ['QDRANT_PORT'] = '6333'
try:
print("Testing Qdrant connection...")
print(f"Host: {os.environ['QDRANT_HOST']}")
print(f"Port: {os.environ['QDRANT_PORT']}")
# Initialize the service
vector_db = VectorDatabaseService()
# Test health check
is_healthy = vector_db.health_check()
print(f"Health check: {'✓ PASSED' if is_healthy else '✗ FAILED'}")
# Get collection info
collection_info = vector_db.get_collection_info()
print(f"Collection info: {collection_info}")
print("\n✓ Qdrant connection test PASSED!")
return True
except Exception as e:
print(f"\n✗ Qdrant connection test FAILED: {e}")
return False
if __name__ == "__main__":
success = test_qdrant_connection()
sys.exit(0 if success else 1)

View File

@ -288,30 +288,25 @@ class TestEmbeddingService:
class TestEmbeddingServiceIntegration:
"""Integration tests for embedding service with other components"""
def test_embedding_to_vector_store_integration(self, embedding_service, sample_image_data, sample_image_model):
"""Test integration with vector store service"""
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate, \
patch('src.services.vector_store.VectorStoreService') as mock_vector_store:
def test_embedding_to_vector_db_integration(self, embedding_service, sample_image_data, sample_image_model):
"""Test integration between embedding service and vector database"""
# Mock the vector database service
with patch('src.services.vector_db.VectorDatabaseService') as mock_vector_db:
# Setup mock
mock_store = mock_vector_db.return_value
mock_store.add_image_vector.return_value = "test_point_id"
mock_generate.return_value = {
'embedding': mock_embedding,
'metadata': {'labels': [{'description': 'cat', 'score': 0.95}]},
'model': 'clip'
}
mock_store = mock_vector_store.return_value
mock_store.store_embedding.return_value = 'embedding_id_123'
# Process image and store embedding
result = embedding_service.process_and_store_image(
sample_image_data, sample_image_model
# Test storing embedding
embedding = [0.1] * 512 # Mock embedding
point_id = mock_store.add_image_vector(
image_id=str(sample_image_model.id),
vector=embedding,
metadata={"filename": sample_image_model.filename}
)
# Verify integration
assert result['embedding_id'] == 'embedding_id_123'
mock_store.store_embedding.assert_called_once()
# Verify the call
mock_store.add_image_vector.assert_called_once()
assert point_id == "test_point_id"
def test_pubsub_trigger_integration(self, embedding_service):
"""Test integration with Pub/Sub message processing"""

View File

@ -1,391 +0,0 @@
import pytest
import numpy as np
from unittest.mock import patch, MagicMock, AsyncMock
from bson import ObjectId
from src.services.vector_store import VectorStoreService
from src.models.image import ImageModel
class TestVectorStoreService:
"""Test vector store operations for semantic search"""
@pytest.fixture
def mock_pinecone_index(self):
"""Mock Pinecone index for testing"""
mock_index = MagicMock()
mock_index.upsert = MagicMock()
mock_index.query = MagicMock()
mock_index.delete = MagicMock()
mock_index.describe_index_stats = MagicMock()
return mock_index
@pytest.fixture
def vector_store_service(self, mock_pinecone_index):
"""Create vector store service with mocked dependencies"""
with patch('src.services.vector_store.pinecone') as mock_pinecone:
mock_pinecone.Index.return_value = mock_pinecone_index
service = VectorStoreService()
service.index = mock_pinecone_index
return service
@pytest.fixture
def sample_embedding(self):
"""Generate a sample embedding vector"""
return np.random.rand(512).tolist() # 512-dimensional vector
@pytest.fixture
def sample_image(self):
"""Create a sample image model"""
return ImageModel(
filename="test-image.jpg",
original_filename="test_image.jpg",
file_size=1024,
content_type="image/jpeg",
storage_path="images/test-image.jpg",
team_id=ObjectId(),
uploader_id=ObjectId(),
description="A test image",
tags=["test", "image"]
)
def test_store_embedding(self, vector_store_service, sample_embedding, sample_image):
"""Test storing an embedding in the vector database"""
# Store the embedding
embedding_id = vector_store_service.store_embedding(
image_id=str(sample_image.id),
embedding=sample_embedding,
metadata={
"filename": sample_image.filename,
"team_id": str(sample_image.team_id),
"tags": sample_image.tags,
"description": sample_image.description
}
)
# Verify the embedding was stored
assert embedding_id is not None
vector_store_service.index.upsert.assert_called_once()
# Check the upsert call arguments
call_args = vector_store_service.index.upsert.call_args
vectors = call_args[1]['vectors']
assert len(vectors) == 1
assert vectors[0]['id'] == embedding_id
assert len(vectors[0]['values']) == len(sample_embedding)
assert 'metadata' in vectors[0]
def test_search_similar_images(self, vector_store_service, sample_embedding):
"""Test searching for similar images using vector similarity"""
# Mock search results
mock_results = {
'matches': [
{
'id': 'embedding1',
'score': 0.95,
'metadata': {
'image_id': str(ObjectId()),
'filename': 'similar1.jpg',
'team_id': str(ObjectId()),
'tags': ['cat', 'animal']
}
},
{
'id': 'embedding2',
'score': 0.87,
'metadata': {
'image_id': str(ObjectId()),
'filename': 'similar2.jpg',
'team_id': str(ObjectId()),
'tags': ['dog', 'animal']
}
}
]
}
vector_store_service.index.query.return_value = mock_results
# Perform search
results = vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=str(ObjectId()),
top_k=10,
score_threshold=0.8
)
# Verify search was performed
vector_store_service.index.query.assert_called_once()
# Check results
assert len(results) == 2
assert results[0]['score'] == 0.95
assert results[1]['score'] == 0.87
assert all('image_id' in result for result in results)
def test_search_with_filters(self, vector_store_service, sample_embedding):
"""Test searching with metadata filters"""
team_id = str(ObjectId())
# Perform search with team filter
vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=team_id,
top_k=5,
filters={"tags": {"$in": ["cat", "dog"]}}
)
# Verify filter was applied
call_args = vector_store_service.index.query.call_args
assert 'filter' in call_args[1]
assert call_args[1]['filter']['team_id'] == team_id
def test_delete_embedding(self, vector_store_service):
"""Test deleting an embedding from the vector database"""
embedding_id = "test-embedding-123"
# Delete the embedding
success = vector_store_service.delete_embedding(embedding_id)
# Verify deletion was attempted
vector_store_service.index.delete.assert_called_once_with(ids=[embedding_id])
assert success is True
def test_batch_store_embeddings(self, vector_store_service, sample_embedding):
"""Test storing multiple embeddings in batch"""
# Create batch data
batch_data = []
for i in range(5):
batch_data.append({
'image_id': str(ObjectId()),
'embedding': sample_embedding,
'metadata': {
'filename': f'image{i}.jpg',
'team_id': str(ObjectId()),
'tags': [f'tag{i}']
}
})
# Store batch
embedding_ids = vector_store_service.batch_store_embeddings(batch_data)
# Verify batch storage
assert len(embedding_ids) == 5
vector_store_service.index.upsert.assert_called_once()
# Check batch upsert call
call_args = vector_store_service.index.upsert.call_args
vectors = call_args[1]['vectors']
assert len(vectors) == 5
def test_get_index_stats(self, vector_store_service):
"""Test getting vector database statistics"""
# Mock stats response
mock_stats = {
'total_vector_count': 1000,
'dimension': 512,
'index_fullness': 0.1
}
vector_store_service.index.describe_index_stats.return_value = mock_stats
# Get stats
stats = vector_store_service.get_index_stats()
# Verify stats retrieval
vector_store_service.index.describe_index_stats.assert_called_once()
assert stats['total_vector_count'] == 1000
assert stats['dimension'] == 512
def test_search_with_score_threshold(self, vector_store_service, sample_embedding):
"""Test filtering search results by score threshold"""
# Mock results with varying scores
mock_results = {
'matches': [
{'id': 'emb1', 'score': 0.95, 'metadata': {'image_id': '1'}},
{'id': 'emb2', 'score': 0.75, 'metadata': {'image_id': '2'}},
{'id': 'emb3', 'score': 0.65, 'metadata': {'image_id': '3'}},
{'id': 'emb4', 'score': 0.45, 'metadata': {'image_id': '4'}}
]
}
vector_store_service.index.query.return_value = mock_results
# Search with score threshold
results = vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=str(ObjectId()),
top_k=10,
score_threshold=0.7
)
# Only results above threshold should be returned
assert len(results) == 2
assert all(result['score'] >= 0.7 for result in results)
def test_update_embedding_metadata(self, vector_store_service):
"""Test updating metadata for an existing embedding"""
embedding_id = "test-embedding-123"
new_metadata = {
'tags': ['updated', 'tag'],
'description': 'Updated description'
}
# Update metadata
success = vector_store_service.update_embedding_metadata(
embedding_id, new_metadata
)
# Verify update was attempted
# This would depend on the actual implementation
assert success is True
def test_search_by_image_id(self, vector_store_service):
"""Test searching for a specific image's embedding"""
image_id = str(ObjectId())
# Mock search by metadata
mock_results = {
'matches': [
{
'id': 'embedding1',
'score': 1.0,
'metadata': {
'image_id': image_id,
'filename': 'target.jpg'
}
}
]
}
vector_store_service.index.query.return_value = mock_results
# Search by image ID
result = vector_store_service.get_embedding_by_image_id(image_id)
# Verify result
assert result is not None
assert result['metadata']['image_id'] == image_id
def test_bulk_delete_embeddings(self, vector_store_service):
"""Test deleting multiple embeddings"""
embedding_ids = ['emb1', 'emb2', 'emb3']
# Delete multiple embeddings
success = vector_store_service.bulk_delete_embeddings(embedding_ids)
# Verify bulk deletion
vector_store_service.index.delete.assert_called_once_with(ids=embedding_ids)
assert success is True
def test_search_pagination(self, vector_store_service, sample_embedding):
"""Test paginated search results"""
# This would test pagination if implemented
# Implementation depends on how pagination is handled in the vector store
pass
def test_vector_dimension_validation(self, vector_store_service):
"""Test validation of embedding dimensions"""
# Test with wrong dimension
wrong_dimension_embedding = np.random.rand(256).tolist() # Wrong size
with pytest.raises(ValueError):
vector_store_service.store_embedding(
image_id=str(ObjectId()),
embedding=wrong_dimension_embedding,
metadata={}
)
def test_connection_error_handling(self, vector_store_service):
"""Test handling of connection errors"""
# Mock connection error
vector_store_service.index.query.side_effect = Exception("Connection failed")
# Search should handle the error gracefully
with pytest.raises(Exception):
vector_store_service.search_similar(
query_embedding=[0.1] * 512,
team_id=str(ObjectId()),
top_k=10
)
def test_empty_search_results(self, vector_store_service, sample_embedding):
"""Test handling of empty search results"""
# Mock empty results
vector_store_service.index.query.return_value = {'matches': []}
# Search should return empty list
results = vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=str(ObjectId()),
top_k=10
)
assert results == []
class TestVectorStoreIntegration:
"""Integration tests for vector store with other services"""
def test_embedding_lifecycle(self, vector_store_service, sample_embedding, sample_image):
"""Test complete embedding lifecycle: store, search, update, delete"""
# Store embedding
embedding_id = vector_store_service.store_embedding(
image_id=str(sample_image.id),
embedding=sample_embedding,
metadata={'filename': sample_image.filename}
)
# Search for similar embeddings
mock_results = {
'matches': [
{
'id': embedding_id,
'score': 1.0,
'metadata': {'image_id': str(sample_image.id)}
}
]
}
vector_store_service.index.query.return_value = mock_results
results = vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=str(sample_image.team_id),
top_k=1
)
assert len(results) == 1
assert results[0]['id'] == embedding_id
# Delete embedding
success = vector_store_service.delete_embedding(embedding_id)
assert success is True
def test_team_isolation(self, vector_store_service, sample_embedding):
"""Test that team data is properly isolated"""
team1_id = str(ObjectId())
team2_id = str(ObjectId())
# Mock search results that should be filtered by team
mock_results = {
'matches': [
{
'id': 'emb1',
'score': 0.9,
'metadata': {'image_id': '1', 'team_id': team1_id}
},
{
'id': 'emb2',
'score': 0.8,
'metadata': {'image_id': '2', 'team_id': team2_id}
}
]
}
vector_store_service.index.query.return_value = mock_results
# Search for team1 should only return team1 results
results = vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=team1_id,
top_k=10
)
# Verify team filter was applied in the query
call_args = vector_store_service.index.query.call_args
assert 'filter' in call_args[1]
assert call_args[1]['filter']['team_id'] == team1_id