import pytest import numpy as np from unittest.mock import patch, MagicMock, AsyncMock from bson import ObjectId from io import BytesIO from src.services.embedding_service import EmbeddingService from src.models.image import ImageModel class TestEmbeddingService: """Test embedding generation for images""" @pytest.fixture def mock_vision_client(self): """Mock Google Cloud Vision client""" mock_client = MagicMock() mock_response = MagicMock() mock_response.image_properties_annotation.dominant_colors.colors = [] mock_response.label_annotations = [] mock_response.object_localizations = [] mock_response.text_annotations = [] mock_client.annotate_image.return_value = mock_response return mock_client @pytest.fixture def embedding_service(self, mock_vision_client): """Create embedding service with mocked dependencies""" with patch('src.services.embedding_service.vision') as mock_vision: mock_vision.ImageAnnotatorClient.return_value = mock_vision_client service = EmbeddingService() service.client = mock_vision_client return service @pytest.fixture def sample_image_data(self): """Create sample image data""" # Create a simple test image (1x1 pixel PNG) image_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\nIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xdd\x8d\xb4\x1c\x00\x00\x00\x00IEND\xaeB`\x82' return BytesIO(image_data) @pytest.fixture def sample_image_model(self): """Create a sample image model""" return ImageModel( filename="test-image.jpg", original_filename="test_image.jpg", file_size=1024, content_type="image/jpeg", storage_path="images/test-image.jpg", team_id=ObjectId(), uploader_id=ObjectId() ) def test_generate_embedding_from_image(self, embedding_service, sample_image_data): """Test generating embeddings from image data""" # Mock the embedding generation mock_embedding = np.random.rand(512).tolist() with patch.object(embedding_service, '_extract_features') as mock_extract: mock_extract.return_value = mock_embedding # Generate embedding embedding = embedding_service.generate_embedding(sample_image_data) # Verify embedding was generated assert embedding is not None assert len(embedding) == 512 assert isinstance(embedding, list) assert all(isinstance(x, (int, float)) for x in embedding) def test_extract_image_features(self, embedding_service, sample_image_data): """Test extracting features from images using Vision API""" # Mock Vision API response mock_response = MagicMock() mock_response.label_annotations = [ MagicMock(description="cat", score=0.95), MagicMock(description="animal", score=0.87), MagicMock(description="pet", score=0.82) ] mock_response.object_localizations = [ MagicMock(name="Cat", score=0.9) ] mock_response.image_properties_annotation.dominant_colors.colors = [ MagicMock(color=MagicMock(red=255, green=100, blue=50), score=0.8) ] embedding_service.client.annotate_image.return_value = mock_response # Extract features features = embedding_service.extract_image_features(sample_image_data) # Verify features were extracted assert 'labels' in features assert 'objects' in features assert 'colors' in features assert len(features['labels']) == 3 assert features['labels'][0]['description'] == "cat" assert features['labels'][0]['score'] == 0.95 def test_generate_embedding_with_metadata(self, embedding_service, sample_image_data, sample_image_model): """Test generating embeddings with image metadata""" mock_embedding = np.random.rand(512).tolist() mock_features = { 'labels': [{'description': 'cat', 'score': 0.95}], 'objects': [{'name': 'Cat', 'score': 0.9}], 'colors': [{'red': 255, 'green': 100, 'blue': 50, 'score': 0.8}] } with patch.object(embedding_service, '_extract_features') as mock_extract_features, \ patch.object(embedding_service, 'extract_image_features') as mock_extract_metadata: mock_extract_features.return_value = mock_embedding mock_extract_metadata.return_value = mock_features # Generate embedding with metadata result = embedding_service.generate_embedding_with_metadata( sample_image_data, sample_image_model ) # Verify result structure assert 'embedding' in result assert 'metadata' in result assert 'model' in result assert len(result['embedding']) == 512 assert result['metadata']['labels'][0]['description'] == 'cat' assert result['model'] == 'clip' # or whatever model is used def test_batch_generate_embeddings(self, embedding_service): """Test generating embeddings for multiple images in batch""" # Create multiple image data samples image_batch = [] for i in range(3): image_data = BytesIO(b'fake_image_data_' + str(i).encode()) image_model = ImageModel( filename=f"image{i}.jpg", original_filename=f"image{i}.jpg", file_size=1024, content_type="image/jpeg", storage_path=f"images/image{i}.jpg", team_id=ObjectId(), uploader_id=ObjectId() ) image_batch.append((image_data, image_model)) # Mock embedding generation mock_embedding = np.random.rand(512).tolist() with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate: mock_generate.return_value = { 'embedding': mock_embedding, 'metadata': {'labels': []}, 'model': 'clip' } # Generate batch embeddings results = embedding_service.batch_generate_embeddings(image_batch) # Verify batch results assert len(results) == 3 assert all('embedding' in result for result in results) assert all('metadata' in result for result in results) def test_embedding_model_consistency(self, embedding_service, sample_image_data): """Test that the same image produces consistent embeddings""" mock_embedding = np.random.rand(512).tolist() with patch.object(embedding_service, '_extract_features') as mock_extract: mock_extract.return_value = mock_embedding # Generate embedding twice embedding1 = embedding_service.generate_embedding(sample_image_data) sample_image_data.seek(0) # Reset stream position embedding2 = embedding_service.generate_embedding(sample_image_data) # Embeddings should be identical for the same image assert embedding1 == embedding2 def test_embedding_dimension_validation(self, embedding_service): """Test that embeddings have the correct dimensions""" mock_embedding = np.random.rand(512).tolist() with patch.object(embedding_service, '_extract_features') as mock_extract: mock_extract.return_value = mock_embedding # Validate embedding dimensions assert embedding_service.validate_embedding_dimensions(mock_embedding) is True # Test wrong dimensions wrong_embedding = np.random.rand(256).tolist() assert embedding_service.validate_embedding_dimensions(wrong_embedding) is False def test_handle_unsupported_image_format(self, embedding_service): """Test handling of unsupported image formats""" # Create invalid image data invalid_data = BytesIO(b'not_an_image') # Should raise appropriate exception with pytest.raises(ValueError): embedding_service.generate_embedding(invalid_data) def test_handle_corrupted_image(self, embedding_service): """Test handling of corrupted image data""" # Create corrupted image data corrupted_data = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00corrupted') # Should handle gracefully with pytest.raises(Exception): embedding_service.generate_embedding(corrupted_data) def test_vision_api_error_handling(self, embedding_service, sample_image_data): """Test handling of Vision API errors""" # Mock Vision API error embedding_service.client.annotate_image.side_effect = Exception("Vision API error") # Should handle the error gracefully with pytest.raises(Exception): embedding_service.extract_image_features(sample_image_data) def test_embedding_caching(self, embedding_service, sample_image_data): """Test caching of embeddings for the same image""" # This would test caching functionality if implemented pass def test_embedding_quality_metrics(self, embedding_service, sample_image_data): """Test quality metrics for generated embeddings""" mock_embedding = np.random.rand(512).tolist() with patch.object(embedding_service, '_extract_features') as mock_extract: mock_extract.return_value = mock_embedding # Generate embedding embedding = embedding_service.generate_embedding(sample_image_data) # Check embedding quality metrics quality_score = embedding_service.calculate_embedding_quality(embedding) assert 0 <= quality_score <= 1 def test_different_image_types(self, embedding_service): """Test embedding generation for different image types""" image_types = [ ('image/jpeg', b'fake_jpeg_data'), ('image/png', b'fake_png_data'), ('image/webp', b'fake_webp_data') ] mock_embedding = np.random.rand(512).tolist() with patch.object(embedding_service, '_extract_features') as mock_extract: mock_extract.return_value = mock_embedding for content_type, data in image_types: image_data = BytesIO(data) # Should handle different image types embedding = embedding_service.generate_embedding(image_data) assert len(embedding) == 512 def test_large_image_handling(self, embedding_service): """Test handling of large images""" # Create large image data (simulated) large_image_data = BytesIO(b'x' * (10 * 1024 * 1024)) # 10MB mock_embedding = np.random.rand(512).tolist() with patch.object(embedding_service, '_extract_features') as mock_extract: mock_extract.return_value = mock_embedding # Should handle large images embedding = embedding_service.generate_embedding(large_image_data) assert len(embedding) == 512 def test_embedding_normalization(self, embedding_service, sample_image_data): """Test that embeddings are properly normalized""" # Generate raw embedding raw_embedding = np.random.rand(512).tolist() with patch.object(embedding_service, '_extract_features') as mock_extract: mock_extract.return_value = raw_embedding # Generate normalized embedding embedding = embedding_service.generate_embedding(sample_image_data, normalize=True) # Check if embedding is normalized (L2 norm should be 1) norm = np.linalg.norm(embedding) assert abs(norm - 1.0) < 0.001 # Allow small floating point errors class TestEmbeddingServiceIntegration: """Integration tests for embedding service with other components""" def test_embedding_to_vector_db_integration(self, embedding_service, sample_image_data, sample_image_model): """Test integration between embedding service and vector database""" # Mock the vector database service with patch('src.services.vector_db.VectorDatabaseService') as mock_vector_db: # Setup mock mock_store = mock_vector_db.return_value mock_store.add_image_vector.return_value = "test_point_id" # Test storing embedding embedding = [0.1] * 512 # Mock embedding point_id = mock_store.add_image_vector( image_id=str(sample_image_model.id), vector=embedding, metadata={"filename": sample_image_model.filename} ) # Verify the call mock_store.add_image_vector.assert_called_once() assert point_id == "test_point_id" def test_pubsub_trigger_integration(self, embedding_service): """Test integration with Pub/Sub message processing""" # Mock Pub/Sub message mock_message = { 'image_id': str(ObjectId()), 'storage_path': 'images/test.jpg', 'team_id': str(ObjectId()) } with patch.object(embedding_service, 'process_image_from_storage') as mock_process: mock_process.return_value = {'embedding_id': 'emb123'} # Process Pub/Sub message result = embedding_service.handle_pubsub_message(mock_message) # Verify message processing assert result['embedding_id'] == 'emb123' mock_process.assert_called_once_with( mock_message['storage_path'], mock_message['image_id'], mock_message['team_id'] ) def test_cloud_function_deployment(self, embedding_service): """Test Cloud Function deployment compatibility""" # Test that the service can be initialized in a Cloud Function environment # This would test environment variable loading, authentication, etc. pass def test_error_recovery_and_retry(self, embedding_service, sample_image_data): """Test error recovery and retry mechanisms""" # Mock transient error followed by success mock_embedding = np.random.rand(512).tolist() with patch.object(embedding_service, '_extract_features') as mock_extract: # First call fails, second succeeds mock_extract.side_effect = [Exception("Transient error"), mock_embedding] # Should retry and succeed embedding = embedding_service.generate_embedding_with_retry( sample_image_data, max_retries=2 ) assert len(embedding) == 512 assert mock_extract.call_count == 2