355 lines
16 KiB
Python
355 lines
16 KiB
Python
import pytest
|
|
import numpy as np
|
|
from unittest.mock import patch, MagicMock, AsyncMock
|
|
from bson import ObjectId
|
|
from io import BytesIO
|
|
|
|
from src.services.embedding_service import EmbeddingService
|
|
from src.models.image import ImageModel
|
|
|
|
|
|
class TestEmbeddingService:
|
|
"""Test embedding generation for images"""
|
|
|
|
@pytest.fixture
|
|
def mock_vision_client(self):
|
|
"""Mock Google Cloud Vision client"""
|
|
mock_client = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.image_properties_annotation.dominant_colors.colors = []
|
|
mock_response.label_annotations = []
|
|
mock_response.object_localizations = []
|
|
mock_response.text_annotations = []
|
|
mock_client.annotate_image.return_value = mock_response
|
|
return mock_client
|
|
|
|
@pytest.fixture
|
|
def embedding_service(self, mock_vision_client):
|
|
"""Create embedding service with mocked dependencies"""
|
|
with patch('src.services.embedding_service.vision') as mock_vision:
|
|
mock_vision.ImageAnnotatorClient.return_value = mock_vision_client
|
|
service = EmbeddingService()
|
|
service.client = mock_vision_client
|
|
return service
|
|
|
|
@pytest.fixture
|
|
def sample_image_data(self):
|
|
"""Create sample image data"""
|
|
# Create a simple test image (1x1 pixel PNG)
|
|
image_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\nIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xdd\x8d\xb4\x1c\x00\x00\x00\x00IEND\xaeB`\x82'
|
|
return BytesIO(image_data)
|
|
|
|
@pytest.fixture
|
|
def sample_image_model(self):
|
|
"""Create a sample image model"""
|
|
return ImageModel(
|
|
filename="test-image.jpg",
|
|
original_filename="test_image.jpg",
|
|
file_size=1024,
|
|
content_type="image/jpeg",
|
|
storage_path="images/test-image.jpg",
|
|
team_id=ObjectId(),
|
|
uploader_id=ObjectId()
|
|
)
|
|
|
|
def test_generate_embedding_from_image(self, embedding_service, sample_image_data):
|
|
"""Test generating embeddings from image data"""
|
|
# Mock the embedding generation
|
|
mock_embedding = np.random.rand(512).tolist()
|
|
|
|
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
|
mock_extract.return_value = mock_embedding
|
|
|
|
# Generate embedding
|
|
embedding = embedding_service.generate_embedding(sample_image_data)
|
|
|
|
# Verify embedding was generated
|
|
assert embedding is not None
|
|
assert len(embedding) == 512
|
|
assert isinstance(embedding, list)
|
|
assert all(isinstance(x, (int, float)) for x in embedding)
|
|
|
|
def test_extract_image_features(self, embedding_service, sample_image_data):
|
|
"""Test extracting features from images using Vision API"""
|
|
# Mock Vision API response
|
|
mock_response = MagicMock()
|
|
mock_response.label_annotations = [
|
|
MagicMock(description="cat", score=0.95),
|
|
MagicMock(description="animal", score=0.87),
|
|
MagicMock(description="pet", score=0.82)
|
|
]
|
|
mock_response.object_localizations = [
|
|
MagicMock(name="Cat", score=0.9)
|
|
]
|
|
mock_response.image_properties_annotation.dominant_colors.colors = [
|
|
MagicMock(color=MagicMock(red=255, green=100, blue=50), score=0.8)
|
|
]
|
|
|
|
embedding_service.client.annotate_image.return_value = mock_response
|
|
|
|
# Extract features
|
|
features = embedding_service.extract_image_features(sample_image_data)
|
|
|
|
# Verify features were extracted
|
|
assert 'labels' in features
|
|
assert 'objects' in features
|
|
assert 'colors' in features
|
|
assert len(features['labels']) == 3
|
|
assert features['labels'][0]['description'] == "cat"
|
|
assert features['labels'][0]['score'] == 0.95
|
|
|
|
def test_generate_embedding_with_metadata(self, embedding_service, sample_image_data, sample_image_model):
|
|
"""Test generating embeddings with image metadata"""
|
|
mock_embedding = np.random.rand(512).tolist()
|
|
mock_features = {
|
|
'labels': [{'description': 'cat', 'score': 0.95}],
|
|
'objects': [{'name': 'Cat', 'score': 0.9}],
|
|
'colors': [{'red': 255, 'green': 100, 'blue': 50, 'score': 0.8}]
|
|
}
|
|
|
|
with patch.object(embedding_service, '_extract_features') as mock_extract_features, \
|
|
patch.object(embedding_service, 'extract_image_features') as mock_extract_metadata:
|
|
|
|
mock_extract_features.return_value = mock_embedding
|
|
mock_extract_metadata.return_value = mock_features
|
|
|
|
# Generate embedding with metadata
|
|
result = embedding_service.generate_embedding_with_metadata(
|
|
sample_image_data, sample_image_model
|
|
)
|
|
|
|
# Verify result structure
|
|
assert 'embedding' in result
|
|
assert 'metadata' in result
|
|
assert 'model' in result
|
|
assert len(result['embedding']) == 512
|
|
assert result['metadata']['labels'][0]['description'] == 'cat'
|
|
assert result['model'] == 'clip' # or whatever model is used
|
|
|
|
def test_batch_generate_embeddings(self, embedding_service):
|
|
"""Test generating embeddings for multiple images in batch"""
|
|
# Create multiple image data samples
|
|
image_batch = []
|
|
for i in range(3):
|
|
image_data = BytesIO(b'fake_image_data_' + str(i).encode())
|
|
image_model = ImageModel(
|
|
filename=f"image{i}.jpg",
|
|
original_filename=f"image{i}.jpg",
|
|
file_size=1024,
|
|
content_type="image/jpeg",
|
|
storage_path=f"images/image{i}.jpg",
|
|
team_id=ObjectId(),
|
|
uploader_id=ObjectId()
|
|
)
|
|
image_batch.append((image_data, image_model))
|
|
|
|
# Mock embedding generation
|
|
mock_embedding = np.random.rand(512).tolist()
|
|
with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate:
|
|
mock_generate.return_value = {
|
|
'embedding': mock_embedding,
|
|
'metadata': {'labels': []},
|
|
'model': 'clip'
|
|
}
|
|
|
|
# Generate batch embeddings
|
|
results = embedding_service.batch_generate_embeddings(image_batch)
|
|
|
|
# Verify batch results
|
|
assert len(results) == 3
|
|
assert all('embedding' in result for result in results)
|
|
assert all('metadata' in result for result in results)
|
|
|
|
def test_embedding_model_consistency(self, embedding_service, sample_image_data):
|
|
"""Test that the same image produces consistent embeddings"""
|
|
mock_embedding = np.random.rand(512).tolist()
|
|
|
|
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
|
mock_extract.return_value = mock_embedding
|
|
|
|
# Generate embedding twice
|
|
embedding1 = embedding_service.generate_embedding(sample_image_data)
|
|
sample_image_data.seek(0) # Reset stream position
|
|
embedding2 = embedding_service.generate_embedding(sample_image_data)
|
|
|
|
# Embeddings should be identical for the same image
|
|
assert embedding1 == embedding2
|
|
|
|
def test_embedding_dimension_validation(self, embedding_service):
|
|
"""Test that embeddings have the correct dimensions"""
|
|
mock_embedding = np.random.rand(512).tolist()
|
|
|
|
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
|
mock_extract.return_value = mock_embedding
|
|
|
|
# Validate embedding dimensions
|
|
assert embedding_service.validate_embedding_dimensions(mock_embedding) is True
|
|
|
|
# Test wrong dimensions
|
|
wrong_embedding = np.random.rand(256).tolist()
|
|
assert embedding_service.validate_embedding_dimensions(wrong_embedding) is False
|
|
|
|
def test_handle_unsupported_image_format(self, embedding_service):
|
|
"""Test handling of unsupported image formats"""
|
|
# Create invalid image data
|
|
invalid_data = BytesIO(b'not_an_image')
|
|
|
|
# Should raise appropriate exception
|
|
with pytest.raises(ValueError):
|
|
embedding_service.generate_embedding(invalid_data)
|
|
|
|
def test_handle_corrupted_image(self, embedding_service):
|
|
"""Test handling of corrupted image data"""
|
|
# Create corrupted image data
|
|
corrupted_data = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00corrupted')
|
|
|
|
# Should handle gracefully
|
|
with pytest.raises(Exception):
|
|
embedding_service.generate_embedding(corrupted_data)
|
|
|
|
def test_vision_api_error_handling(self, embedding_service, sample_image_data):
|
|
"""Test handling of Vision API errors"""
|
|
# Mock Vision API error
|
|
embedding_service.client.annotate_image.side_effect = Exception("Vision API error")
|
|
|
|
# Should handle the error gracefully
|
|
with pytest.raises(Exception):
|
|
embedding_service.extract_image_features(sample_image_data)
|
|
|
|
def test_embedding_caching(self, embedding_service, sample_image_data):
|
|
"""Test caching of embeddings for the same image"""
|
|
# This would test caching functionality if implemented
|
|
pass
|
|
|
|
def test_embedding_quality_metrics(self, embedding_service, sample_image_data):
|
|
"""Test quality metrics for generated embeddings"""
|
|
mock_embedding = np.random.rand(512).tolist()
|
|
|
|
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
|
mock_extract.return_value = mock_embedding
|
|
|
|
# Generate embedding
|
|
embedding = embedding_service.generate_embedding(sample_image_data)
|
|
|
|
# Check embedding quality metrics
|
|
quality_score = embedding_service.calculate_embedding_quality(embedding)
|
|
assert 0 <= quality_score <= 1
|
|
|
|
def test_different_image_types(self, embedding_service):
|
|
"""Test embedding generation for different image types"""
|
|
image_types = [
|
|
('image/jpeg', b'fake_jpeg_data'),
|
|
('image/png', b'fake_png_data'),
|
|
('image/webp', b'fake_webp_data')
|
|
]
|
|
|
|
mock_embedding = np.random.rand(512).tolist()
|
|
|
|
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
|
mock_extract.return_value = mock_embedding
|
|
|
|
for content_type, data in image_types:
|
|
image_data = BytesIO(data)
|
|
|
|
# Should handle different image types
|
|
embedding = embedding_service.generate_embedding(image_data)
|
|
assert len(embedding) == 512
|
|
|
|
def test_large_image_handling(self, embedding_service):
|
|
"""Test handling of large images"""
|
|
# Create large image data (simulated)
|
|
large_image_data = BytesIO(b'x' * (10 * 1024 * 1024)) # 10MB
|
|
|
|
mock_embedding = np.random.rand(512).tolist()
|
|
|
|
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
|
mock_extract.return_value = mock_embedding
|
|
|
|
# Should handle large images
|
|
embedding = embedding_service.generate_embedding(large_image_data)
|
|
assert len(embedding) == 512
|
|
|
|
def test_embedding_normalization(self, embedding_service, sample_image_data):
|
|
"""Test that embeddings are properly normalized"""
|
|
# Generate raw embedding
|
|
raw_embedding = np.random.rand(512).tolist()
|
|
|
|
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
|
mock_extract.return_value = raw_embedding
|
|
|
|
# Generate normalized embedding
|
|
embedding = embedding_service.generate_embedding(sample_image_data, normalize=True)
|
|
|
|
# Check if embedding is normalized (L2 norm should be 1)
|
|
norm = np.linalg.norm(embedding)
|
|
assert abs(norm - 1.0) < 0.001 # Allow small floating point errors
|
|
|
|
|
|
class TestEmbeddingServiceIntegration:
|
|
"""Integration tests for embedding service with other components"""
|
|
|
|
def test_embedding_to_vector_db_integration(self, embedding_service, sample_image_data, sample_image_model):
|
|
"""Test integration between embedding service and vector database"""
|
|
# Mock the vector database service
|
|
with patch('src.services.vector_db.VectorDatabaseService') as mock_vector_db:
|
|
# Setup mock
|
|
mock_store = mock_vector_db.return_value
|
|
mock_store.add_image_vector.return_value = "test_point_id"
|
|
|
|
# Test storing embedding
|
|
embedding = [0.1] * 512 # Mock embedding
|
|
point_id = mock_store.add_image_vector(
|
|
image_id=str(sample_image_model.id),
|
|
vector=embedding,
|
|
metadata={"filename": sample_image_model.filename}
|
|
)
|
|
|
|
# Verify the call
|
|
mock_store.add_image_vector.assert_called_once()
|
|
assert point_id == "test_point_id"
|
|
|
|
def test_pubsub_trigger_integration(self, embedding_service):
|
|
"""Test integration with Pub/Sub message processing"""
|
|
# Mock Pub/Sub message
|
|
mock_message = {
|
|
'image_id': str(ObjectId()),
|
|
'storage_path': 'images/test.jpg',
|
|
'team_id': str(ObjectId())
|
|
}
|
|
|
|
with patch.object(embedding_service, 'process_image_from_storage') as mock_process:
|
|
mock_process.return_value = {'embedding_id': 'emb123'}
|
|
|
|
# Process Pub/Sub message
|
|
result = embedding_service.handle_pubsub_message(mock_message)
|
|
|
|
# Verify message processing
|
|
assert result['embedding_id'] == 'emb123'
|
|
mock_process.assert_called_once_with(
|
|
mock_message['storage_path'],
|
|
mock_message['image_id'],
|
|
mock_message['team_id']
|
|
)
|
|
|
|
def test_cloud_function_deployment(self, embedding_service):
|
|
"""Test Cloud Function deployment compatibility"""
|
|
# Test that the service can be initialized in a Cloud Function environment
|
|
# This would test environment variable loading, authentication, etc.
|
|
pass
|
|
|
|
def test_error_recovery_and_retry(self, embedding_service, sample_image_data):
|
|
"""Test error recovery and retry mechanisms"""
|
|
# Mock transient error followed by success
|
|
mock_embedding = np.random.rand(512).tolist()
|
|
|
|
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
|
# First call fails, second succeeds
|
|
mock_extract.side_effect = [Exception("Transient error"), mock_embedding]
|
|
|
|
# Should retry and succeed
|
|
embedding = embedding_service.generate_embedding_with_retry(
|
|
sample_image_data, max_retries=2
|
|
)
|
|
|
|
assert len(embedding) == 512
|
|
assert mock_extract.call_count == 2 |