image_management_api/tests/services/test_embedding_service.py
2025-05-24 18:33:51 +02:00

355 lines
16 KiB
Python

import pytest
import numpy as np
from unittest.mock import patch, MagicMock, AsyncMock
from bson import ObjectId
from io import BytesIO
from src.services.embedding_service import EmbeddingService
from src.models.image import ImageModel
class TestEmbeddingService:
"""Test embedding generation for images"""
@pytest.fixture
def mock_vision_client(self):
"""Mock Google Cloud Vision client"""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.image_properties_annotation.dominant_colors.colors = []
mock_response.label_annotations = []
mock_response.object_localizations = []
mock_response.text_annotations = []
mock_client.annotate_image.return_value = mock_response
return mock_client
@pytest.fixture
def embedding_service(self, mock_vision_client):
"""Create embedding service with mocked dependencies"""
with patch('src.services.embedding_service.vision') as mock_vision:
mock_vision.ImageAnnotatorClient.return_value = mock_vision_client
service = EmbeddingService()
service.client = mock_vision_client
return service
@pytest.fixture
def sample_image_data(self):
"""Create sample image data"""
# Create a simple test image (1x1 pixel PNG)
image_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\nIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xdd\x8d\xb4\x1c\x00\x00\x00\x00IEND\xaeB`\x82'
return BytesIO(image_data)
@pytest.fixture
def sample_image_model(self):
"""Create a sample image model"""
return ImageModel(
filename="test-image.jpg",
original_filename="test_image.jpg",
file_size=1024,
content_type="image/jpeg",
storage_path="images/test-image.jpg",
team_id=ObjectId(),
uploader_id=ObjectId()
)
def test_generate_embedding_from_image(self, embedding_service, sample_image_data):
"""Test generating embeddings from image data"""
# Mock the embedding generation
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
# Generate embedding
embedding = embedding_service.generate_embedding(sample_image_data)
# Verify embedding was generated
assert embedding is not None
assert len(embedding) == 512
assert isinstance(embedding, list)
assert all(isinstance(x, (int, float)) for x in embedding)
def test_extract_image_features(self, embedding_service, sample_image_data):
"""Test extracting features from images using Vision API"""
# Mock Vision API response
mock_response = MagicMock()
mock_response.label_annotations = [
MagicMock(description="cat", score=0.95),
MagicMock(description="animal", score=0.87),
MagicMock(description="pet", score=0.82)
]
mock_response.object_localizations = [
MagicMock(name="Cat", score=0.9)
]
mock_response.image_properties_annotation.dominant_colors.colors = [
MagicMock(color=MagicMock(red=255, green=100, blue=50), score=0.8)
]
embedding_service.client.annotate_image.return_value = mock_response
# Extract features
features = embedding_service.extract_image_features(sample_image_data)
# Verify features were extracted
assert 'labels' in features
assert 'objects' in features
assert 'colors' in features
assert len(features['labels']) == 3
assert features['labels'][0]['description'] == "cat"
assert features['labels'][0]['score'] == 0.95
def test_generate_embedding_with_metadata(self, embedding_service, sample_image_data, sample_image_model):
"""Test generating embeddings with image metadata"""
mock_embedding = np.random.rand(512).tolist()
mock_features = {
'labels': [{'description': 'cat', 'score': 0.95}],
'objects': [{'name': 'Cat', 'score': 0.9}],
'colors': [{'red': 255, 'green': 100, 'blue': 50, 'score': 0.8}]
}
with patch.object(embedding_service, '_extract_features') as mock_extract_features, \
patch.object(embedding_service, 'extract_image_features') as mock_extract_metadata:
mock_extract_features.return_value = mock_embedding
mock_extract_metadata.return_value = mock_features
# Generate embedding with metadata
result = embedding_service.generate_embedding_with_metadata(
sample_image_data, sample_image_model
)
# Verify result structure
assert 'embedding' in result
assert 'metadata' in result
assert 'model' in result
assert len(result['embedding']) == 512
assert result['metadata']['labels'][0]['description'] == 'cat'
assert result['model'] == 'clip' # or whatever model is used
def test_batch_generate_embeddings(self, embedding_service):
"""Test generating embeddings for multiple images in batch"""
# Create multiple image data samples
image_batch = []
for i in range(3):
image_data = BytesIO(b'fake_image_data_' + str(i).encode())
image_model = ImageModel(
filename=f"image{i}.jpg",
original_filename=f"image{i}.jpg",
file_size=1024,
content_type="image/jpeg",
storage_path=f"images/image{i}.jpg",
team_id=ObjectId(),
uploader_id=ObjectId()
)
image_batch.append((image_data, image_model))
# Mock embedding generation
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate:
mock_generate.return_value = {
'embedding': mock_embedding,
'metadata': {'labels': []},
'model': 'clip'
}
# Generate batch embeddings
results = embedding_service.batch_generate_embeddings(image_batch)
# Verify batch results
assert len(results) == 3
assert all('embedding' in result for result in results)
assert all('metadata' in result for result in results)
def test_embedding_model_consistency(self, embedding_service, sample_image_data):
"""Test that the same image produces consistent embeddings"""
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
# Generate embedding twice
embedding1 = embedding_service.generate_embedding(sample_image_data)
sample_image_data.seek(0) # Reset stream position
embedding2 = embedding_service.generate_embedding(sample_image_data)
# Embeddings should be identical for the same image
assert embedding1 == embedding2
def test_embedding_dimension_validation(self, embedding_service):
"""Test that embeddings have the correct dimensions"""
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
# Validate embedding dimensions
assert embedding_service.validate_embedding_dimensions(mock_embedding) is True
# Test wrong dimensions
wrong_embedding = np.random.rand(256).tolist()
assert embedding_service.validate_embedding_dimensions(wrong_embedding) is False
def test_handle_unsupported_image_format(self, embedding_service):
"""Test handling of unsupported image formats"""
# Create invalid image data
invalid_data = BytesIO(b'not_an_image')
# Should raise appropriate exception
with pytest.raises(ValueError):
embedding_service.generate_embedding(invalid_data)
def test_handle_corrupted_image(self, embedding_service):
"""Test handling of corrupted image data"""
# Create corrupted image data
corrupted_data = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00corrupted')
# Should handle gracefully
with pytest.raises(Exception):
embedding_service.generate_embedding(corrupted_data)
def test_vision_api_error_handling(self, embedding_service, sample_image_data):
"""Test handling of Vision API errors"""
# Mock Vision API error
embedding_service.client.annotate_image.side_effect = Exception("Vision API error")
# Should handle the error gracefully
with pytest.raises(Exception):
embedding_service.extract_image_features(sample_image_data)
def test_embedding_caching(self, embedding_service, sample_image_data):
"""Test caching of embeddings for the same image"""
# This would test caching functionality if implemented
pass
def test_embedding_quality_metrics(self, embedding_service, sample_image_data):
"""Test quality metrics for generated embeddings"""
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
# Generate embedding
embedding = embedding_service.generate_embedding(sample_image_data)
# Check embedding quality metrics
quality_score = embedding_service.calculate_embedding_quality(embedding)
assert 0 <= quality_score <= 1
def test_different_image_types(self, embedding_service):
"""Test embedding generation for different image types"""
image_types = [
('image/jpeg', b'fake_jpeg_data'),
('image/png', b'fake_png_data'),
('image/webp', b'fake_webp_data')
]
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
for content_type, data in image_types:
image_data = BytesIO(data)
# Should handle different image types
embedding = embedding_service.generate_embedding(image_data)
assert len(embedding) == 512
def test_large_image_handling(self, embedding_service):
"""Test handling of large images"""
# Create large image data (simulated)
large_image_data = BytesIO(b'x' * (10 * 1024 * 1024)) # 10MB
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
# Should handle large images
embedding = embedding_service.generate_embedding(large_image_data)
assert len(embedding) == 512
def test_embedding_normalization(self, embedding_service, sample_image_data):
"""Test that embeddings are properly normalized"""
# Generate raw embedding
raw_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = raw_embedding
# Generate normalized embedding
embedding = embedding_service.generate_embedding(sample_image_data, normalize=True)
# Check if embedding is normalized (L2 norm should be 1)
norm = np.linalg.norm(embedding)
assert abs(norm - 1.0) < 0.001 # Allow small floating point errors
class TestEmbeddingServiceIntegration:
"""Integration tests for embedding service with other components"""
def test_embedding_to_vector_db_integration(self, embedding_service, sample_image_data, sample_image_model):
"""Test integration between embedding service and vector database"""
# Mock the vector database service
with patch('src.services.vector_db.VectorDatabaseService') as mock_vector_db:
# Setup mock
mock_store = mock_vector_db.return_value
mock_store.add_image_vector.return_value = "test_point_id"
# Test storing embedding
embedding = [0.1] * 512 # Mock embedding
point_id = mock_store.add_image_vector(
image_id=str(sample_image_model.id),
vector=embedding,
metadata={"filename": sample_image_model.filename}
)
# Verify the call
mock_store.add_image_vector.assert_called_once()
assert point_id == "test_point_id"
def test_pubsub_trigger_integration(self, embedding_service):
"""Test integration with Pub/Sub message processing"""
# Mock Pub/Sub message
mock_message = {
'image_id': str(ObjectId()),
'storage_path': 'images/test.jpg',
'team_id': str(ObjectId())
}
with patch.object(embedding_service, 'process_image_from_storage') as mock_process:
mock_process.return_value = {'embedding_id': 'emb123'}
# Process Pub/Sub message
result = embedding_service.handle_pubsub_message(mock_message)
# Verify message processing
assert result['embedding_id'] == 'emb123'
mock_process.assert_called_once_with(
mock_message['storage_path'],
mock_message['image_id'],
mock_message['team_id']
)
def test_cloud_function_deployment(self, embedding_service):
"""Test Cloud Function deployment compatibility"""
# Test that the service can be initialized in a Cloud Function environment
# This would test environment variable loading, authentication, etc.
pass
def test_error_recovery_and_retry(self, embedding_service, sample_image_data):
"""Test error recovery and retry mechanisms"""
# Mock transient error followed by success
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
# First call fails, second succeeds
mock_extract.side_effect = [Exception("Transient error"), mock_embedding]
# Should retry and succeed
embedding = embedding_service.generate_embedding_with_retry(
sample_image_data, max_retries=2
)
assert len(embedding) == 512
assert mock_extract.call_count == 2