diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..93ea8c7 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,14 @@ +[tool:pytest] +asyncio_mode = auto +asyncio_default_fixture_loop_scope = function +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* +addopts = + -v + --tb=short + --strict-markers + --disable-warnings +markers = + asyncio: marks tests as async (deselect with '-m "not asyncio"') \ No newline at end of file diff --git a/src/db/repositories/firestore_repository.py b/src/db/repositories/firestore_repository.py index 88f2ab2..6f0e8e8 100644 --- a/src/db/repositories/firestore_repository.py +++ b/src/db/repositories/firestore_repository.py @@ -27,7 +27,7 @@ class FirestoreRepository(Generic[T]): """ try: # Convert Pydantic model to dict - model_dict = model.dict(by_alias=True) + model_dict = model.model_dump(by_alias=True) # Add document to Firestore doc_id = await firestore_db.add_document(self.collection_name, model_dict) diff --git a/tests/db/test_database.py b/tests/db/test_database.py new file mode 100644 index 0000000..43a06d5 --- /dev/null +++ b/tests/db/test_database.py @@ -0,0 +1,108 @@ +import pytest +from unittest.mock import Mock, patch +from src.db import Database, db + + +class TestDatabase: + """Test cases for Database class""" + + @pytest.fixture + def database(self): + """Create a Database instance for testing""" + return Database() + + def test_init(self, database): + """Test database initialization""" + assert database.provider is not None + assert hasattr(database.provider, 'client') + + @patch('src.db.FirestoreProvider') + def test_connect_to_database_success(self, mock_provider_class, database): + """Test successful database connection""" + mock_provider = Mock() + mock_provider_class.return_value = mock_provider + database.provider = mock_provider + + database.connect_to_database() + + mock_provider.connect.assert_called_once() + + @patch('src.db.FirestoreProvider') + def test_connect_to_database_failure(self, mock_provider_class, database): + """Test database connection failure""" + mock_provider = Mock() + mock_provider.connect.side_effect = Exception("Connection failed") + mock_provider_class.return_value = mock_provider + database.provider = mock_provider + + with pytest.raises(Exception, match="Connection failed"): + database.connect_to_database() + + mock_provider.connect.assert_called_once() + + def test_close_database_connection_success(self, database): + """Test successful database disconnection""" + mock_provider = Mock() + database.provider = mock_provider + + database.close_database_connection() + + mock_provider.disconnect.assert_called_once() + + def test_close_database_connection_failure(self, database): + """Test database disconnection with error""" + mock_provider = Mock() + mock_provider.disconnect.side_effect = Exception("Disconnect failed") + database.provider = mock_provider + + # Should not raise exception, just log error + database.close_database_connection() + + mock_provider.disconnect.assert_called_once() + + def test_get_database_with_client(self, database): + """Test getting database when client exists""" + mock_provider = Mock() + mock_client = Mock() + mock_provider.client = mock_client + database.provider = mock_provider + + result = database.get_database() + + assert result == mock_client + + def test_get_database_without_client_reconnect_success(self, database): + """Test getting database when client is None and reconnection succeeds""" + mock_provider = Mock() + mock_provider.client = None + database.provider = mock_provider + + with patch.object(database, 'connect_to_database') as mock_connect: + # Simulate successful reconnection by setting client after connect is called + def set_client(): + mock_provider.client = Mock() + mock_connect.side_effect = set_client + + result = database.get_database() + + mock_connect.assert_called_once() + assert result == mock_provider.client + + def test_get_database_without_client_reconnect_failure(self, database): + """Test getting database when client is None and reconnection fails""" + mock_provider = Mock() + mock_provider.client = None + database.provider = mock_provider + + with patch.object(database, 'connect_to_database') as mock_connect: + mock_connect.side_effect = Exception("Reconnection failed") + + result = database.get_database() + + mock_connect.assert_called_once() + assert result is None + + def test_singleton_db_instance(self): + """Test that db is properly imported as singleton""" + assert isinstance(db, Database) + assert db.provider is not None \ No newline at end of file diff --git a/tests/db/test_firestore_provider.py b/tests/db/test_firestore_provider.py new file mode 100644 index 0000000..462e514 --- /dev/null +++ b/tests/db/test_firestore_provider.py @@ -0,0 +1,273 @@ +import pytest +import asyncio +from unittest.mock import Mock, patch, AsyncMock +from src.db.providers.firestore_provider import FirestoreProvider +from src.config.config import settings + + +class TestFirestoreProvider: + """Test cases for FirestoreProvider""" + + @pytest.fixture + def provider(self): + """Create a FirestoreProvider instance for testing""" + return FirestoreProvider() + + @pytest.fixture + def mock_firestore_client(self): + """Mock Firestore client for testing""" + mock_client = Mock() + mock_collection = Mock() + mock_doc_ref = Mock() + mock_doc = Mock() + + # Setup mock chain + mock_client.collection.return_value = mock_collection + mock_collection.document.return_value = mock_doc_ref + mock_collection.add.return_value = (None, mock_doc_ref) + mock_doc_ref.get.return_value = mock_doc + mock_doc_ref.id = "test_doc_id" + + return mock_client + + def test_init(self, provider): + """Test provider initialization""" + assert provider.client is None + assert provider._db is None + assert "teams" in provider._collections + assert "users" in provider._collections + assert "api_keys" in provider._collections + assert "images" in provider._collections + + @patch('src.db.providers.firestore_provider.firestore.Client') + @patch('google.oauth2.service_account') + @patch('os.path.exists') + def test_connect_with_credentials_file(self, mock_exists, mock_service_account, mock_client, provider): + """Test connecting with credentials file""" + # Setup mocks + mock_exists.return_value = True + mock_credentials = Mock() + mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials + mock_client_instance = Mock() + mock_client.return_value = mock_client_instance + + # Test connection + result = provider.connect() + + # Assertions + assert result is True + assert provider.client == mock_client_instance + assert provider._db == mock_client_instance + mock_service_account.Credentials.from_service_account_file.assert_called_once() + mock_client.assert_called_once_with( + project=settings.FIRESTORE_PROJECT_ID, + credentials=mock_credentials, + database=settings.FIRESTORE_DATABASE_NAME + ) + + @patch('src.db.providers.firestore_provider.firestore.Client') + @patch('os.path.exists') + def test_connect_with_default_credentials(self, mock_exists, mock_client, provider): + """Test connecting with application default credentials""" + # Setup mocks + mock_exists.return_value = False + mock_client_instance = Mock() + mock_client.return_value = mock_client_instance + + # Test connection + result = provider.connect() + + # Assertions + assert result is True + assert provider.client == mock_client_instance + mock_client.assert_called_once_with( + project=settings.FIRESTORE_PROJECT_ID, + database=settings.FIRESTORE_DATABASE_NAME + ) + + def test_disconnect(self, provider): + """Test disconnecting from Firestore""" + provider.client = Mock() + provider._db = Mock() + + provider.disconnect() + + assert provider.client is None + assert provider._db is None + + def test_get_collection(self, provider, mock_firestore_client): + """Test getting a collection reference""" + provider.client = mock_firestore_client + + collection = provider.get_collection("test_collection") + + mock_firestore_client.collection.assert_called_once_with("test_collection") + assert collection == mock_firestore_client.collection.return_value + + def test_get_collection_without_client(self, provider): + """Test getting collection when client is None""" + provider.client = None + + with patch.object(provider, 'connect') as mock_connect: + # Simulate successful connection by setting client after connect is called + def set_client(): + provider.client = Mock() + mock_connect.side_effect = set_client + + provider.get_collection("test_collection") + + mock_connect.assert_called_once() + + @pytest.mark.asyncio + async def test_add_document(self, provider, mock_firestore_client): + """Test adding a document""" + provider.client = mock_firestore_client + + test_data = {"name": "Test", "value": 123} + + with patch.object(provider, 'get_collection') as mock_get_collection: + mock_collection = Mock() + mock_doc_ref = Mock() + mock_doc_ref.id = "generated_id" + mock_collection.add.return_value = (None, mock_doc_ref) + mock_get_collection.return_value = mock_collection + + doc_id = await provider.add_document("test_collection", test_data) + + assert doc_id == "generated_id" + mock_collection.add.assert_called_once() + + @pytest.mark.asyncio + async def test_add_document_with_id(self, provider, mock_firestore_client): + """Test adding a document with specific ID""" + provider.client = mock_firestore_client + + test_data = {"_id": "custom_id", "name": "Test"} + + with patch.object(provider, 'get_collection') as mock_get_collection: + mock_collection = Mock() + mock_doc_ref = Mock() + mock_collection.document.return_value = mock_doc_ref + mock_get_collection.return_value = mock_collection + + doc_id = await provider.add_document("test_collection", test_data) + + assert doc_id == "custom_id" + mock_collection.document.assert_called_once_with("custom_id") + mock_doc_ref.set.assert_called_once() + + @pytest.mark.asyncio + async def test_get_document_exists(self, provider, mock_firestore_client): + """Test getting an existing document""" + provider.client = mock_firestore_client + + with patch.object(provider, 'get_collection') as mock_get_collection: + mock_collection = Mock() + mock_doc_ref = Mock() + mock_doc = Mock() + mock_doc.exists = True + mock_doc.to_dict.return_value = {"name": "Test", "value": 123} + + mock_collection.document.return_value = mock_doc_ref + mock_doc_ref.get.return_value = mock_doc + mock_get_collection.return_value = mock_collection + + result = await provider.get_document("test_collection", "test_id") + + assert result == {"name": "Test", "value": 123, "_id": "test_id"} + mock_collection.document.assert_called_once_with("test_id") + + @pytest.mark.asyncio + async def test_get_document_not_exists(self, provider, mock_firestore_client): + """Test getting a non-existent document""" + provider.client = mock_firestore_client + + with patch.object(provider, 'get_collection') as mock_get_collection: + mock_collection = Mock() + mock_doc_ref = Mock() + mock_doc = Mock() + mock_doc.exists = False + + mock_collection.document.return_value = mock_doc_ref + mock_doc_ref.get.return_value = mock_doc + mock_get_collection.return_value = mock_collection + + result = await provider.get_document("test_collection", "test_id") + + assert result is None + + @pytest.mark.asyncio + async def test_list_documents(self, provider, mock_firestore_client): + """Test listing all documents in a collection""" + provider.client = mock_firestore_client + + with patch.object(provider, 'get_collection') as mock_get_collection: + mock_collection = Mock() + mock_doc1 = Mock() + mock_doc1.id = "doc1" + mock_doc1.to_dict.return_value = {"name": "Test1"} + mock_doc2 = Mock() + mock_doc2.id = "doc2" + mock_doc2.to_dict.return_value = {"name": "Test2"} + + mock_collection.stream.return_value = [mock_doc1, mock_doc2] + mock_get_collection.return_value = mock_collection + + result = await provider.list_documents("test_collection") + + assert len(result) == 2 + assert result[0] == {"name": "Test1", "_id": "doc1"} + assert result[1] == {"name": "Test2", "_id": "doc2"} + + @pytest.mark.asyncio + async def test_update_document(self, provider, mock_firestore_client): + """Test updating a document""" + provider.client = mock_firestore_client + + update_data = {"name": "Updated", "_id": "should_be_removed"} + + with patch.object(provider, 'get_collection') as mock_get_collection: + mock_collection = Mock() + mock_doc_ref = Mock() + mock_collection.document.return_value = mock_doc_ref + mock_get_collection.return_value = mock_collection + + result = await provider.update_document("test_collection", "test_id", update_data) + + assert result is True + mock_collection.document.assert_called_once_with("test_id") + # Verify _id was removed from update data + mock_doc_ref.update.assert_called_once_with({"name": "Updated"}) + + @pytest.mark.asyncio + async def test_delete_document(self, provider, mock_firestore_client): + """Test deleting a document""" + provider.client = mock_firestore_client + + with patch.object(provider, 'get_collection') as mock_get_collection: + mock_collection = Mock() + mock_doc_ref = Mock() + mock_collection.document.return_value = mock_doc_ref + mock_get_collection.return_value = mock_collection + + result = await provider.delete_document("test_collection", "test_id") + + assert result is True + mock_collection.document.assert_called_once_with("test_id") + mock_doc_ref.delete.assert_called_once() + + def test_convert_to_model(self, provider): + """Test converting document data to Pydantic model""" + from pydantic import BaseModel + + class TestModel(BaseModel): + name: str + value: int + + doc_data = {"name": "Test", "value": 123} + + result = provider.convert_to_model(TestModel, doc_data) + + assert isinstance(result, TestModel) + assert result.name == "Test" + assert result.value == 123 \ No newline at end of file diff --git a/tests/db/test_firestore_repositories.py b/tests/db/test_firestore_repositories.py new file mode 100644 index 0000000..e00a87e --- /dev/null +++ b/tests/db/test_firestore_repositories.py @@ -0,0 +1,427 @@ +import pytest +from unittest.mock import Mock, patch, AsyncMock +from src.db.repositories.firestore_repository import FirestoreRepository +from src.db.repositories.firestore_team_repository import FirestoreTeamRepository +from src.db.repositories.firestore_user_repository import FirestoreUserRepository +from src.db.repositories.firestore_api_key_repository import FirestoreApiKeyRepository +from src.db.repositories.firestore_image_repository import FirestoreImageRepository +from src.models.team import TeamModel +from src.models.user import UserModel +from src.models.api_key import ApiKeyModel +from src.models.image import ImageModel +from pydantic import BaseModel + + +class TestFirestoreRepository: + """Test cases for the base FirestoreRepository""" + + @pytest.fixture + def mock_firestore_db(self): + """Mock firestore_db for testing""" + with patch('src.db.repositories.firestore_repository.firestore_db') as mock_db: + # Make the async methods return coroutines + mock_db.add_document = AsyncMock() + mock_db.get_document = AsyncMock() + mock_db.list_documents = AsyncMock() + mock_db.update_document = AsyncMock() + mock_db.delete_document = AsyncMock() + yield mock_db + + @pytest.fixture + def test_model_class(self): + """Create a test model class for testing""" + class TestModel(BaseModel): + name: str + value: int + + return TestModel + + @pytest.fixture + def repository(self, test_model_class): + """Create a FirestoreRepository instance for testing""" + return FirestoreRepository("test_collection", test_model_class) + + def test_init(self, repository, test_model_class): + """Test repository initialization""" + assert repository.collection_name == "test_collection" + assert repository.model_class == test_model_class + + @pytest.mark.asyncio + async def test_create(self, repository, test_model_class, mock_firestore_db): + """Test creating a new document""" + # Setup + test_model = test_model_class(name="Test", value=123) + mock_firestore_db.add_document.return_value = "generated_id" + mock_firestore_db.get_document.return_value = { + "name": "Test", + "value": 123, + "_id": "generated_id" + } + + # Execute + result = await repository.create(test_model) + + # Assert + assert isinstance(result, test_model_class) + assert result.name == "Test" + assert result.value == 123 + mock_firestore_db.add_document.assert_called_once_with( + "test_collection", + {"name": "Test", "value": 123} + ) + mock_firestore_db.get_document.assert_called_once_with( + "test_collection", + "generated_id" + ) + + @pytest.mark.asyncio + async def test_get_by_id_found(self, repository, test_model_class, mock_firestore_db): + """Test getting a document by ID when it exists""" + # Setup + mock_firestore_db.get_document.return_value = { + "name": "Test", + "value": 123, + "_id": "test_id" + } + + # Execute + result = await repository.get_by_id("test_id") + + # Assert + assert isinstance(result, test_model_class) + assert result.name == "Test" + assert result.value == 123 + mock_firestore_db.get_document.assert_called_once_with("test_collection", "test_id") + + @pytest.mark.asyncio + async def test_get_by_id_not_found(self, repository, mock_firestore_db): + """Test getting a document by ID when it doesn't exist""" + # Setup + mock_firestore_db.get_document.return_value = None + + # Execute + result = await repository.get_by_id("nonexistent_id") + + # Assert + assert result is None + mock_firestore_db.get_document.assert_called_once_with("test_collection", "nonexistent_id") + + @pytest.mark.asyncio + async def test_get_all(self, repository, test_model_class, mock_firestore_db): + """Test getting all documents""" + # Setup + mock_firestore_db.list_documents.return_value = [ + {"name": "Test1", "value": 123, "_id": "id1"}, + {"name": "Test2", "value": 456, "_id": "id2"} + ] + + # Execute + result = await repository.get_all() + + # Assert + assert len(result) == 2 + assert all(isinstance(item, test_model_class) for item in result) + assert result[0].name == "Test1" + assert result[1].name == "Test2" + mock_firestore_db.list_documents.assert_called_once_with("test_collection") + + @pytest.mark.asyncio + async def test_update_success(self, repository, test_model_class, mock_firestore_db): + """Test updating a document successfully""" + # Setup + update_data = {"name": "Updated", "value": 999} + mock_firestore_db.update_document.return_value = True + mock_firestore_db.get_document.return_value = { + "name": "Updated", + "value": 999, + "_id": "test_id" + } + + # Execute + result = await repository.update("test_id", update_data) + + # Assert + assert isinstance(result, test_model_class) + assert result.name == "Updated" + assert result.value == 999 + mock_firestore_db.update_document.assert_called_once_with( + "test_collection", + "test_id", + update_data + ) + + @pytest.mark.asyncio + async def test_update_failure(self, repository, mock_firestore_db): + """Test updating a document that doesn't exist""" + # Setup + update_data = {"name": "Updated"} + mock_firestore_db.update_document.return_value = False + + # Execute + result = await repository.update("nonexistent_id", update_data) + + # Assert + assert result is None + mock_firestore_db.update_document.assert_called_once_with( + "test_collection", + "nonexistent_id", + update_data + ) + + @pytest.mark.asyncio + async def test_delete_success(self, repository, mock_firestore_db): + """Test deleting a document successfully""" + # Setup + mock_firestore_db.delete_document.return_value = True + + # Execute + result = await repository.delete("test_id") + + # Assert + assert result is True + mock_firestore_db.delete_document.assert_called_once_with("test_collection", "test_id") + + @pytest.mark.asyncio + async def test_delete_failure(self, repository, mock_firestore_db): + """Test deleting a document that doesn't exist""" + # Setup + mock_firestore_db.delete_document.return_value = False + + # Execute + result = await repository.delete("nonexistent_id") + + # Assert + assert result is False + mock_firestore_db.delete_document.assert_called_once_with("test_collection", "nonexistent_id") + + +class TestFirestoreTeamRepository: + """Test cases for FirestoreTeamRepository""" + + @pytest.fixture + def repository(self): + """Create a FirestoreTeamRepository instance for testing""" + return FirestoreTeamRepository() + + def test_init(self, repository): + """Test repository initialization""" + assert repository.collection_name == "teams" + assert repository.model_class == TeamModel + + @pytest.mark.asyncio + async def test_get_by_id(self, repository): + """Test getting team by ID""" + with patch.object(repository.__class__.__bases__[0], 'get_by_id') as mock_get_by_id: + mock_get_by_id.return_value = Mock() + + await repository.get_by_id("team_id") + + mock_get_by_id.assert_called_once_with("team_id") + + @pytest.mark.asyncio + async def test_update(self, repository): + """Test updating team""" + with patch.object(repository.__class__.__bases__[0], 'update') as mock_update: + mock_update.return_value = Mock() + + await repository.update("team_id", {"name": "Updated Team"}) + + mock_update.assert_called_once_with("team_id", {"name": "Updated Team"}) + + @pytest.mark.asyncio + async def test_delete(self, repository): + """Test deleting team""" + with patch.object(repository.__class__.__bases__[0], 'delete') as mock_delete: + mock_delete.return_value = True + + result = await repository.delete("team_id") + + assert result is True + mock_delete.assert_called_once_with("team_id") + + +class TestFirestoreUserRepository: + """Test cases for FirestoreUserRepository""" + + @pytest.fixture + def repository(self): + """Create a FirestoreUserRepository instance for testing""" + return FirestoreUserRepository() + + def test_init(self, repository): + """Test repository initialization""" + assert repository.collection_name == "users" + assert repository.model_class == UserModel + + @pytest.mark.asyncio + async def test_get_by_email(self, repository): + """Test getting user by email""" + mock_users = [ + Mock(email="test1@example.com"), + Mock(email="test2@example.com"), + Mock(email="target@example.com") + ] + + with patch.object(repository, 'get_all') as mock_get_all: + mock_get_all.return_value = mock_users + + result = await repository.get_by_email("target@example.com") + + assert result == mock_users[2] + mock_get_all.assert_called_once() + + @pytest.mark.asyncio + async def test_get_by_email_not_found(self, repository): + """Test getting user by email when not found""" + mock_users = [ + Mock(email="test1@example.com"), + Mock(email="test2@example.com") + ] + + with patch.object(repository, 'get_all') as mock_get_all: + mock_get_all.return_value = mock_users + + result = await repository.get_by_email("notfound@example.com") + + assert result is None + mock_get_all.assert_called_once() + + @pytest.mark.asyncio + async def test_get_by_team_id(self, repository): + """Test getting users by team ID""" + mock_users = [ + Mock(team_id="team1"), + Mock(team_id="team2"), + Mock(team_id="team1") + ] + + with patch.object(repository, 'get_all') as mock_get_all: + mock_get_all.return_value = mock_users + + result = await repository.get_by_team_id("team1") + + assert len(result) == 2 + assert result[0] == mock_users[0] + assert result[1] == mock_users[2] + mock_get_all.assert_called_once() + + +class TestFirestoreApiKeyRepository: + """Test cases for FirestoreApiKeyRepository""" + + @pytest.fixture + def repository(self): + """Create a FirestoreApiKeyRepository instance for testing""" + return FirestoreApiKeyRepository() + + def test_init(self, repository): + """Test repository initialization""" + assert repository.collection_name == "api_keys" + assert repository.model_class == ApiKeyModel + + @pytest.mark.asyncio + async def test_get_by_key_hash(self, repository): + """Test getting API key by hash""" + mock_api_keys = [ + Mock(key_hash="hash1"), + Mock(key_hash="hash2"), + Mock(key_hash="target_hash") + ] + + with patch.object(repository, 'get_all') as mock_get_all: + mock_get_all.return_value = mock_api_keys + + result = await repository.get_by_key_hash("target_hash") + + assert result == mock_api_keys[2] + mock_get_all.assert_called_once() + + @pytest.mark.asyncio + async def test_get_by_user_id(self, repository): + """Test getting API keys by user ID""" + mock_api_keys = [ + Mock(user_id="user1"), + Mock(user_id="user2"), + Mock(user_id="user1") + ] + + with patch.object(repository, 'get_all') as mock_get_all: + mock_get_all.return_value = mock_api_keys + + result = await repository.get_by_user_id("user1") + + assert len(result) == 2 + assert result[0] == mock_api_keys[0] + assert result[1] == mock_api_keys[2] + mock_get_all.assert_called_once() + + +class TestFirestoreImageRepository: + """Test cases for FirestoreImageRepository""" + + @pytest.fixture + def repository(self): + """Create a FirestoreImageRepository instance for testing""" + return FirestoreImageRepository() + + def test_init(self, repository): + """Test repository initialization""" + assert repository.collection_name == "images" + assert repository.model_class == ImageModel + + @pytest.mark.asyncio + async def test_get_by_team_id(self, repository): + """Test getting images by team ID""" + mock_images = [ + Mock(team_id="team1"), + Mock(team_id="team2"), + Mock(team_id="team1") + ] + + with patch.object(repository, 'get_all') as mock_get_all: + mock_get_all.return_value = mock_images + + result = await repository.get_by_team_id("team1") + + assert len(result) == 2 + assert result[0] == mock_images[0] + assert result[1] == mock_images[2] + mock_get_all.assert_called_once() + + @pytest.mark.asyncio + async def test_get_by_uploader_id(self, repository): + """Test getting images by uploader ID""" + mock_images = [ + Mock(uploader_id="user1"), + Mock(uploader_id="user2"), + Mock(uploader_id="user1") + ] + + with patch.object(repository, 'get_all') as mock_get_all: + mock_get_all.return_value = mock_images + + result = await repository.get_by_uploader_id("user1") + + assert len(result) == 2 + assert result[0] == mock_images[0] + assert result[1] == mock_images[2] + mock_get_all.assert_called_once() + + @pytest.mark.asyncio + async def test_get_by_tag(self, repository): + """Test getting images by tag""" + mock_images = [ + Mock(tags=["tag1", "tag2"]), + Mock(tags=["tag3"]), + Mock(tags=["tag1", "tag4"]) + ] + + with patch.object(repository, 'get_all') as mock_get_all: + mock_get_all.return_value = mock_images + + result = await repository.get_by_tag("tag1") + + assert len(result) == 2 + assert result[0] == mock_images[0] + assert result[1] == mock_images[2] + mock_get_all.assert_called_once() \ No newline at end of file diff --git a/tests/db/test_py_object_id.py b/tests/db/test_py_object_id.py deleted file mode 100644 index d53aea9..0000000 --- a/tests/db/test_py_object_id.py +++ /dev/null @@ -1,51 +0,0 @@ -import pytest -from bson import ObjectId -from pydantic import ValidationError, BaseModel, Field -from src.models.team import PyObjectId - -class TestPyObjectId: - def test_valid_object_id(self): - """Test validating a valid ObjectId""" - # Using a string - valid_id = str(ObjectId()) - obj_id = PyObjectId.validate(valid_id) - assert isinstance(obj_id, ObjectId) - - # Using an existing ObjectId - existing_id = ObjectId() - obj_id = PyObjectId.validate(existing_id) - assert isinstance(obj_id, ObjectId) - assert obj_id == existing_id - - def test_invalid_object_id(self): - """Test validating an invalid ObjectId""" - with pytest.raises(ValueError, match='Invalid ObjectId'): - PyObjectId.validate("invalid-id") - - with pytest.raises(ValueError, match='Invalid ObjectId'): - PyObjectId.validate(123) - - with pytest.raises(ValueError, match='Invalid ObjectId'): - PyObjectId.validate(None) - - def test_object_id_in_model(self): - """Test using PyObjectId in a Pydantic model""" - class TestModel(BaseModel): - id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") - - # Creating a model instance with auto-generated ID - model = TestModel() - assert isinstance(model.id, ObjectId) - - # Creating a model instance with provided ID - existing_id = ObjectId() - model = TestModel(_id=existing_id) - assert model.id == existing_id - - # Validating model with invalid ID - with pytest.raises(ValidationError): - TestModel(_id="invalid-id") - - # Removing the test_json_schema test as it's incompatible with - # the current PyObjectId implementation - # The functionality is implicitly tested in the other model tests \ No newline at end of file diff --git a/tests/db/test_repository_factory.py b/tests/db/test_repository_factory.py new file mode 100644 index 0000000..ceb3363 --- /dev/null +++ b/tests/db/test_repository_factory.py @@ -0,0 +1,52 @@ +import pytest +from src.db.repositories.repository_factory import RepositoryFactory, DatabaseType +from src.db.repositories.firestore_team_repository import FirestoreTeamRepository +from src.db.repositories.firestore_user_repository import FirestoreUserRepository +from src.db.repositories.firestore_api_key_repository import FirestoreApiKeyRepository +from src.db.repositories.firestore_image_repository import FirestoreImageRepository + + +class TestRepositoryFactory: + """Test cases for RepositoryFactory""" + + @pytest.fixture + def factory(self): + """Create a RepositoryFactory instance for testing""" + return RepositoryFactory() + + def test_init(self, factory): + """Test factory initialization""" + assert DatabaseType.FIRESTORE in factory.team_repositories + assert DatabaseType.FIRESTORE in factory.user_repositories + assert DatabaseType.FIRESTORE in factory.api_key_repositories + assert DatabaseType.FIRESTORE in factory.image_repositories + + def test_get_team_repository(self, factory): + """Test getting team repository""" + repository = factory.get_team_repository() + assert isinstance(repository, FirestoreTeamRepository) + + def test_get_user_repository(self, factory): + """Test getting user repository""" + repository = factory.get_user_repository() + assert isinstance(repository, FirestoreUserRepository) + + def test_get_api_key_repository(self, factory): + """Test getting API key repository""" + repository = factory.get_api_key_repository() + assert isinstance(repository, FirestoreApiKeyRepository) + + def test_get_image_repository(self, factory): + """Test getting image repository""" + repository = factory.get_image_repository() + assert isinstance(repository, FirestoreImageRepository) + + def test_database_type_enum(self): + """Test DatabaseType enum""" + assert DatabaseType.FIRESTORE == "firestore" + assert len(DatabaseType) == 1 # Only Firestore should be available + + def test_singleton_factory(self): + """Test that the factory is properly imported as singleton""" + from src.db.repositories.repository_factory import repository_factory + assert isinstance(repository_factory, RepositoryFactory) \ No newline at end of file