image_management_api/tests/db/test_firestore_repositories.py
2025-05-25 16:20:42 +02:00

406 lines
16 KiB
Python

import pytest
from unittest.mock import Mock, patch, AsyncMock
from pydantic import BaseModel
from src.db.repositories.firestore_repositories import (
FirestoreRepository,
FirestoreTeamRepository,
FirestoreUserRepository,
FirestoreApiKeyRepository,
FirestoreImageRepository
)
from src.models.team import TeamModel
from src.models.user import UserModel
from src.models.api_key import ApiKeyModel
from src.models.image import ImageModel
class TestFirestoreRepository:
"""Test cases for the base FirestoreRepository class"""
@pytest.fixture
def mock_firestore_db(self):
"""Mock Firestore database"""
mock_db = Mock()
mock_collection = Mock()
mock_doc = Mock()
mock_db.collection.return_value = mock_collection
mock_collection.document.return_value = mock_doc
mock_collection.stream.return_value = []
return mock_db
@pytest.fixture
def test_model_class(self):
"""Create a test model class for testing"""
class TestModel(BaseModel):
name: str
value: int
return TestModel
@pytest.fixture
def repository(self, test_model_class):
"""Create a FirestoreRepository instance for testing"""
return FirestoreRepository("test_collection", test_model_class)
def test_init(self, repository, test_model_class):
"""Test repository initialization"""
assert repository.collection_name == "test_collection"
assert repository.model_class == test_model_class
@pytest.mark.asyncio
async def test_create(self, repository, test_model_class, mock_firestore_db):
"""Test creating a document"""
with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db):
# Mock the document reference and set operation
mock_doc_ref = Mock()
mock_doc_ref.id = "test_id"
mock_collection = mock_firestore_db.collection.return_value
mock_collection.add.return_value = (None, mock_doc_ref)
# Create test model instance
test_instance = test_model_class(name="test", value=42)
# Call create method
result = await repository.create(test_instance)
# Verify the result
assert result.name == "test"
assert result.value == 42
# Verify Firestore calls
mock_firestore_db.collection.assert_called_once_with("test_collection")
mock_collection.add.assert_called_once()
@pytest.mark.asyncio
async def test_get_by_id_found(self, repository, test_model_class, mock_firestore_db):
"""Test getting a document by ID when it exists"""
with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db):
# Mock document snapshot
mock_doc_snapshot = Mock()
mock_doc_snapshot.exists = True
mock_doc_snapshot.to_dict.return_value = {"name": "test", "value": 42}
mock_doc_snapshot.id = "test_id"
mock_doc_ref = Mock()
mock_doc_ref.get.return_value = mock_doc_snapshot
mock_collection = mock_firestore_db.collection.return_value
mock_collection.document.return_value = mock_doc_ref
result = await repository.get_by_id("test_id")
assert result.name == "test"
assert result.value == 42
@pytest.mark.asyncio
async def test_get_by_id_not_found(self, repository, mock_firestore_db):
"""Test getting a document by ID when it doesn't exist"""
with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db):
# Mock document snapshot that doesn't exist
mock_doc_snapshot = Mock()
mock_doc_snapshot.exists = False
mock_doc_ref = Mock()
mock_doc_ref.get.return_value = mock_doc_snapshot
mock_collection = mock_firestore_db.collection.return_value
mock_collection.document.return_value = mock_doc_ref
result = await repository.get_by_id("nonexistent_id")
assert result is None
@pytest.mark.asyncio
async def test_get_all(self, repository, test_model_class, mock_firestore_db):
"""Test getting all documents"""
with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db):
# Mock document snapshots
mock_docs = [
Mock(to_dict=lambda: {"name": "test1", "value": 1}, id="id1"),
Mock(to_dict=lambda: {"name": "test2", "value": 2}, id="id2")
]
mock_collection = mock_firestore_db.collection.return_value
mock_collection.stream.return_value = mock_docs
result = await repository.get_all()
assert len(result) == 2
assert result[0].name == "test1"
assert result[1].name == "test2"
@pytest.mark.asyncio
async def test_update_success(self, repository, test_model_class, mock_firestore_db):
"""Test updating a document successfully"""
with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db):
# Mock successful update
mock_doc_ref = Mock()
mock_doc_ref.update.return_value = None # Firestore update returns None on success
mock_collection = mock_firestore_db.collection.return_value
mock_collection.document.return_value = mock_doc_ref
# Mock get_by_id to return updated document
updated_instance = test_model_class(name="updated", value=99)
with patch.object(repository, 'get_by_id', return_value=updated_instance):
result = await repository.update("test_id", {"name": "updated", "value": 99})
assert result.name == "updated"
assert result.value == 99
mock_doc_ref.update.assert_called_once_with({"name": "updated", "value": 99})
@pytest.mark.asyncio
async def test_update_failure(self, repository, mock_firestore_db):
"""Test updating a document that doesn't exist"""
with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db):
# Mock failed update (document doesn't exist)
mock_doc_ref = Mock()
mock_doc_ref.update.side_effect = Exception("Document not found")
mock_collection = mock_firestore_db.collection.return_value
mock_collection.document.return_value = mock_doc_ref
with pytest.raises(Exception):
await repository.update("nonexistent_id", {"name": "updated"})
@pytest.mark.asyncio
async def test_delete_success(self, repository, mock_firestore_db):
"""Test deleting a document successfully"""
with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db):
mock_doc_ref = Mock()
mock_doc_ref.delete.return_value = None # Firestore delete returns None on success
mock_collection = mock_firestore_db.collection.return_value
mock_collection.document.return_value = mock_doc_ref
result = await repository.delete("test_id")
assert result is True
mock_doc_ref.delete.assert_called_once()
@pytest.mark.asyncio
async def test_delete_failure(self, repository, mock_firestore_db):
"""Test deleting a document that doesn't exist"""
with patch('src.db.repositories.firestore_repositories.get_firestore_db', return_value=mock_firestore_db):
mock_doc_ref = Mock()
mock_doc_ref.delete.side_effect = Exception("Document not found")
mock_collection = mock_firestore_db.collection.return_value
mock_collection.document.return_value = mock_doc_ref
result = await repository.delete("nonexistent_id")
assert result is False
class TestFirestoreTeamRepository:
"""Test cases for FirestoreTeamRepository"""
@pytest.fixture
def repository(self):
"""Create a FirestoreTeamRepository instance for testing"""
return FirestoreTeamRepository()
def test_init(self, repository):
"""Test repository initialization"""
assert repository.collection_name == "teams"
assert repository.model_class == TeamModel
@pytest.mark.asyncio
async def test_get_by_id(self, repository):
"""Test getting team by ID"""
with patch.object(repository.__class__.__bases__[0], 'get_by_id') as mock_get:
mock_get.return_value = Mock(id="team_id", name="Test Team")
result = await repository.get_by_id("team_id")
assert result.id == "team_id"
assert result.name == "Test Team"
mock_get.assert_called_once_with("team_id")
@pytest.mark.asyncio
async def test_update(self, repository):
"""Test updating team"""
with patch.object(repository.__class__.__bases__[0], 'update') as mock_update:
mock_update.return_value = Mock(name="Updated Team")
result = await repository.update("team_id", {"name": "Updated Team"})
assert result.name == "Updated Team"
mock_update.assert_called_once_with("team_id", {"name": "Updated Team"})
@pytest.mark.asyncio
async def test_delete(self, repository):
"""Test deleting team"""
with patch.object(repository.__class__.__bases__[0], 'delete') as mock_delete:
mock_delete.return_value = True
result = await repository.delete("team_id")
assert result is True
mock_delete.assert_called_once_with("team_id")
class TestFirestoreUserRepository:
"""Test cases for FirestoreUserRepository"""
@pytest.fixture
def repository(self):
"""Create a FirestoreUserRepository instance for testing"""
return FirestoreUserRepository()
def test_init(self, repository):
"""Test repository initialization"""
assert repository.collection_name == "users"
assert repository.model_class == UserModel
@pytest.mark.asyncio
async def test_get_by_email(self, repository):
"""Test getting user by email"""
mock_users = [
Mock(email="test1@example.com"),
Mock(email="test2@example.com"),
Mock(email="target@example.com")
]
with patch.object(repository, 'get_all') as mock_get_all:
mock_get_all.return_value = mock_users
result = await repository.get_by_email("target@example.com")
assert result == mock_users[2]
mock_get_all.assert_called_once()
@pytest.mark.asyncio
async def test_get_by_email_not_found(self, repository):
"""Test getting user by email when not found"""
mock_users = [
Mock(email="test1@example.com"),
Mock(email="test2@example.com")
]
with patch.object(repository, 'get_all') as mock_get_all:
mock_get_all.return_value = mock_users
result = await repository.get_by_email("notfound@example.com")
assert result is None
mock_get_all.assert_called_once()
@pytest.mark.asyncio
async def test_get_by_team_id(self, repository):
"""Test getting users by team ID"""
mock_users = [
Mock(team_id="team1"),
Mock(team_id="team2"),
Mock(team_id="team1")
]
with patch.object(repository, 'get_all') as mock_get_all:
mock_get_all.return_value = mock_users
result = await repository.get_by_team_id("team1")
assert len(result) == 2
assert result[0] == mock_users[0]
assert result[1] == mock_users[2]
mock_get_all.assert_called_once()
class TestFirestoreApiKeyRepository:
"""Test cases for FirestoreApiKeyRepository"""
@pytest.fixture
def repository(self):
"""Create a FirestoreApiKeyRepository instance for testing"""
return FirestoreApiKeyRepository()
def test_init(self, repository):
"""Test repository initialization"""
assert repository.collection_name == "api_keys"
assert repository.model_class == ApiKeyModel
@pytest.mark.asyncio
async def test_get_by_key_hash(self, repository):
"""Test getting API key by hash"""
mock_api_keys = [
Mock(key_hash="hash1"),
Mock(key_hash="hash2"),
Mock(key_hash="target_hash")
]
with patch.object(repository, 'get_all') as mock_get_all:
mock_get_all.return_value = mock_api_keys
result = await repository.get_by_key_hash("target_hash")
assert result == mock_api_keys[2]
mock_get_all.assert_called_once()
@pytest.mark.asyncio
async def test_get_by_user_id(self, repository):
"""Test getting API keys by user ID"""
mock_api_keys = [
Mock(user_id="user1"),
Mock(user_id="user2"),
Mock(user_id="user1")
]
with patch.object(repository, 'get_all') as mock_get_all:
mock_get_all.return_value = mock_api_keys
result = await repository.get_by_user_id("user1")
assert len(result) == 2
assert result[0] == mock_api_keys[0]
assert result[1] == mock_api_keys[2]
mock_get_all.assert_called_once()
class TestFirestoreImageRepository:
"""Test cases for FirestoreImageRepository"""
@pytest.fixture
def repository(self):
"""Create a FirestoreImageRepository instance for testing"""
return FirestoreImageRepository()
def test_init(self, repository):
"""Test repository initialization"""
assert repository.collection_name == "images"
assert repository.model_class == ImageModel
@pytest.mark.asyncio
async def test_get_by_team_id(self, repository):
"""Test getting images by team ID"""
mock_images = [
Mock(team_id="team1"),
Mock(team_id="team2"),
Mock(team_id="team1")
]
with patch.object(repository, 'get_all') as mock_get_all:
mock_get_all.return_value = mock_images
result = await repository.get_by_team_id("team1")
assert len(result) == 2
assert result[0] == mock_images[0]
assert result[1] == mock_images[2]
mock_get_all.assert_called_once()
@pytest.mark.asyncio
async def test_get_by_uploader_id(self, repository):
"""Test getting images by uploader ID"""
mock_images = [
Mock(uploader_id="user1"),
Mock(uploader_id="user2"),
Mock(uploader_id="user1")
]
with patch.object(repository, 'get_all') as mock_get_all:
mock_get_all.return_value = mock_images
result = await repository.get_by_uploader_id("user1")
assert len(result) == 2
assert result[0] == mock_images[0]
assert result[1] == mock_images[2]
mock_get_all.assert_called_once()