273 lines
11 KiB
Python
273 lines
11 KiB
Python
import pytest
|
|
import asyncio
|
|
from unittest.mock import Mock, patch, AsyncMock
|
|
from src.db.providers.firestore_provider import FirestoreProvider
|
|
from src.config.config import settings
|
|
|
|
|
|
class TestFirestoreProvider:
|
|
"""Test cases for FirestoreProvider"""
|
|
|
|
@pytest.fixture
|
|
def provider(self):
|
|
"""Create a FirestoreProvider instance for testing"""
|
|
return FirestoreProvider()
|
|
|
|
@pytest.fixture
|
|
def mock_firestore_client(self):
|
|
"""Mock Firestore client for testing"""
|
|
mock_client = Mock()
|
|
mock_collection = Mock()
|
|
mock_doc_ref = Mock()
|
|
mock_doc = Mock()
|
|
|
|
# Setup mock chain
|
|
mock_client.collection.return_value = mock_collection
|
|
mock_collection.document.return_value = mock_doc_ref
|
|
mock_collection.add.return_value = (None, mock_doc_ref)
|
|
mock_doc_ref.get.return_value = mock_doc
|
|
mock_doc_ref.id = "test_doc_id"
|
|
|
|
return mock_client
|
|
|
|
def test_init(self, provider):
|
|
"""Test provider initialization"""
|
|
assert provider.client is None
|
|
assert provider._db is None
|
|
assert "teams" in provider._collections
|
|
assert "users" in provider._collections
|
|
assert "api_keys" in provider._collections
|
|
assert "images" in provider._collections
|
|
|
|
@patch('src.db.providers.firestore_provider.firestore.Client')
|
|
@patch('google.oauth2.service_account')
|
|
@patch('os.path.exists')
|
|
def test_connect_with_credentials_file(self, mock_exists, mock_service_account, mock_client, provider):
|
|
"""Test connecting with credentials file"""
|
|
# Setup mocks
|
|
mock_exists.return_value = True
|
|
mock_credentials = Mock()
|
|
mock_service_account.Credentials.from_service_account_file.return_value = mock_credentials
|
|
mock_client_instance = Mock()
|
|
mock_client.return_value = mock_client_instance
|
|
|
|
# Test connection
|
|
result = provider.connect()
|
|
|
|
# Assertions
|
|
assert result is True
|
|
assert provider.client == mock_client_instance
|
|
assert provider._db == mock_client_instance
|
|
mock_service_account.Credentials.from_service_account_file.assert_called_once()
|
|
mock_client.assert_called_once_with(
|
|
project=settings.FIRESTORE_PROJECT_ID,
|
|
credentials=mock_credentials,
|
|
database=settings.FIRESTORE_DATABASE_NAME
|
|
)
|
|
|
|
@patch('src.db.providers.firestore_provider.firestore.Client')
|
|
@patch('os.path.exists')
|
|
def test_connect_with_default_credentials(self, mock_exists, mock_client, provider):
|
|
"""Test connecting with application default credentials"""
|
|
# Setup mocks
|
|
mock_exists.return_value = False
|
|
mock_client_instance = Mock()
|
|
mock_client.return_value = mock_client_instance
|
|
|
|
# Test connection
|
|
result = provider.connect()
|
|
|
|
# Assertions
|
|
assert result is True
|
|
assert provider.client == mock_client_instance
|
|
mock_client.assert_called_once_with(
|
|
project=settings.FIRESTORE_PROJECT_ID,
|
|
database=settings.FIRESTORE_DATABASE_NAME
|
|
)
|
|
|
|
def test_disconnect(self, provider):
|
|
"""Test disconnecting from Firestore"""
|
|
provider.client = Mock()
|
|
provider._db = Mock()
|
|
|
|
provider.disconnect()
|
|
|
|
assert provider.client is None
|
|
assert provider._db is None
|
|
|
|
def test_get_collection(self, provider, mock_firestore_client):
|
|
"""Test getting a collection reference"""
|
|
provider.client = mock_firestore_client
|
|
|
|
collection = provider.get_collection("test_collection")
|
|
|
|
mock_firestore_client.collection.assert_called_once_with("test_collection")
|
|
assert collection == mock_firestore_client.collection.return_value
|
|
|
|
def test_get_collection_without_client(self, provider):
|
|
"""Test getting collection when client is None"""
|
|
provider.client = None
|
|
|
|
with patch.object(provider, 'connect') as mock_connect:
|
|
# Simulate successful connection by setting client after connect is called
|
|
def set_client():
|
|
provider.client = Mock()
|
|
mock_connect.side_effect = set_client
|
|
|
|
provider.get_collection("test_collection")
|
|
|
|
mock_connect.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_document(self, provider, mock_firestore_client):
|
|
"""Test adding a document"""
|
|
provider.client = mock_firestore_client
|
|
|
|
test_data = {"name": "Test", "value": 123}
|
|
|
|
with patch.object(provider, 'get_collection') as mock_get_collection:
|
|
mock_collection = Mock()
|
|
mock_doc_ref = Mock()
|
|
mock_doc_ref.id = "generated_id"
|
|
mock_collection.add.return_value = (None, mock_doc_ref)
|
|
mock_get_collection.return_value = mock_collection
|
|
|
|
doc_id = await provider.add_document("test_collection", test_data)
|
|
|
|
assert doc_id == "generated_id"
|
|
mock_collection.add.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_add_document_with_id(self, provider, mock_firestore_client):
|
|
"""Test adding a document with specific ID"""
|
|
provider.client = mock_firestore_client
|
|
|
|
test_data = {"_id": "custom_id", "name": "Test"}
|
|
|
|
with patch.object(provider, 'get_collection') as mock_get_collection:
|
|
mock_collection = Mock()
|
|
mock_doc_ref = Mock()
|
|
mock_collection.document.return_value = mock_doc_ref
|
|
mock_get_collection.return_value = mock_collection
|
|
|
|
doc_id = await provider.add_document("test_collection", test_data)
|
|
|
|
assert doc_id == "custom_id"
|
|
mock_collection.document.assert_called_once_with("custom_id")
|
|
mock_doc_ref.set.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_document_exists(self, provider, mock_firestore_client):
|
|
"""Test getting an existing document"""
|
|
provider.client = mock_firestore_client
|
|
|
|
with patch.object(provider, 'get_collection') as mock_get_collection:
|
|
mock_collection = Mock()
|
|
mock_doc_ref = Mock()
|
|
mock_doc = Mock()
|
|
mock_doc.exists = True
|
|
mock_doc.to_dict.return_value = {"name": "Test", "value": 123}
|
|
|
|
mock_collection.document.return_value = mock_doc_ref
|
|
mock_doc_ref.get.return_value = mock_doc
|
|
mock_get_collection.return_value = mock_collection
|
|
|
|
result = await provider.get_document("test_collection", "test_id")
|
|
|
|
assert result == {"name": "Test", "value": 123, "_id": "test_id"}
|
|
mock_collection.document.assert_called_once_with("test_id")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_document_not_exists(self, provider, mock_firestore_client):
|
|
"""Test getting a non-existent document"""
|
|
provider.client = mock_firestore_client
|
|
|
|
with patch.object(provider, 'get_collection') as mock_get_collection:
|
|
mock_collection = Mock()
|
|
mock_doc_ref = Mock()
|
|
mock_doc = Mock()
|
|
mock_doc.exists = False
|
|
|
|
mock_collection.document.return_value = mock_doc_ref
|
|
mock_doc_ref.get.return_value = mock_doc
|
|
mock_get_collection.return_value = mock_collection
|
|
|
|
result = await provider.get_document("test_collection", "test_id")
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_documents(self, provider, mock_firestore_client):
|
|
"""Test listing all documents in a collection"""
|
|
provider.client = mock_firestore_client
|
|
|
|
with patch.object(provider, 'get_collection') as mock_get_collection:
|
|
mock_collection = Mock()
|
|
mock_doc1 = Mock()
|
|
mock_doc1.id = "doc1"
|
|
mock_doc1.to_dict.return_value = {"name": "Test1"}
|
|
mock_doc2 = Mock()
|
|
mock_doc2.id = "doc2"
|
|
mock_doc2.to_dict.return_value = {"name": "Test2"}
|
|
|
|
mock_collection.stream.return_value = [mock_doc1, mock_doc2]
|
|
mock_get_collection.return_value = mock_collection
|
|
|
|
result = await provider.list_documents("test_collection")
|
|
|
|
assert len(result) == 2
|
|
assert result[0] == {"name": "Test1", "_id": "doc1"}
|
|
assert result[1] == {"name": "Test2", "_id": "doc2"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_document(self, provider, mock_firestore_client):
|
|
"""Test updating a document"""
|
|
provider.client = mock_firestore_client
|
|
|
|
update_data = {"name": "Updated", "_id": "should_be_removed"}
|
|
|
|
with patch.object(provider, 'get_collection') as mock_get_collection:
|
|
mock_collection = Mock()
|
|
mock_doc_ref = Mock()
|
|
mock_collection.document.return_value = mock_doc_ref
|
|
mock_get_collection.return_value = mock_collection
|
|
|
|
result = await provider.update_document("test_collection", "test_id", update_data)
|
|
|
|
assert result is True
|
|
mock_collection.document.assert_called_once_with("test_id")
|
|
# Verify _id was removed from update data
|
|
mock_doc_ref.update.assert_called_once_with({"name": "Updated"})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_document(self, provider, mock_firestore_client):
|
|
"""Test deleting a document"""
|
|
provider.client = mock_firestore_client
|
|
|
|
with patch.object(provider, 'get_collection') as mock_get_collection:
|
|
mock_collection = Mock()
|
|
mock_doc_ref = Mock()
|
|
mock_collection.document.return_value = mock_doc_ref
|
|
mock_get_collection.return_value = mock_collection
|
|
|
|
result = await provider.delete_document("test_collection", "test_id")
|
|
|
|
assert result is True
|
|
mock_collection.document.assert_called_once_with("test_id")
|
|
mock_doc_ref.delete.assert_called_once()
|
|
|
|
def test_convert_to_model(self, provider):
|
|
"""Test converting document data to Pydantic model"""
|
|
from pydantic import BaseModel
|
|
|
|
class TestModel(BaseModel):
|
|
name: str
|
|
value: int
|
|
|
|
doc_data = {"name": "Test", "value": 123}
|
|
|
|
result = provider.convert_to_model(TestModel, doc_data)
|
|
|
|
assert isinstance(result, TestModel)
|
|
assert result.name == "Test"
|
|
assert result.value == 123 |