import pytest from unittest.mock import Mock, patch, AsyncMock from pydantic import BaseModel from src.db.repositories.firestore_repositories import ( FirestoreRepository, FirestoreTeamRepository, FirestoreUserRepository, FirestoreApiKeyRepository, FirestoreImageRepository ) from src.models.team import TeamModel from src.models.user import UserModel from src.models.api_key import ApiKeyModel from src.models.image import ImageModel class TestFirestoreRepository: """Test cases for the base FirestoreRepository class""" @pytest.fixture def mock_firestore_db(self): """Mock Firestore database""" mock_db = Mock() mock_collection = Mock() mock_doc = Mock() mock_db.collection.return_value = mock_collection mock_collection.document.return_value = mock_doc mock_collection.stream.return_value = [] return mock_db @pytest.fixture def test_model_class(self): """Create a test model class for testing""" class TestModel(BaseModel): name: str value: int return TestModel @pytest.fixture def repository(self, test_model_class): """Create a FirestoreRepository instance for testing""" return FirestoreRepository("test_collection", test_model_class) def test_init(self, repository, test_model_class): """Test repository initialization""" assert repository.collection_name == "test_collection" assert repository.model_class == test_model_class @pytest.mark.asyncio async def test_create(self, repository, test_model_class, mock_firestore_db): """Test creating a document""" with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db): # Mock the document reference and set operation mock_doc_ref = Mock() mock_doc_ref.id = "test_id" mock_collection = mock_firestore_db.collection.return_value mock_collection.add.return_value = (None, mock_doc_ref) # Create test model instance test_instance = test_model_class(name="test", value=42) # Call create method result = await repository.create(test_instance) # Verify the result assert result.name == "test" assert result.value == 42 # Verify Firestore calls mock_firestore_db.collection.assert_called_once_with("test_collection") mock_collection.add.assert_called_once() @pytest.mark.asyncio async def test_get_by_id_found(self, repository, test_model_class, mock_firestore_db): """Test getting a document by ID when it exists""" with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db): # Mock document snapshot mock_doc_snapshot = Mock() mock_doc_snapshot.exists = True mock_doc_snapshot.to_dict.return_value = {"name": "test", "value": 42} mock_doc_snapshot.id = "test_id" mock_doc_ref = Mock() mock_doc_ref.get.return_value = mock_doc_snapshot mock_collection = mock_firestore_db.collection.return_value mock_collection.document.return_value = mock_doc_ref result = await repository.get_by_id("test_id") assert result.name == "test" assert result.value == 42 @pytest.mark.asyncio async def test_get_by_id_not_found(self, repository, mock_firestore_db): """Test getting a document by ID when it doesn't exist""" with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db): # Mock document snapshot that doesn't exist mock_doc_snapshot = Mock() mock_doc_snapshot.exists = False mock_doc_ref = Mock() mock_doc_ref.get.return_value = mock_doc_snapshot mock_collection = mock_firestore_db.collection.return_value mock_collection.document.return_value = mock_doc_ref result = await repository.get_by_id("nonexistent_id") assert result is None @pytest.mark.asyncio async def test_get_all(self, repository, test_model_class, mock_firestore_db): """Test getting all documents""" with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db): # Mock document snapshots mock_docs = [ Mock(to_dict=lambda: {"name": "test1", "value": 1}, id="id1"), Mock(to_dict=lambda: {"name": "test2", "value": 2}, id="id2") ] mock_collection = mock_firestore_db.collection.return_value mock_collection.stream.return_value = mock_docs result = await repository.get_all() assert len(result) == 2 assert result[0].name == "test1" assert result[1].name == "test2" @pytest.mark.asyncio async def test_update_success(self, repository, test_model_class, mock_firestore_db): """Test updating a document successfully""" with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db): # Mock successful update mock_doc_ref = Mock() mock_doc_ref.update.return_value = None # Firestore update returns None on success mock_collection = mock_firestore_db.collection.return_value mock_collection.document.return_value = mock_doc_ref # Mock get_by_id to return updated document updated_instance = test_model_class(name="updated", value=99) with patch.object(repository, 'get_by_id', return_value=updated_instance): result = await repository.update("test_id", {"name": "updated", "value": 99}) assert result.name == "updated" assert result.value == 99 mock_doc_ref.update.assert_called_once_with({"name": "updated", "value": 99}) @pytest.mark.asyncio async def test_update_failure(self, repository, mock_firestore_db): """Test updating a document that doesn't exist""" with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db): # Mock failed update (document doesn't exist) mock_doc_ref = Mock() mock_doc_ref.update.side_effect = Exception("Document not found") mock_collection = mock_firestore_db.collection.return_value mock_collection.document.return_value = mock_doc_ref with pytest.raises(Exception): await repository.update("nonexistent_id", {"name": "updated"}) @pytest.mark.asyncio async def test_delete_success(self, repository, mock_firestore_db): """Test deleting a document successfully""" with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db): mock_doc_ref = Mock() mock_doc_ref.delete.return_value = None # Firestore delete returns None on success mock_collection = mock_firestore_db.collection.return_value mock_collection.document.return_value = mock_doc_ref result = await repository.delete("test_id") assert result is True mock_doc_ref.delete.assert_called_once() @pytest.mark.asyncio async def test_delete_failure(self, repository, mock_firestore_db): """Test deleting a document that doesn't exist""" with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db): mock_doc_ref = Mock() mock_doc_ref.delete.side_effect = Exception("Document not found") mock_collection = mock_firestore_db.collection.return_value mock_collection.document.return_value = mock_doc_ref result = await repository.delete("nonexistent_id") assert result is False class TestFirestoreTeamRepository: """Test cases for FirestoreTeamRepository""" @pytest.fixture def repository(self): """Create a FirestoreTeamRepository instance for testing""" return FirestoreTeamRepository() def test_init(self, repository): """Test repository initialization""" assert repository.collection_name == "teams" assert repository.model_class == TeamModel @pytest.mark.asyncio async def test_get_by_id(self, repository): """Test getting team by ID""" with patch.object(repository.__class__.__bases__[0], 'get_by_id') as mock_get: mock_get.return_value = Mock(id="team_id", name="Test Team") result = await repository.get_by_id("team_id") assert result.id == "team_id" assert result.name == "Test Team" mock_get.assert_called_once_with("team_id") @pytest.mark.asyncio async def test_update(self, repository): """Test updating team""" with patch.object(repository.__class__.__bases__[0], 'update') as mock_update: mock_update.return_value = Mock(name="Updated Team") result = await repository.update("team_id", {"name": "Updated Team"}) assert result.name == "Updated Team" mock_update.assert_called_once_with("team_id", {"name": "Updated Team"}) @pytest.mark.asyncio async def test_delete(self, repository): """Test deleting team""" with patch.object(repository.__class__.__bases__[0], 'delete') as mock_delete: mock_delete.return_value = True result = await repository.delete("team_id") assert result is True mock_delete.assert_called_once_with("team_id") class TestFirestoreUserRepository: """Test cases for FirestoreUserRepository""" @pytest.fixture def repository(self): """Create a FirestoreUserRepository instance for testing""" return FirestoreUserRepository() def test_init(self, repository): """Test repository initialization""" assert repository.collection_name == "users" assert repository.model_class == UserModel @pytest.mark.asyncio async def test_get_by_email(self, repository): """Test getting user by email""" mock_users = [ Mock(email="test1@example.com"), Mock(email="test2@example.com"), Mock(email="target@example.com") ] with patch.object(repository, 'get_all') as mock_get_all: mock_get_all.return_value = mock_users result = await repository.get_by_email("target@example.com") assert result == mock_users[2] mock_get_all.assert_called_once() @pytest.mark.asyncio async def test_get_by_email_not_found(self, repository): """Test getting user by email when not found""" mock_users = [ Mock(email="test1@example.com"), Mock(email="test2@example.com") ] with patch.object(repository, 'get_all') as mock_get_all: mock_get_all.return_value = mock_users result = await repository.get_by_email("notfound@example.com") assert result is None mock_get_all.assert_called_once() @pytest.mark.asyncio async def test_get_by_team_id(self, repository): """Test getting users by team ID""" mock_users = [ Mock(team_id="team1"), Mock(team_id="team2"), Mock(team_id="team1") ] with patch.object(repository, 'get_all') as mock_get_all: mock_get_all.return_value = mock_users result = await repository.get_by_team_id("team1") assert len(result) == 2 assert result[0] == mock_users[0] assert result[1] == mock_users[2] mock_get_all.assert_called_once() class TestFirestoreApiKeyRepository: """Test cases for FirestoreApiKeyRepository""" @pytest.fixture def repository(self): """Create a FirestoreApiKeyRepository instance for testing""" return FirestoreApiKeyRepository() def test_init(self, repository): """Test repository initialization""" assert repository.collection_name == "api_keys" assert repository.model_class == ApiKeyModel @pytest.mark.asyncio async def test_get_by_key_hash(self, repository): """Test getting API key by hash""" mock_api_keys = [ Mock(key_hash="hash1"), Mock(key_hash="hash2"), Mock(key_hash="target_hash") ] with patch.object(repository, 'get_all') as mock_get_all: mock_get_all.return_value = mock_api_keys result = await repository.get_by_key_hash("target_hash") assert result == mock_api_keys[2] mock_get_all.assert_called_once() @pytest.mark.asyncio async def test_get_by_user_id(self, repository): """Test getting API keys by user ID""" mock_api_keys = [ Mock(user_id="user1"), Mock(user_id="user2"), Mock(user_id="user1") ] with patch.object(repository, 'get_all') as mock_get_all: mock_get_all.return_value = mock_api_keys result = await repository.get_by_user_id("user1") assert len(result) == 2 assert result[0] == mock_api_keys[0] assert result[1] == mock_api_keys[2] mock_get_all.assert_called_once() class TestFirestoreImageRepository: """Test cases for FirestoreImageRepository""" @pytest.fixture def repository(self): """Create a FirestoreImageRepository instance for testing""" return FirestoreImageRepository() def test_init(self, repository): """Test repository initialization""" assert repository.collection_name == "images" assert repository.model_class == ImageModel @pytest.mark.asyncio async def test_get_by_team_id(self, repository): """Test getting images by team ID""" mock_images = [ Mock(team_id="team1"), Mock(team_id="team2"), Mock(team_id="team1") ] with patch.object(repository, 'get_all') as mock_get_all: mock_get_all.return_value = mock_images result = await repository.get_by_team_id("team1") assert len(result) == 2 assert result[0] == mock_images[0] assert result[1] == mock_images[2] mock_get_all.assert_called_once() @pytest.mark.asyncio async def test_get_by_uploader_id(self, repository): """Test getting images by uploader ID""" mock_images = [ Mock(uploader_id="user1"), Mock(uploader_id="user2"), Mock(uploader_id="user1") ] with patch.object(repository, 'get_all') as mock_get_all: mock_get_all.return_value = mock_images result = await repository.get_by_uploader_id("user1") assert len(result) == 2 assert result[0] == mock_images[0] assert result[1] == mock_images[2] mock_get_all.assert_called_once()