433 lines
17 KiB
Python
433 lines
17 KiB
Python
import pytest
|
|
from unittest.mock import Mock, patch, AsyncMock
|
|
from pydantic import BaseModel
|
|
|
|
from src.db.repositories.firestore_repository import FirestoreRepository
|
|
from src.db.repositories.firestore_team_repository import FirestoreTeamRepository
|
|
from src.db.repositories.firestore_user_repository import FirestoreUserRepository
|
|
from src.db.repositories.firestore_api_key_repository import FirestoreApiKeyRepository
|
|
from src.db.repositories.firestore_image_repository import FirestoreImageRepository
|
|
from src.models.team import TeamModel
|
|
from src.models.user import UserModel
|
|
from src.models.api_key import ApiKeyModel
|
|
from src.models.image import ImageModel
|
|
|
|
|
|
class TestFirestoreRepository:
|
|
"""Test cases for the base FirestoreRepository class"""
|
|
|
|
@pytest.fixture
|
|
def mock_firestore_db(self):
|
|
"""Mock Firestore database"""
|
|
mock_db = Mock()
|
|
mock_collection = Mock()
|
|
mock_doc = Mock()
|
|
|
|
mock_db.collection.return_value = mock_collection
|
|
mock_collection.document.return_value = mock_doc
|
|
mock_collection.stream.return_value = []
|
|
|
|
return mock_db
|
|
|
|
@pytest.fixture
|
|
def test_model_class(self):
|
|
"""Create a test model class for testing"""
|
|
class TestModel(BaseModel):
|
|
name: str
|
|
description: str = None
|
|
|
|
return TestModel
|
|
|
|
@pytest.fixture
|
|
def repository(self, test_model_class):
|
|
"""Create a FirestoreRepository instance for testing"""
|
|
return FirestoreRepository("test_collection", test_model_class)
|
|
|
|
def test_init(self, repository, test_model_class):
|
|
"""Test repository initialization"""
|
|
assert repository.collection_name == "test_collection"
|
|
assert repository.model_class == test_model_class
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create(self, repository, test_model_class, mock_firestore_db):
|
|
"""Test creating a document"""
|
|
with patch('src.db.providers.firestore_provider.firestore_db') as mock_provider:
|
|
# Configure the mock provider
|
|
mock_provider.add_document.return_value = "test_doc_id"
|
|
mock_provider.get_document.return_value = {
|
|
"_id": "test_doc_id",
|
|
"name": "Test Model",
|
|
"description": "Test Description"
|
|
}
|
|
|
|
# Create test model
|
|
test_model = test_model_class(
|
|
name="Test Model",
|
|
description="Test Description"
|
|
)
|
|
|
|
# Call create method
|
|
result = await repository.create(test_model)
|
|
|
|
# Verify calls
|
|
mock_provider.add_document.assert_called_once()
|
|
mock_provider.get_document.assert_called_once_with("test_collection", "test_doc_id")
|
|
|
|
# Verify result
|
|
assert result.name == "Test Model"
|
|
assert result.description == "Test Description"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_id_found(self, repository, test_model_class, mock_firestore_db):
|
|
"""Test getting a document by ID when it exists"""
|
|
with patch('src.db.providers.firestore_provider.firestore_db') as mock_provider:
|
|
# Configure the mock provider
|
|
mock_provider.get_document.return_value = {
|
|
"_id": "test_doc_id",
|
|
"name": "Test Model",
|
|
"description": "Test Description"
|
|
}
|
|
|
|
# Call get_by_id method
|
|
result = await repository.get_by_id("test_doc_id")
|
|
|
|
# Verify calls
|
|
mock_provider.get_document.assert_called_once_with("test_collection", "test_doc_id")
|
|
|
|
# Verify result
|
|
assert result is not None
|
|
assert result.name == "Test Model"
|
|
assert result.description == "Test Description"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_id_not_found(self, repository, mock_firestore_db):
|
|
"""Test getting a document by ID when it doesn't exist"""
|
|
with patch('src.db.providers.firestore_provider.firestore_db') as mock_provider:
|
|
# Configure the mock provider to return None (document not found)
|
|
mock_provider.get_document.return_value = None
|
|
|
|
# Call get_by_id method
|
|
result = await repository.get_by_id("nonexistent_id")
|
|
|
|
# Verify calls
|
|
mock_provider.get_document.assert_called_once_with("test_collection", "nonexistent_id")
|
|
|
|
# Verify result
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_all(self, repository, test_model_class, mock_firestore_db):
|
|
"""Test getting all documents"""
|
|
with patch('src.db.providers.firestore_provider.firestore_db') as mock_provider:
|
|
# Configure the mock provider to return a list of documents
|
|
mock_provider.list_documents.return_value = [
|
|
{
|
|
"_id": "doc1",
|
|
"name": "Test Model 1",
|
|
"description": "Test Description 1"
|
|
},
|
|
{
|
|
"_id": "doc2",
|
|
"name": "Test Model 2",
|
|
"description": "Test Description 2"
|
|
}
|
|
]
|
|
|
|
# Call get_all method
|
|
result = await repository.get_all()
|
|
|
|
# Verify calls
|
|
mock_provider.list_documents.assert_called_once_with("test_collection")
|
|
|
|
# Verify result
|
|
assert len(result) == 2
|
|
assert result[0].name == "Test Model 1"
|
|
assert result[1].name == "Test Model 2"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_success(self, repository, test_model_class, mock_firestore_db):
|
|
"""Test updating a document successfully"""
|
|
with patch('src.db.providers.firestore_provider.firestore_db') as mock_provider:
|
|
# Configure the mock provider
|
|
mock_provider.update_document.return_value = True
|
|
mock_provider.get_document.return_value = {
|
|
"_id": "test_doc_id",
|
|
"name": "Updated Model",
|
|
"description": "Updated Description"
|
|
}
|
|
|
|
# Call update method
|
|
result = await repository.update("test_doc_id", {"name": "Updated Model", "description": "Updated Description"})
|
|
|
|
# Verify calls
|
|
mock_provider.update_document.assert_called_once_with("test_collection", "test_doc_id", {"name": "Updated Model", "description": "Updated Description"})
|
|
mock_provider.get_document.assert_called_once_with("test_collection", "test_doc_id")
|
|
|
|
# Verify result
|
|
assert result is not None
|
|
assert result.name == "Updated Model"
|
|
assert result.description == "Updated Description"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_failure(self, repository, mock_firestore_db):
|
|
"""Test updating a document that doesn't exist"""
|
|
with patch('src.db.providers.firestore_provider.firestore_db') as mock_provider:
|
|
# Configure the mock provider to return False (update failed)
|
|
mock_provider.update_document.return_value = False
|
|
|
|
# Call update method
|
|
result = await repository.update("nonexistent_id", {"name": "updated"})
|
|
|
|
# Verify calls
|
|
mock_provider.update_document.assert_called_once_with("test_collection", "nonexistent_id", {"name": "updated"})
|
|
|
|
# Verify result
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_success(self, repository, mock_firestore_db):
|
|
"""Test deleting a document successfully"""
|
|
with patch('src.db.providers.firestore_provider.firestore_db') as mock_provider:
|
|
# Configure the mock provider to return True (delete successful)
|
|
mock_provider.delete_document.return_value = True
|
|
|
|
# Call delete method
|
|
result = await repository.delete("test_doc_id")
|
|
|
|
# Verify calls
|
|
mock_provider.delete_document.assert_called_once_with("test_collection", "test_doc_id")
|
|
|
|
# Verify result
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_failure(self, repository, mock_firestore_db):
|
|
"""Test deleting a document that doesn't exist"""
|
|
with patch('src.db.providers.firestore_provider.firestore_db') as mock_provider:
|
|
# Configure the mock provider to return False (delete failed)
|
|
mock_provider.delete_document.return_value = False
|
|
|
|
# Call delete method
|
|
result = await repository.delete("nonexistent_id")
|
|
|
|
# Verify calls
|
|
mock_provider.delete_document.assert_called_once_with("test_collection", "nonexistent_id")
|
|
|
|
# Verify result
|
|
assert result is False
|
|
|
|
|
|
class TestFirestoreTeamRepository:
|
|
"""Test cases for FirestoreTeamRepository"""
|
|
|
|
@pytest.fixture
|
|
def repository(self):
|
|
"""Create a FirestoreTeamRepository instance for testing"""
|
|
return FirestoreTeamRepository()
|
|
|
|
def test_init(self, repository):
|
|
"""Test repository initialization"""
|
|
assert repository.collection_name == "teams"
|
|
assert repository.model_class == TeamModel
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_id(self, repository):
|
|
"""Test getting team by ID"""
|
|
with patch.object(repository.__class__.__bases__[0], 'get_by_id') as mock_get:
|
|
mock_get.return_value = Mock(id="team_id", name="Test Team")
|
|
|
|
result = await repository.get_by_id("team_id")
|
|
|
|
assert result.id == "team_id"
|
|
assert result.name == "Test Team"
|
|
mock_get.assert_called_once_with("team_id")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update(self, repository):
|
|
"""Test updating team"""
|
|
with patch.object(repository.__class__.__bases__[0], 'update') as mock_update:
|
|
mock_update.return_value = Mock(name="Updated Team")
|
|
|
|
result = await repository.update("team_id", {"name": "Updated Team"})
|
|
|
|
assert result.name == "Updated Team"
|
|
mock_update.assert_called_once_with("team_id", {"name": "Updated Team"})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete(self, repository):
|
|
"""Test deleting team"""
|
|
with patch.object(repository.__class__.__bases__[0], 'delete') as mock_delete:
|
|
mock_delete.return_value = True
|
|
|
|
result = await repository.delete("team_id")
|
|
|
|
assert result is True
|
|
mock_delete.assert_called_once_with("team_id")
|
|
|
|
|
|
class TestFirestoreUserRepository:
|
|
"""Test cases for FirestoreUserRepository"""
|
|
|
|
@pytest.fixture
|
|
def repository(self):
|
|
"""Create a FirestoreUserRepository instance for testing"""
|
|
return FirestoreUserRepository()
|
|
|
|
def test_init(self, repository):
|
|
"""Test repository initialization"""
|
|
assert repository.collection_name == "users"
|
|
assert repository.model_class == UserModel
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_email(self, repository):
|
|
"""Test getting user by email"""
|
|
mock_users = [
|
|
Mock(email="test1@example.com"),
|
|
Mock(email="test2@example.com"),
|
|
Mock(email="target@example.com")
|
|
]
|
|
|
|
with patch.object(repository, 'get_all') as mock_get_all:
|
|
mock_get_all.return_value = mock_users
|
|
|
|
result = await repository.get_by_email("target@example.com")
|
|
|
|
assert result == mock_users[2]
|
|
mock_get_all.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_email_not_found(self, repository):
|
|
"""Test getting user by email when not found"""
|
|
mock_users = [
|
|
Mock(email="test1@example.com"),
|
|
Mock(email="test2@example.com")
|
|
]
|
|
|
|
with patch.object(repository, 'get_all') as mock_get_all:
|
|
mock_get_all.return_value = mock_users
|
|
|
|
result = await repository.get_by_email("notfound@example.com")
|
|
|
|
assert result is None
|
|
mock_get_all.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_team_id(self, repository):
|
|
"""Test getting users by team ID"""
|
|
mock_users = [
|
|
Mock(team_id="team1"),
|
|
Mock(team_id="team2"),
|
|
Mock(team_id="team1")
|
|
]
|
|
|
|
with patch.object(repository, 'get_all') as mock_get_all:
|
|
mock_get_all.return_value = mock_users
|
|
|
|
result = await repository.get_by_team_id("team1")
|
|
|
|
assert len(result) == 2
|
|
assert result[0] == mock_users[0]
|
|
assert result[1] == mock_users[2]
|
|
mock_get_all.assert_called_once()
|
|
|
|
|
|
class TestFirestoreApiKeyRepository:
|
|
"""Test cases for FirestoreApiKeyRepository"""
|
|
|
|
@pytest.fixture
|
|
def repository(self):
|
|
"""Create a FirestoreApiKeyRepository instance for testing"""
|
|
return FirestoreApiKeyRepository()
|
|
|
|
def test_init(self, repository):
|
|
"""Test repository initialization"""
|
|
assert repository.collection_name == "api_keys"
|
|
assert repository.model_class == ApiKeyModel
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_key_hash(self, repository):
|
|
"""Test getting API key by hash"""
|
|
mock_api_keys = [
|
|
Mock(key_hash="hash1"),
|
|
Mock(key_hash="hash2"),
|
|
Mock(key_hash="target_hash")
|
|
]
|
|
|
|
with patch.object(repository, 'get_all') as mock_get_all:
|
|
mock_get_all.return_value = mock_api_keys
|
|
|
|
result = await repository.get_by_key_hash("target_hash")
|
|
|
|
assert result == mock_api_keys[2]
|
|
mock_get_all.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_user_id(self, repository):
|
|
"""Test getting API keys by user ID"""
|
|
mock_api_keys = [
|
|
Mock(user_id="user1"),
|
|
Mock(user_id="user2"),
|
|
Mock(user_id="user1")
|
|
]
|
|
|
|
with patch.object(repository, 'get_all') as mock_get_all:
|
|
mock_get_all.return_value = mock_api_keys
|
|
|
|
result = await repository.get_by_user_id("user1")
|
|
|
|
assert len(result) == 2
|
|
assert result[0] == mock_api_keys[0]
|
|
assert result[1] == mock_api_keys[2]
|
|
mock_get_all.assert_called_once()
|
|
|
|
|
|
class TestFirestoreImageRepository:
|
|
"""Test cases for FirestoreImageRepository"""
|
|
|
|
@pytest.fixture
|
|
def repository(self):
|
|
"""Create a FirestoreImageRepository instance for testing"""
|
|
return FirestoreImageRepository()
|
|
|
|
def test_init(self, repository):
|
|
"""Test repository initialization"""
|
|
assert repository.collection_name == "images"
|
|
assert repository.model_class == ImageModel
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_team_id(self, repository):
|
|
"""Test getting images by team ID"""
|
|
mock_images = [
|
|
Mock(team_id="team1"),
|
|
Mock(team_id="team2"),
|
|
Mock(team_id="team1")
|
|
]
|
|
|
|
with patch.object(repository, 'get_all') as mock_get_all:
|
|
mock_get_all.return_value = mock_images
|
|
|
|
result = await repository.get_by_team_id("team1")
|
|
|
|
assert len(result) == 2
|
|
assert result[0] == mock_images[0]
|
|
assert result[1] == mock_images[2]
|
|
mock_get_all.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_by_uploader_id(self, repository):
|
|
"""Test getting images by uploader ID"""
|
|
mock_images = [
|
|
Mock(uploader_id="user1"),
|
|
Mock(uploader_id="user2"),
|
|
Mock(uploader_id="user1")
|
|
]
|
|
|
|
with patch.object(repository, 'get_all') as mock_get_all:
|
|
mock_get_all.return_value = mock_images
|
|
|
|
result = await repository.get_by_uploader_id("user1")
|
|
|
|
assert len(result) == 2
|
|
assert result[0] == mock_images[0]
|
|
assert result[1] == mock_images[2]
|
|
mock_get_all.assert_called_once() |