import pytest import asyncio from unittest.mock import Mock, patch, AsyncMock from src.db.providers.firestore_provider import FirestoreProvider from src.config.config import settings class TestFirestoreProvider: """Test cases for FirestoreProvider""" @pytest.fixture def provider(self): """Create a FirestoreProvider instance for testing""" return FirestoreProvider() @pytest.fixture def mock_firestore_client(self): """Mock Firestore client for testing""" mock_client = Mock() mock_collection = Mock() mock_doc_ref = Mock() mock_doc = Mock() # Setup mock chain mock_client.collection.return_value = mock_collection mock_collection.document.return_value = mock_doc_ref mock_collection.add.return_value = (None, mock_doc_ref) mock_doc_ref.get.return_value = mock_doc mock_doc_ref.id = "test_doc_id" return mock_client def test_init(self, provider): """Test provider initialization""" assert provider.client is None assert provider._db is None assert "teams" in provider._collections assert "users" in provider._collections assert "api_keys" in provider._collections assert "images" in provider._collections @patch('src.db.providers.firestore_provider.firestore.Client') @patch('google.oauth2.service_account') @patch('os.path.exists') def test_connect_with_credentials_file(self, mock_exists, mock_service_account, mock_client, provider): """Test connecting with credentials file""" # Setup mocks mock_exists.return_value = True mock_credentials = Mock() mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials mock_client_instance = Mock() mock_client.return_value = mock_client_instance # Test connection result = provider.connect() # Assertions assert result is True assert provider.client == mock_client_instance assert provider._db == mock_client_instance mock_service_account.Credentials.from_service_account_file.assert_called_once() mock_client.assert_called_once_with( project=settings.FIRESTORE_PROJECT_ID, credentials=mock_credentials, database=settings.FIRESTORE_DATABASE_NAME ) @patch('src.db.providers.firestore_provider.firestore.Client') @patch('os.path.exists') def test_connect_with_default_credentials(self, mock_exists, mock_client, provider): """Test connecting with application default credentials""" # Setup mocks mock_exists.return_value = False mock_client_instance = Mock() mock_client.return_value = mock_client_instance # Test connection result = provider.connect() # Assertions assert result is True assert provider.client == mock_client_instance mock_client.assert_called_once_with( project=settings.FIRESTORE_PROJECT_ID, database=settings.FIRESTORE_DATABASE_NAME ) def test_disconnect(self, provider): """Test disconnecting from Firestore""" provider.client = Mock() provider._db = Mock() provider.disconnect() assert provider.client is None assert provider._db is None def test_get_collection(self, provider, mock_firestore_client): """Test getting a collection reference""" provider.client = mock_firestore_client collection = provider.get_collection("test_collection") mock_firestore_client.collection.assert_called_once_with("test_collection") assert collection == mock_firestore_client.collection.return_value def test_get_collection_without_client(self, provider): """Test getting collection when client is None""" provider.client = None with patch.object(provider, 'connect') as mock_connect: # Simulate successful connection by setting client after connect is called def set_client(): provider.client = Mock() mock_connect.side_effect = set_client provider.get_collection("test_collection") mock_connect.assert_called_once() @pytest.mark.asyncio async def test_add_document(self, provider, mock_firestore_client): """Test adding a document""" provider.client = mock_firestore_client test_data = {"name": "Test", "value": 123} with patch.object(provider, 'get_collection') as mock_get_collection: mock_collection = Mock() mock_doc_ref = Mock() mock_doc_ref.id = "generated_id" mock_collection.add.return_value = (None, mock_doc_ref) mock_get_collection.return_value = mock_collection doc_id = await provider.add_document("test_collection", test_data) assert doc_id == "generated_id" mock_collection.add.assert_called_once() @pytest.mark.asyncio async def test_add_document_with_id(self, provider, mock_firestore_client): """Test adding a document with specific ID""" provider.client = mock_firestore_client test_data = {"_id": "custom_id", "name": "Test"} with patch.object(provider, 'get_collection') as mock_get_collection: mock_collection = Mock() mock_doc_ref = Mock() mock_collection.document.return_value = mock_doc_ref mock_get_collection.return_value = mock_collection doc_id = await provider.add_document("test_collection", test_data) assert doc_id == "custom_id" mock_collection.document.assert_called_once_with("custom_id") mock_doc_ref.set.assert_called_once() @pytest.mark.asyncio async def test_get_document_exists(self, provider, mock_firestore_client): """Test getting an existing document""" provider.client = mock_firestore_client with patch.object(provider, 'get_collection') as mock_get_collection: mock_collection = Mock() mock_doc_ref = Mock() mock_doc = Mock() mock_doc.exists = True mock_doc.to_dict.return_value = {"name": "Test", "value": 123} mock_collection.document.return_value = mock_doc_ref mock_doc_ref.get.return_value = mock_doc mock_get_collection.return_value = mock_collection result = await provider.get_document("test_collection", "test_id") assert result == {"name": "Test", "value": 123, "_id": "test_id"} mock_collection.document.assert_called_once_with("test_id") @pytest.mark.asyncio async def test_get_document_not_exists(self, provider, mock_firestore_client): """Test getting a non-existent document""" provider.client = mock_firestore_client with patch.object(provider, 'get_collection') as mock_get_collection: mock_collection = Mock() mock_doc_ref = Mock() mock_doc = Mock() mock_doc.exists = False mock_collection.document.return_value = mock_doc_ref mock_doc_ref.get.return_value = mock_doc mock_get_collection.return_value = mock_collection result = await provider.get_document("test_collection", "test_id") assert result is None @pytest.mark.asyncio async def test_list_documents(self, provider, mock_firestore_client): """Test listing all documents in a collection""" provider.client = mock_firestore_client with patch.object(provider, 'get_collection') as mock_get_collection: mock_collection = Mock() mock_doc1 = Mock() mock_doc1.id = "doc1" mock_doc1.to_dict.return_value = {"name": "Test1"} mock_doc2 = Mock() mock_doc2.id = "doc2" mock_doc2.to_dict.return_value = {"name": "Test2"} mock_collection.stream.return_value = [mock_doc1, mock_doc2] mock_get_collection.return_value = mock_collection result = await provider.list_documents("test_collection") assert len(result) == 2 assert result[0] == {"name": "Test1", "_id": "doc1"} assert result[1] == {"name": "Test2", "_id": "doc2"} @pytest.mark.asyncio async def test_update_document(self, provider, mock_firestore_client): """Test updating a document""" provider.client = mock_firestore_client update_data = {"name": "Updated", "_id": "should_be_removed"} with patch.object(provider, 'get_collection') as mock_get_collection: mock_collection = Mock() mock_doc_ref = Mock() mock_collection.document.return_value = mock_doc_ref mock_get_collection.return_value = mock_collection result = await provider.update_document("test_collection", "test_id", update_data) assert result is True mock_collection.document.assert_called_once_with("test_id") # Verify _id was removed from update data mock_doc_ref.update.assert_called_once_with({"name": "Updated"}) @pytest.mark.asyncio async def test_delete_document(self, provider, mock_firestore_client): """Test deleting a document""" provider.client = mock_firestore_client with patch.object(provider, 'get_collection') as mock_get_collection: mock_collection = Mock() mock_doc_ref = Mock() mock_collection.document.return_value = mock_doc_ref mock_get_collection.return_value = mock_collection result = await provider.delete_document("test_collection", "test_id") assert result is True mock_collection.document.assert_called_once_with("test_id") mock_doc_ref.delete.assert_called_once() def test_convert_to_model(self, provider): """Test converting document data to Pydantic model""" from pydantic import BaseModel class TestModel(BaseModel): name: str value: int doc_data = {"name": "Test", "value": 123} result = provider.convert_to_model(TestModel, doc_data) assert isinstance(result, TestModel) assert result.name == "Test" assert result.value == 123