108 lines
4.0 KiB
Python
108 lines
4.0 KiB
Python
import pytest
|
|
from unittest.mock import Mock, patch
|
|
from src.db import Database, db
|
|
|
|
|
|
class TestDatabase:
|
|
"""Test cases for Database class"""
|
|
|
|
@pytest.fixture
|
|
def database(self):
|
|
"""Create a Database instance for testing"""
|
|
return Database()
|
|
|
|
def test_init(self, database):
|
|
"""Test database initialization"""
|
|
assert database.provider is not None
|
|
assert hasattr(database.provider, 'client')
|
|
|
|
@patch('src.db.FirestoreProvider')
|
|
def test_connect_to_database_success(self, mock_provider_class, database):
|
|
"""Test successful database connection"""
|
|
mock_provider = Mock()
|
|
mock_provider_class.return_value = mock_provider
|
|
database.provider = mock_provider
|
|
|
|
database.connect_to_database()
|
|
|
|
mock_provider.connect.assert_called_once()
|
|
|
|
@patch('src.db.FirestoreProvider')
|
|
def test_connect_to_database_failure(self, mock_provider_class, database):
|
|
"""Test database connection failure"""
|
|
mock_provider = Mock()
|
|
mock_provider.connect.side_effect = Exception("Connection failed")
|
|
mock_provider_class.return_value = mock_provider
|
|
database.provider = mock_provider
|
|
|
|
with pytest.raises(Exception, match="Connection failed"):
|
|
database.connect_to_database()
|
|
|
|
mock_provider.connect.assert_called_once()
|
|
|
|
def test_close_database_connection_success(self, database):
|
|
"""Test successful database disconnection"""
|
|
mock_provider = Mock()
|
|
database.provider = mock_provider
|
|
|
|
database.close_database_connection()
|
|
|
|
mock_provider.disconnect.assert_called_once()
|
|
|
|
def test_close_database_connection_failure(self, database):
|
|
"""Test database disconnection with error"""
|
|
mock_provider = Mock()
|
|
mock_provider.disconnect.side_effect = Exception("Disconnect failed")
|
|
database.provider = mock_provider
|
|
|
|
# Should not raise exception, just log error
|
|
database.close_database_connection()
|
|
|
|
mock_provider.disconnect.assert_called_once()
|
|
|
|
def test_get_database_with_client(self, database):
|
|
"""Test getting database when client exists"""
|
|
mock_provider = Mock()
|
|
mock_client = Mock()
|
|
mock_provider.client = mock_client
|
|
database.provider = mock_provider
|
|
|
|
result = database.get_database()
|
|
|
|
assert result == mock_client
|
|
|
|
def test_get_database_without_client_reconnect_success(self, database):
|
|
"""Test getting database when client is None and reconnection succeeds"""
|
|
mock_provider = Mock()
|
|
mock_provider.client = None
|
|
database.provider = mock_provider
|
|
|
|
with patch.object(database, 'connect_to_database') as mock_connect:
|
|
# Simulate successful reconnection by setting client after connect is called
|
|
def set_client():
|
|
mock_provider.client = Mock()
|
|
mock_connect.side_effect = set_client
|
|
|
|
result = database.get_database()
|
|
|
|
mock_connect.assert_called_once()
|
|
assert result == mock_provider.client
|
|
|
|
def test_get_database_without_client_reconnect_failure(self, database):
|
|
"""Test getting database when client is None and reconnection fails"""
|
|
mock_provider = Mock()
|
|
mock_provider.client = None
|
|
database.provider = mock_provider
|
|
|
|
with patch.object(database, 'connect_to_database') as mock_connect:
|
|
mock_connect.side_effect = Exception("Reconnection failed")
|
|
|
|
result = database.get_database()
|
|
|
|
mock_connect.assert_called_once()
|
|
assert result is None
|
|
|
|
def test_singleton_db_instance(self):
|
|
"""Test that db is properly imported as singleton"""
|
|
assert isinstance(db, Database)
|
|
assert db.provider is not None |