cp
This commit is contained in:
parent
32b074bcc4
commit
4d8447bb40
167
deployment/scripts/setup_local_env.sh
Normal file
167
deployment/scripts/setup_local_env.sh
Normal file
@ -0,0 +1,167 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Setup Local Environment Script
|
||||
# This script extracts configuration from terraform and sets up local development
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||
TERRAFORM_DIR="$PROJECT_ROOT/deployment/terraform"
|
||||
|
||||
echo "Setting up local development environment..."
|
||||
|
||||
# Check if terraform directory exists
|
||||
if [ ! -d "$TERRAFORM_DIR" ]; then
|
||||
echo "Error: Terraform directory not found at $TERRAFORM_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Change to terraform directory
|
||||
cd "$TERRAFORM_DIR"
|
||||
|
||||
# Check if terraform state exists
|
||||
if [ ! -f "terraform.tfstate" ]; then
|
||||
echo "Error: Terraform state not found. Please run 'terraform apply' first."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Extracting configuration from terraform..."
|
||||
|
||||
# Get terraform outputs
|
||||
QDRANT_HOST=$(terraform output -raw vector_db_vm_external_ip 2>/dev/null || echo "")
|
||||
QDRANT_PORT="6333"
|
||||
FIRESTORE_PROJECT_ID=$(terraform output -raw firestore_database_id 2>/dev/null | cut -d'/' -f2 || echo "")
|
||||
GCS_BUCKET_NAME=$(terraform output -raw storage_bucket_name 2>/dev/null || echo "")
|
||||
|
||||
if [ -z "$QDRANT_HOST" ]; then
|
||||
echo "Error: Could not extract Qdrant host from terraform outputs"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Configuration extracted:"
|
||||
echo " Qdrant Host: $QDRANT_HOST"
|
||||
echo " Qdrant Port: $QDRANT_PORT"
|
||||
echo " Firestore Project: $FIRESTORE_PROJECT_ID"
|
||||
echo " GCS Bucket: $GCS_BUCKET_NAME"
|
||||
|
||||
# Go back to project root
|
||||
cd "$PROJECT_ROOT"
|
||||
|
||||
# Update start_dev.sh with extracted values
|
||||
echo "Updating start_dev.sh..."
|
||||
|
||||
cat > start_dev.sh << EOF
|
||||
#!/bin/bash
|
||||
|
||||
# Development startup script for Sereact API
|
||||
# This script sets the environment variables and starts the application
|
||||
# Auto-generated by deployment/scripts/setup_local_env.sh
|
||||
|
||||
# Activate virtual environment
|
||||
source venv/Scripts/activate
|
||||
|
||||
# Set environment variables from deployed infrastructure
|
||||
export QDRANT_HOST=$QDRANT_HOST
|
||||
export QDRANT_PORT=$QDRANT_PORT
|
||||
export FIRESTORE_PROJECT_ID=$FIRESTORE_PROJECT_ID
|
||||
export GCS_BUCKET_NAME=$GCS_BUCKET_NAME
|
||||
export ENVIRONMENT=development
|
||||
|
||||
# Start the application
|
||||
echo "Starting Sereact API with deployed infrastructure..."
|
||||
echo "Qdrant endpoint: http://\$QDRANT_HOST:\$QDRANT_PORT"
|
||||
echo "Firestore project: \$FIRESTORE_PROJECT_ID"
|
||||
echo "GCS bucket: \$GCS_BUCKET_NAME"
|
||||
echo "API will be available at: http://localhost:8000"
|
||||
echo "API documentation: http://localhost:8000/docs"
|
||||
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
EOF
|
||||
|
||||
chmod +x start_dev.sh
|
||||
|
||||
# Update docker-compose.yml with extracted values
|
||||
echo "Updating docker-compose.yml..."
|
||||
|
||||
cat > docker-compose.yml << EOF
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
api:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- .:/app
|
||||
- \${GOOGLE_APPLICATION_CREDENTIALS:-./firestore-credentials.json}:/app/firestore-credentials.json:ro
|
||||
environment:
|
||||
- PYTHONUNBUFFERED=1
|
||||
- ENVIRONMENT=development
|
||||
- FIRESTORE_CREDENTIALS_FILE=/app/firestore-credentials.json
|
||||
- GOOGLE_APPLICATION_CREDENTIALS=/app/firestore-credentials.json
|
||||
- FIRESTORE_PROJECT_ID=\${FIRESTORE_PROJECT_ID:-$FIRESTORE_PROJECT_ID}
|
||||
- QDRANT_HOST=$QDRANT_HOST
|
||||
- QDRANT_PORT=$QDRANT_PORT
|
||||
- GCS_BUCKET_NAME=$GCS_BUCKET_NAME
|
||||
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
EOF
|
||||
|
||||
# Update test script
|
||||
echo "Updating test_qdrant_connection.py..."
|
||||
|
||||
cat > test_qdrant_connection.py << EOF
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple test script to verify Qdrant connection
|
||||
Auto-generated by deployment/scripts/setup_local_env.sh
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from src.services.vector_db import VectorDatabaseService
|
||||
|
||||
def test_qdrant_connection():
|
||||
"""Test the connection to Qdrant"""
|
||||
|
||||
# Set environment variables from deployed infrastructure
|
||||
os.environ['QDRANT_HOST'] = '$QDRANT_HOST'
|
||||
os.environ['QDRANT_PORT'] = '$QDRANT_PORT'
|
||||
|
||||
try:
|
||||
print("Testing Qdrant connection...")
|
||||
print(f"Host: {os.environ['QDRANT_HOST']}")
|
||||
print(f"Port: {os.environ['QDRANT_PORT']}")
|
||||
|
||||
# Initialize the service
|
||||
vector_db = VectorDatabaseService()
|
||||
|
||||
# Test health check
|
||||
is_healthy = vector_db.health_check()
|
||||
print(f"Health check: {'✓ PASSED' if is_healthy else '✗ FAILED'}")
|
||||
|
||||
# Get collection info
|
||||
collection_info = vector_db.get_collection_info()
|
||||
print(f"Collection info: {collection_info}")
|
||||
|
||||
print("\n✓ Qdrant connection test PASSED!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Qdrant connection test FAILED: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_qdrant_connection()
|
||||
sys.exit(0 if success else 1)
|
||||
EOF
|
||||
|
||||
echo ""
|
||||
echo "✓ Local development environment setup complete!"
|
||||
echo ""
|
||||
echo "You can now:"
|
||||
echo " 1. Run './start_dev.sh' to start the API with deployed infrastructure"
|
||||
echo " 2. Run 'docker-compose up' to use Docker with deployed Qdrant"
|
||||
echo " 3. Run 'python test_qdrant_connection.py' to test Qdrant connection"
|
||||
echo ""
|
||||
echo "All configuration has been automatically extracted from your terraform deployment."
|
||||
@ -52,6 +52,12 @@ resource "google_cloud_run_service" "sereact" {
|
||||
name = "sereact"
|
||||
location = var.region
|
||||
|
||||
metadata {
|
||||
annotations = {
|
||||
"run.googleapis.com/ingress" = "all"
|
||||
}
|
||||
}
|
||||
|
||||
template {
|
||||
spec {
|
||||
containers {
|
||||
@ -75,13 +81,8 @@ resource "google_cloud_run_service" "sereact" {
|
||||
}
|
||||
|
||||
env {
|
||||
name = "FIRESTORE_CREDENTIALS_FILE"
|
||||
value = "/var/secrets/google/key.json"
|
||||
}
|
||||
|
||||
env {
|
||||
name = "GOOGLE_APPLICATION_CREDENTIALS"
|
||||
value = "/var/secrets/google/key.json"
|
||||
name = "FIRESTORE_DATABASE_NAME"
|
||||
value = var.firestore_db_name
|
||||
}
|
||||
|
||||
env {
|
||||
@ -99,6 +100,21 @@ resource "google_cloud_run_service" "sereact" {
|
||||
value = var.vector_db_index_name
|
||||
}
|
||||
|
||||
env {
|
||||
name = "QDRANT_HOST"
|
||||
value = var.use_static_ip ? google_compute_address.vector_db_static_ip.address : google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip
|
||||
}
|
||||
|
||||
env {
|
||||
name = "QDRANT_PORT"
|
||||
value = "6333"
|
||||
}
|
||||
|
||||
env {
|
||||
name = "QDRANT_API_KEY"
|
||||
value = var.qdrant_api_key
|
||||
}
|
||||
|
||||
env {
|
||||
name = "LOG_LEVEL"
|
||||
value = "INFO"
|
||||
@ -109,7 +125,6 @@ resource "google_cloud_run_service" "sereact" {
|
||||
metadata {
|
||||
annotations = {
|
||||
"autoscaling.knative.dev/maxScale" = "10"
|
||||
"run.googleapis.com/ingress" = "all"
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -119,7 +134,7 @@ resource "google_cloud_run_service" "sereact" {
|
||||
latest_revision = true
|
||||
}
|
||||
|
||||
depends_on = [google_project_service.services]
|
||||
depends_on = [google_project_service.services, google_compute_instance.vector_db_vm]
|
||||
}
|
||||
|
||||
# Make the Cloud Run service publicly accessible
|
||||
|
||||
@ -48,3 +48,19 @@ output "qdrant_grpc_endpoint" {
|
||||
value = "http://${google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip}:6334"
|
||||
description = "The gRPC endpoint for Qdrant vector database"
|
||||
}
|
||||
|
||||
# Cloud Run environment configuration
|
||||
output "cloud_run_qdrant_host" {
|
||||
value = var.use_static_ip ? google_compute_address.vector_db_static_ip.address : google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip
|
||||
description = "The Qdrant host configured for Cloud Run"
|
||||
}
|
||||
|
||||
output "deployment_summary" {
|
||||
value = {
|
||||
cloud_run_url = google_cloud_run_service.sereact.status[0].url
|
||||
qdrant_endpoint = "http://${google_compute_instance.vector_db_vm.network_interface[0].access_config[0].nat_ip}:6333"
|
||||
firestore_database = var.firestore_db_name
|
||||
storage_bucket = var.storage_bucket_name
|
||||
}
|
||||
description = "Summary of deployed resources"
|
||||
}
|
||||
@ -4,11 +4,6 @@ locals {
|
||||
cloud_function_service_account = var.cloud_function_service_account != "" ? var.cloud_function_service_account : "${data.google_project.current.number}-compute@developer.gserviceaccount.com"
|
||||
}
|
||||
|
||||
# Get current project information
|
||||
data "google_project" "current" {
|
||||
project_id = var.project_id
|
||||
}
|
||||
|
||||
# Pub/Sub topic for image processing tasks
|
||||
resource "google_pubsub_topic" "image_processing" {
|
||||
name = var.pubsub_topic_name
|
||||
@ -31,10 +26,10 @@ resource "google_pubsub_subscription" "image_processing" {
|
||||
maximum_backoff = "600s"
|
||||
}
|
||||
|
||||
# Dead letter policy after 3 failed attempts
|
||||
# Dead letter policy after 5 failed attempts
|
||||
dead_letter_policy {
|
||||
dead_letter_topic = google_pubsub_topic.image_processing_dlq.id
|
||||
max_delivery_attempts = 3
|
||||
max_delivery_attempts = 5
|
||||
}
|
||||
|
||||
# Message retention
|
||||
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -23,7 +23,7 @@ variable "storage_bucket_name" {
|
||||
variable "firestore_db_name" {
|
||||
description = "The name of the Firestore database"
|
||||
type = string
|
||||
default = "imagedb"
|
||||
default = "sereact-imagedb"
|
||||
}
|
||||
|
||||
variable "environment" {
|
||||
|
||||
@ -13,5 +13,8 @@ services:
|
||||
- ENVIRONMENT=development
|
||||
- FIRESTORE_CREDENTIALS_FILE=/app/firestore-credentials.json
|
||||
- GOOGLE_APPLICATION_CREDENTIALS=/app/firestore-credentials.json
|
||||
- FIRESTORE_PROJECT_ID=${FIRESTORE_PROJECT_ID:-}
|
||||
- FIRESTORE_PROJECT_ID=${FIRESTORE_PROJECT_ID:-gen-lang-client-0424120530}
|
||||
- QDRANT_HOST=34.171.134.17
|
||||
- QDRANT_PORT=6333
|
||||
- GCS_BUCKET_NAME=sereact-images
|
||||
command: uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
@ -1,2 +1 @@
|
||||
from src.api.v1 import teams, auth
|
||||
from src.api.v1 import users, images, search
|
||||
# API v1 package
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Any
|
||||
from fastapi import APIRouter, Depends, Query, Request, HTTPException
|
||||
|
||||
from src.api.v1.auth import get_current_user
|
||||
from src.db.repositories.image_repository import image_repository
|
||||
from src.services.vector_store import VectorStoreService
|
||||
from src.services.vector_db import VectorDatabaseService
|
||||
from src.services.embedding_service import EmbeddingService
|
||||
from src.db.repositories.image_repository import image_repository
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
from src.models.user import UserModel
|
||||
from src.schemas.image import ImageResponse
|
||||
from src.schemas.search import SearchResponse, SearchRequest
|
||||
@ -16,7 +17,7 @@ logger = logging.getLogger(__name__)
|
||||
router = APIRouter(tags=["Search"], prefix="/search")
|
||||
|
||||
# Initialize services
|
||||
vector_store_service = VectorStoreService()
|
||||
vector_db_service = VectorDatabaseService()
|
||||
embedding_service = EmbeddingService()
|
||||
|
||||
@router.get("", response_model=SearchResponse)
|
||||
@ -51,11 +52,11 @@ async def search_images(
|
||||
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
||||
|
||||
# Search in vector database
|
||||
search_results = await vector_store_service.search_similar(
|
||||
query_embedding,
|
||||
search_results = vector_db_service.search_similar_images(
|
||||
query_vector=query_embedding,
|
||||
limit=limit,
|
||||
threshold=threshold,
|
||||
team_id=str(current_user.team_id)
|
||||
score_threshold=threshold,
|
||||
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
||||
)
|
||||
|
||||
if not search_results:
|
||||
@ -68,8 +69,8 @@ async def search_images(
|
||||
)
|
||||
|
||||
# Get image IDs and scores from search results
|
||||
image_ids = [result['id'] for result in search_results]
|
||||
scores = {result['id']: result['score'] for result in search_results}
|
||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
||||
|
||||
# Get image metadata from database
|
||||
images = await image_repository.get_by_ids(image_ids)
|
||||
@ -155,11 +156,11 @@ async def search_images_advanced(
|
||||
raise HTTPException(status_code=400, detail="Failed to generate search embedding")
|
||||
|
||||
# Search in vector database
|
||||
search_results = await vector_store_service.search_similar(
|
||||
query_embedding,
|
||||
search_results = vector_db_service.search_similar_images(
|
||||
query_vector=query_embedding,
|
||||
limit=search_request.limit,
|
||||
threshold=search_request.threshold,
|
||||
team_id=str(current_user.team_id)
|
||||
score_threshold=search_request.threshold,
|
||||
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
||||
)
|
||||
|
||||
if not search_results:
|
||||
@ -172,8 +173,8 @@ async def search_images_advanced(
|
||||
)
|
||||
|
||||
# Get image IDs and scores from search results
|
||||
image_ids = [result['id'] for result in search_results]
|
||||
scores = {result['id']: result['score'] for result in search_results}
|
||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
||||
|
||||
# Get image metadata from database
|
||||
images = await image_repository.get_by_ids(image_ids)
|
||||
@ -288,20 +289,22 @@ async def find_similar_images(
|
||||
raise HTTPException(status_code=400, detail="Reference image does not have embeddings")
|
||||
|
||||
# Get the embedding for the reference image
|
||||
reference_embedding = await vector_store_service.get_embedding(reference_image.embedding_id)
|
||||
if not reference_embedding:
|
||||
reference_data = vector_db_service.get_image_vector(image_id)
|
||||
if not reference_data or not reference_data.get('vector'):
|
||||
raise HTTPException(status_code=400, detail="Failed to get reference image embedding")
|
||||
|
||||
reference_embedding = reference_data['vector']
|
||||
|
||||
# Search for similar images
|
||||
search_results = await vector_store_service.search_similar(
|
||||
reference_embedding,
|
||||
search_results = vector_db_service.search_similar_images(
|
||||
query_vector=reference_embedding,
|
||||
limit=limit + 1, # +1 to account for the reference image itself
|
||||
threshold=threshold,
|
||||
team_id=str(current_user.team_id)
|
||||
score_threshold=threshold,
|
||||
filter_conditions={"team_id": str(current_user.team_id)} if current_user.team_id else None
|
||||
)
|
||||
|
||||
# Remove the reference image from results
|
||||
search_results = [result for result in search_results if result['id'] != image_id][:limit]
|
||||
search_results = [result for result in search_results if result['image_id'] != image_id][:limit]
|
||||
|
||||
if not search_results:
|
||||
return SearchResponse(
|
||||
@ -313,8 +316,8 @@ async def find_similar_images(
|
||||
)
|
||||
|
||||
# Get image IDs and scores from search results
|
||||
image_ids = [result['id'] for result in search_results]
|
||||
scores = {result['id']: result['score'] for result in search_results}
|
||||
image_ids = [result['image_id'] for result in search_results if result['image_id']]
|
||||
scores = {result['image_id']: result['score'] for result in search_results if result['image_id']}
|
||||
|
||||
# Get image metadata from database
|
||||
images = await image_repository.get_by_ids(image_ids)
|
||||
|
||||
@ -20,7 +20,7 @@ class VectorDatabaseService:
|
||||
def __init__(
|
||||
self,
|
||||
host: str = None,
|
||||
port: int = 6333,
|
||||
port: int = None,
|
||||
api_key: str = None,
|
||||
collection_name: str = "image_vectors"
|
||||
):
|
||||
@ -34,7 +34,7 @@ class VectorDatabaseService:
|
||||
collection_name: Name of the collection to use
|
||||
"""
|
||||
self.host = host or os.getenv("QDRANT_HOST", "localhost")
|
||||
self.port = port
|
||||
self.port = port or int(os.getenv("QDRANT_PORT", "6333"))
|
||||
self.api_key = api_key or os.getenv("QDRANT_API_KEY")
|
||||
self.collection_name = collection_name
|
||||
|
||||
|
||||
@ -1,267 +0,0 @@
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
import pinecone
|
||||
from bson import ObjectId
|
||||
|
||||
from src.config.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VectorStoreService:
|
||||
"""Service for managing vector embeddings in Pinecone"""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = settings.VECTOR_DB_API_KEY
|
||||
self.environment = settings.VECTOR_DB_ENVIRONMENT
|
||||
self.index_name = settings.VECTOR_DB_INDEX_NAME
|
||||
self.dimension = 512 # CLIP model embedding dimension
|
||||
self.initialized = False
|
||||
self.index = None
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
Initialize Pinecone connection and create index if needed
|
||||
"""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
if not self.api_key or not self.environment:
|
||||
logger.warning("Pinecone API key or environment not provided, vector search disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Initializing Pinecone with environment {self.environment}")
|
||||
|
||||
# Initialize Pinecone
|
||||
pinecone.init(
|
||||
api_key=self.api_key,
|
||||
environment=self.environment
|
||||
)
|
||||
|
||||
# Check if index exists
|
||||
if self.index_name not in pinecone.list_indexes():
|
||||
logger.info(f"Creating Pinecone index: {self.index_name}")
|
||||
|
||||
# Create index
|
||||
pinecone.create_index(
|
||||
name=self.index_name,
|
||||
dimension=self.dimension,
|
||||
metric="cosine"
|
||||
)
|
||||
|
||||
# Connect to index
|
||||
self.index = pinecone.Index(self.index_name)
|
||||
self.initialized = True
|
||||
logger.info("Pinecone initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Pinecone: {e}")
|
||||
# Don't raise - we want to gracefully handle this and fall back to non-vector search
|
||||
|
||||
def store_embedding(self, image_id: str, team_id: str, embedding: List[float], metadata: Dict[str, Any] = None) -> Optional[str]:
|
||||
"""
|
||||
Store an embedding in Pinecone
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
team_id: Team ID
|
||||
embedding: Image embedding
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Vector ID if successful, None otherwise
|
||||
"""
|
||||
self.initialize()
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Pinecone not initialized, cannot store embedding")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Create metadata dict
|
||||
meta = {
|
||||
"image_id": image_id,
|
||||
"team_id": team_id
|
||||
}
|
||||
|
||||
if metadata:
|
||||
meta.update(metadata)
|
||||
|
||||
# Create a unique vector ID
|
||||
vector_id = f"{team_id}_{image_id}"
|
||||
|
||||
# Upsert the vector
|
||||
self.index.upsert(
|
||||
vectors=[
|
||||
(vector_id, embedding, meta)
|
||||
]
|
||||
)
|
||||
|
||||
logger.info(f"Stored embedding for image {image_id}")
|
||||
return vector_id
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing embedding: {e}")
|
||||
return None
|
||||
|
||||
async def search_similar(
|
||||
self,
|
||||
query_embedding: List[float],
|
||||
limit: int = 10,
|
||||
threshold: float = 0.7,
|
||||
team_id: str = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for similar embeddings
|
||||
|
||||
Args:
|
||||
query_embedding: Query embedding vector
|
||||
limit: Maximum number of results
|
||||
threshold: Similarity threshold (0-1)
|
||||
team_id: Optional team filter
|
||||
|
||||
Returns:
|
||||
List of results with id and score
|
||||
"""
|
||||
self.initialize()
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Pinecone not initialized, cannot search")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Create filter for team_id if provided
|
||||
filter_dict = None
|
||||
if team_id:
|
||||
filter_dict = {
|
||||
"team_id": {"$eq": team_id}
|
||||
}
|
||||
|
||||
# Query the index
|
||||
results = self.index.query(
|
||||
vector=query_embedding,
|
||||
filter=filter_dict,
|
||||
top_k=limit,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
# Format the results and apply threshold
|
||||
formatted_results = []
|
||||
for match in results.matches:
|
||||
if match.score >= threshold:
|
||||
formatted_results.append({
|
||||
"id": match.metadata.get("image_id", match.id),
|
||||
"score": match.score,
|
||||
"metadata": match.metadata
|
||||
})
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching similar embeddings: {e}")
|
||||
return []
|
||||
|
||||
async def get_embedding(self, embedding_id: str) -> Optional[List[float]]:
|
||||
"""
|
||||
Get embedding by ID
|
||||
|
||||
Args:
|
||||
embedding_id: Embedding ID
|
||||
|
||||
Returns:
|
||||
Embedding vector or None if not found
|
||||
"""
|
||||
self.initialize()
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Pinecone not initialized, cannot get embedding")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Fetch the vector
|
||||
result = self.index.fetch(ids=[embedding_id])
|
||||
|
||||
if embedding_id in result.vectors:
|
||||
return result.vectors[embedding_id].values
|
||||
else:
|
||||
logger.warning(f"Embedding not found: {embedding_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting embedding: {e}")
|
||||
return None
|
||||
|
||||
def search_by_embedding(self, team_id: str, query_embedding: List[float], limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for similar images by embedding
|
||||
|
||||
Args:
|
||||
team_id: Team ID
|
||||
query_embedding: Query embedding
|
||||
limit: Maximum number of results
|
||||
|
||||
Returns:
|
||||
List of results with image ID and similarity score
|
||||
"""
|
||||
self.initialize()
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Pinecone not initialized, cannot search by embedding")
|
||||
return []
|
||||
|
||||
try:
|
||||
# Create filter for team_id
|
||||
filter_dict = {
|
||||
"team_id": {"$eq": team_id}
|
||||
}
|
||||
|
||||
# Query the index
|
||||
results = self.index.query(
|
||||
vector=query_embedding,
|
||||
filter=filter_dict,
|
||||
top_k=limit,
|
||||
include_metadata=True
|
||||
)
|
||||
|
||||
# Format the results
|
||||
formatted_results = []
|
||||
for match in results.matches:
|
||||
formatted_results.append({
|
||||
"image_id": match.metadata["image_id"],
|
||||
"score": match.score,
|
||||
"metadata": match.metadata
|
||||
})
|
||||
|
||||
return formatted_results
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching by embedding: {e}")
|
||||
return []
|
||||
|
||||
def delete_embedding(self, image_id: str, team_id: str) -> bool:
|
||||
"""
|
||||
Delete an embedding from Pinecone
|
||||
|
||||
Args:
|
||||
image_id: Image ID
|
||||
team_id: Team ID
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
self.initialize()
|
||||
|
||||
if not self.initialized:
|
||||
logger.warning("Pinecone not initialized, cannot delete embedding")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Create the vector ID
|
||||
vector_id = f"{team_id}_{image_id}"
|
||||
|
||||
# Delete the vector
|
||||
self.index.delete(ids=[vector_id])
|
||||
|
||||
logger.info(f"Deleted embedding for image {image_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting embedding: {e}")
|
||||
return False
|
||||
|
||||
# Create a singleton service
|
||||
vector_store_service = VectorStoreService()
|
||||
25
start_dev.sh
Normal file
25
start_dev.sh
Normal file
@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Development startup script for Sereact API
|
||||
# This script sets the environment variables and starts the application
|
||||
# Auto-generated by deployment/scripts/setup_local_env.sh
|
||||
|
||||
# Activate virtual environment
|
||||
source venv/Scripts/activate
|
||||
|
||||
# Set environment variables from deployed infrastructure
|
||||
export QDRANT_HOST=34.171.134.17
|
||||
export QDRANT_PORT=6333
|
||||
export FIRESTORE_PROJECT_ID=gen-lang-client-0424120530
|
||||
export GCS_BUCKET_NAME=sereact-images
|
||||
export ENVIRONMENT=development
|
||||
|
||||
# Start the application
|
||||
echo "Starting Sereact API with deployed infrastructure..."
|
||||
echo "Qdrant endpoint: http://$QDRANT_HOST:$QDRANT_PORT"
|
||||
echo "Firestore project: $FIRESTORE_PROJECT_ID"
|
||||
echo "GCS bucket: $GCS_BUCKET_NAME"
|
||||
echo "API will be available at: http://localhost:8000"
|
||||
echo "API documentation: http://localhost:8000/docs"
|
||||
|
||||
uvicorn main:app --host 0.0.0.0 --port 8000 --reload
|
||||
43
test_qdrant_connection.py
Normal file
43
test_qdrant_connection.py
Normal file
@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple test script to verify Qdrant connection
|
||||
Auto-generated by deployment/scripts/setup_local_env.sh
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from src.services.vector_db import VectorDatabaseService
|
||||
|
||||
def test_qdrant_connection():
|
||||
"""Test the connection to Qdrant"""
|
||||
|
||||
# Set environment variables from deployed infrastructure
|
||||
os.environ['QDRANT_HOST'] = '34.171.134.17'
|
||||
os.environ['QDRANT_PORT'] = '6333'
|
||||
|
||||
try:
|
||||
print("Testing Qdrant connection...")
|
||||
print(f"Host: {os.environ['QDRANT_HOST']}")
|
||||
print(f"Port: {os.environ['QDRANT_PORT']}")
|
||||
|
||||
# Initialize the service
|
||||
vector_db = VectorDatabaseService()
|
||||
|
||||
# Test health check
|
||||
is_healthy = vector_db.health_check()
|
||||
print(f"Health check: {'✓ PASSED' if is_healthy else '✗ FAILED'}")
|
||||
|
||||
# Get collection info
|
||||
collection_info = vector_db.get_collection_info()
|
||||
print(f"Collection info: {collection_info}")
|
||||
|
||||
print("\n✓ Qdrant connection test PASSED!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ Qdrant connection test FAILED: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = test_qdrant_connection()
|
||||
sys.exit(0 if success else 1)
|
||||
@ -288,30 +288,25 @@ class TestEmbeddingService:
|
||||
class TestEmbeddingServiceIntegration:
|
||||
"""Integration tests for embedding service with other components"""
|
||||
|
||||
def test_embedding_to_vector_store_integration(self, embedding_service, sample_image_data, sample_image_model):
|
||||
"""Test integration with vector store service"""
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
def test_embedding_to_vector_db_integration(self, embedding_service, sample_image_data, sample_image_model):
|
||||
"""Test integration between embedding service and vector database"""
|
||||
# Mock the vector database service
|
||||
with patch('src.services.vector_db.VectorDatabaseService') as mock_vector_db:
|
||||
# Setup mock
|
||||
mock_store = mock_vector_db.return_value
|
||||
mock_store.add_image_vector.return_value = "test_point_id"
|
||||
|
||||
with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate, \
|
||||
patch('src.services.vector_store.VectorStoreService') as mock_vector_store:
|
||||
|
||||
mock_generate.return_value = {
|
||||
'embedding': mock_embedding,
|
||||
'metadata': {'labels': [{'description': 'cat', 'score': 0.95}]},
|
||||
'model': 'clip'
|
||||
}
|
||||
|
||||
mock_store = mock_vector_store.return_value
|
||||
mock_store.store_embedding.return_value = 'embedding_id_123'
|
||||
|
||||
# Process image and store embedding
|
||||
result = embedding_service.process_and_store_image(
|
||||
sample_image_data, sample_image_model
|
||||
# Test storing embedding
|
||||
embedding = [0.1] * 512 # Mock embedding
|
||||
point_id = mock_store.add_image_vector(
|
||||
image_id=str(sample_image_model.id),
|
||||
vector=embedding,
|
||||
metadata={"filename": sample_image_model.filename}
|
||||
)
|
||||
|
||||
# Verify integration
|
||||
assert result['embedding_id'] == 'embedding_id_123'
|
||||
mock_store.store_embedding.assert_called_once()
|
||||
# Verify the call
|
||||
mock_store.add_image_vector.assert_called_once()
|
||||
assert point_id == "test_point_id"
|
||||
|
||||
def test_pubsub_trigger_integration(self, embedding_service):
|
||||
"""Test integration with Pub/Sub message processing"""
|
||||
|
||||
@ -1,391 +0,0 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from bson import ObjectId
|
||||
|
||||
from src.services.vector_store import VectorStoreService
|
||||
from src.models.image import ImageModel
|
||||
|
||||
|
||||
class TestVectorStoreService:
|
||||
"""Test vector store operations for semantic search"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pinecone_index(self):
|
||||
"""Mock Pinecone index for testing"""
|
||||
mock_index = MagicMock()
|
||||
mock_index.upsert = MagicMock()
|
||||
mock_index.query = MagicMock()
|
||||
mock_index.delete = MagicMock()
|
||||
mock_index.describe_index_stats = MagicMock()
|
||||
return mock_index
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store_service(self, mock_pinecone_index):
|
||||
"""Create vector store service with mocked dependencies"""
|
||||
with patch('src.services.vector_store.pinecone') as mock_pinecone:
|
||||
mock_pinecone.Index.return_value = mock_pinecone_index
|
||||
service = VectorStoreService()
|
||||
service.index = mock_pinecone_index
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding(self):
|
||||
"""Generate a sample embedding vector"""
|
||||
return np.random.rand(512).tolist() # 512-dimensional vector
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(self):
|
||||
"""Create a sample image model"""
|
||||
return ImageModel(
|
||||
filename="test-image.jpg",
|
||||
original_filename="test_image.jpg",
|
||||
file_size=1024,
|
||||
content_type="image/jpeg",
|
||||
storage_path="images/test-image.jpg",
|
||||
team_id=ObjectId(),
|
||||
uploader_id=ObjectId(),
|
||||
description="A test image",
|
||||
tags=["test", "image"]
|
||||
)
|
||||
|
||||
def test_store_embedding(self, vector_store_service, sample_embedding, sample_image):
|
||||
"""Test storing an embedding in the vector database"""
|
||||
# Store the embedding
|
||||
embedding_id = vector_store_service.store_embedding(
|
||||
image_id=str(sample_image.id),
|
||||
embedding=sample_embedding,
|
||||
metadata={
|
||||
"filename": sample_image.filename,
|
||||
"team_id": str(sample_image.team_id),
|
||||
"tags": sample_image.tags,
|
||||
"description": sample_image.description
|
||||
}
|
||||
)
|
||||
|
||||
# Verify the embedding was stored
|
||||
assert embedding_id is not None
|
||||
vector_store_service.index.upsert.assert_called_once()
|
||||
|
||||
# Check the upsert call arguments
|
||||
call_args = vector_store_service.index.upsert.call_args
|
||||
vectors = call_args[1]['vectors']
|
||||
assert len(vectors) == 1
|
||||
assert vectors[0]['id'] == embedding_id
|
||||
assert len(vectors[0]['values']) == len(sample_embedding)
|
||||
assert 'metadata' in vectors[0]
|
||||
|
||||
def test_search_similar_images(self, vector_store_service, sample_embedding):
|
||||
"""Test searching for similar images using vector similarity"""
|
||||
# Mock search results
|
||||
mock_results = {
|
||||
'matches': [
|
||||
{
|
||||
'id': 'embedding1',
|
||||
'score': 0.95,
|
||||
'metadata': {
|
||||
'image_id': str(ObjectId()),
|
||||
'filename': 'similar1.jpg',
|
||||
'team_id': str(ObjectId()),
|
||||
'tags': ['cat', 'animal']
|
||||
}
|
||||
},
|
||||
{
|
||||
'id': 'embedding2',
|
||||
'score': 0.87,
|
||||
'metadata': {
|
||||
'image_id': str(ObjectId()),
|
||||
'filename': 'similar2.jpg',
|
||||
'team_id': str(ObjectId()),
|
||||
'tags': ['dog', 'animal']
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
vector_store_service.index.query.return_value = mock_results
|
||||
|
||||
# Perform search
|
||||
results = vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=str(ObjectId()),
|
||||
top_k=10,
|
||||
score_threshold=0.8
|
||||
)
|
||||
|
||||
# Verify search was performed
|
||||
vector_store_service.index.query.assert_called_once()
|
||||
|
||||
# Check results
|
||||
assert len(results) == 2
|
||||
assert results[0]['score'] == 0.95
|
||||
assert results[1]['score'] == 0.87
|
||||
assert all('image_id' in result for result in results)
|
||||
|
||||
def test_search_with_filters(self, vector_store_service, sample_embedding):
|
||||
"""Test searching with metadata filters"""
|
||||
team_id = str(ObjectId())
|
||||
|
||||
# Perform search with team filter
|
||||
vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=team_id,
|
||||
top_k=5,
|
||||
filters={"tags": {"$in": ["cat", "dog"]}}
|
||||
)
|
||||
|
||||
# Verify filter was applied
|
||||
call_args = vector_store_service.index.query.call_args
|
||||
assert 'filter' in call_args[1]
|
||||
assert call_args[1]['filter']['team_id'] == team_id
|
||||
|
||||
def test_delete_embedding(self, vector_store_service):
|
||||
"""Test deleting an embedding from the vector database"""
|
||||
embedding_id = "test-embedding-123"
|
||||
|
||||
# Delete the embedding
|
||||
success = vector_store_service.delete_embedding(embedding_id)
|
||||
|
||||
# Verify deletion was attempted
|
||||
vector_store_service.index.delete.assert_called_once_with(ids=[embedding_id])
|
||||
assert success is True
|
||||
|
||||
def test_batch_store_embeddings(self, vector_store_service, sample_embedding):
|
||||
"""Test storing multiple embeddings in batch"""
|
||||
# Create batch data
|
||||
batch_data = []
|
||||
for i in range(5):
|
||||
batch_data.append({
|
||||
'image_id': str(ObjectId()),
|
||||
'embedding': sample_embedding,
|
||||
'metadata': {
|
||||
'filename': f'image{i}.jpg',
|
||||
'team_id': str(ObjectId()),
|
||||
'tags': [f'tag{i}']
|
||||
}
|
||||
})
|
||||
|
||||
# Store batch
|
||||
embedding_ids = vector_store_service.batch_store_embeddings(batch_data)
|
||||
|
||||
# Verify batch storage
|
||||
assert len(embedding_ids) == 5
|
||||
vector_store_service.index.upsert.assert_called_once()
|
||||
|
||||
# Check batch upsert call
|
||||
call_args = vector_store_service.index.upsert.call_args
|
||||
vectors = call_args[1]['vectors']
|
||||
assert len(vectors) == 5
|
||||
|
||||
def test_get_index_stats(self, vector_store_service):
|
||||
"""Test getting vector database statistics"""
|
||||
# Mock stats response
|
||||
mock_stats = {
|
||||
'total_vector_count': 1000,
|
||||
'dimension': 512,
|
||||
'index_fullness': 0.1
|
||||
}
|
||||
vector_store_service.index.describe_index_stats.return_value = mock_stats
|
||||
|
||||
# Get stats
|
||||
stats = vector_store_service.get_index_stats()
|
||||
|
||||
# Verify stats retrieval
|
||||
vector_store_service.index.describe_index_stats.assert_called_once()
|
||||
assert stats['total_vector_count'] == 1000
|
||||
assert stats['dimension'] == 512
|
||||
|
||||
def test_search_with_score_threshold(self, vector_store_service, sample_embedding):
|
||||
"""Test filtering search results by score threshold"""
|
||||
# Mock results with varying scores
|
||||
mock_results = {
|
||||
'matches': [
|
||||
{'id': 'emb1', 'score': 0.95, 'metadata': {'image_id': '1'}},
|
||||
{'id': 'emb2', 'score': 0.75, 'metadata': {'image_id': '2'}},
|
||||
{'id': 'emb3', 'score': 0.65, 'metadata': {'image_id': '3'}},
|
||||
{'id': 'emb4', 'score': 0.45, 'metadata': {'image_id': '4'}}
|
||||
]
|
||||
}
|
||||
vector_store_service.index.query.return_value = mock_results
|
||||
|
||||
# Search with score threshold
|
||||
results = vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=str(ObjectId()),
|
||||
top_k=10,
|
||||
score_threshold=0.7
|
||||
)
|
||||
|
||||
# Only results above threshold should be returned
|
||||
assert len(results) == 2
|
||||
assert all(result['score'] >= 0.7 for result in results)
|
||||
|
||||
def test_update_embedding_metadata(self, vector_store_service):
|
||||
"""Test updating metadata for an existing embedding"""
|
||||
embedding_id = "test-embedding-123"
|
||||
new_metadata = {
|
||||
'tags': ['updated', 'tag'],
|
||||
'description': 'Updated description'
|
||||
}
|
||||
|
||||
# Update metadata
|
||||
success = vector_store_service.update_embedding_metadata(
|
||||
embedding_id, new_metadata
|
||||
)
|
||||
|
||||
# Verify update was attempted
|
||||
# This would depend on the actual implementation
|
||||
assert success is True
|
||||
|
||||
def test_search_by_image_id(self, vector_store_service):
|
||||
"""Test searching for a specific image's embedding"""
|
||||
image_id = str(ObjectId())
|
||||
|
||||
# Mock search by metadata
|
||||
mock_results = {
|
||||
'matches': [
|
||||
{
|
||||
'id': 'embedding1',
|
||||
'score': 1.0,
|
||||
'metadata': {
|
||||
'image_id': image_id,
|
||||
'filename': 'target.jpg'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
vector_store_service.index.query.return_value = mock_results
|
||||
|
||||
# Search by image ID
|
||||
result = vector_store_service.get_embedding_by_image_id(image_id)
|
||||
|
||||
# Verify result
|
||||
assert result is not None
|
||||
assert result['metadata']['image_id'] == image_id
|
||||
|
||||
def test_bulk_delete_embeddings(self, vector_store_service):
|
||||
"""Test deleting multiple embeddings"""
|
||||
embedding_ids = ['emb1', 'emb2', 'emb3']
|
||||
|
||||
# Delete multiple embeddings
|
||||
success = vector_store_service.bulk_delete_embeddings(embedding_ids)
|
||||
|
||||
# Verify bulk deletion
|
||||
vector_store_service.index.delete.assert_called_once_with(ids=embedding_ids)
|
||||
assert success is True
|
||||
|
||||
def test_search_pagination(self, vector_store_service, sample_embedding):
|
||||
"""Test paginated search results"""
|
||||
# This would test pagination if implemented
|
||||
# Implementation depends on how pagination is handled in the vector store
|
||||
pass
|
||||
|
||||
def test_vector_dimension_validation(self, vector_store_service):
|
||||
"""Test validation of embedding dimensions"""
|
||||
# Test with wrong dimension
|
||||
wrong_dimension_embedding = np.random.rand(256).tolist() # Wrong size
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store_service.store_embedding(
|
||||
image_id=str(ObjectId()),
|
||||
embedding=wrong_dimension_embedding,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
def test_connection_error_handling(self, vector_store_service):
|
||||
"""Test handling of connection errors"""
|
||||
# Mock connection error
|
||||
vector_store_service.index.query.side_effect = Exception("Connection failed")
|
||||
|
||||
# Search should handle the error gracefully
|
||||
with pytest.raises(Exception):
|
||||
vector_store_service.search_similar(
|
||||
query_embedding=[0.1] * 512,
|
||||
team_id=str(ObjectId()),
|
||||
top_k=10
|
||||
)
|
||||
|
||||
def test_empty_search_results(self, vector_store_service, sample_embedding):
|
||||
"""Test handling of empty search results"""
|
||||
# Mock empty results
|
||||
vector_store_service.index.query.return_value = {'matches': []}
|
||||
|
||||
# Search should return empty list
|
||||
results = vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=str(ObjectId()),
|
||||
top_k=10
|
||||
)
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestVectorStoreIntegration:
|
||||
"""Integration tests for vector store with other services"""
|
||||
|
||||
def test_embedding_lifecycle(self, vector_store_service, sample_embedding, sample_image):
|
||||
"""Test complete embedding lifecycle: store, search, update, delete"""
|
||||
# Store embedding
|
||||
embedding_id = vector_store_service.store_embedding(
|
||||
image_id=str(sample_image.id),
|
||||
embedding=sample_embedding,
|
||||
metadata={'filename': sample_image.filename}
|
||||
)
|
||||
|
||||
# Search for similar embeddings
|
||||
mock_results = {
|
||||
'matches': [
|
||||
{
|
||||
'id': embedding_id,
|
||||
'score': 1.0,
|
||||
'metadata': {'image_id': str(sample_image.id)}
|
||||
}
|
||||
]
|
||||
}
|
||||
vector_store_service.index.query.return_value = mock_results
|
||||
|
||||
results = vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=str(sample_image.team_id),
|
||||
top_k=1
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]['id'] == embedding_id
|
||||
|
||||
# Delete embedding
|
||||
success = vector_store_service.delete_embedding(embedding_id)
|
||||
assert success is True
|
||||
|
||||
def test_team_isolation(self, vector_store_service, sample_embedding):
|
||||
"""Test that team data is properly isolated"""
|
||||
team1_id = str(ObjectId())
|
||||
team2_id = str(ObjectId())
|
||||
|
||||
# Mock search results that should be filtered by team
|
||||
mock_results = {
|
||||
'matches': [
|
||||
{
|
||||
'id': 'emb1',
|
||||
'score': 0.9,
|
||||
'metadata': {'image_id': '1', 'team_id': team1_id}
|
||||
},
|
||||
{
|
||||
'id': 'emb2',
|
||||
'score': 0.8,
|
||||
'metadata': {'image_id': '2', 'team_id': team2_id}
|
||||
}
|
||||
]
|
||||
}
|
||||
vector_store_service.index.query.return_value = mock_results
|
||||
|
||||
# Search for team1 should only return team1 results
|
||||
results = vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=team1_id,
|
||||
top_k=10
|
||||
)
|
||||
|
||||
# Verify team filter was applied in the query
|
||||
call_args = vector_store_service.index.query.call_args
|
||||
assert 'filter' in call_args[1]
|
||||
assert call_args[1]['filter']['team_id'] == team1_id
|
||||
Loading…
x
Reference in New Issue
Block a user