2025-05-24 07:34:07 +02:00

266 lines
8.4 KiB
Python

import asyncio
import pytest
from datetime import datetime, timedelta
from typing import Dict, Any, Generator, List
from bson import ObjectId
from fastapi import FastAPI
from fastapi.testclient import TestClient
from src.db.models.team import TeamModel
from src.db.models.user import UserModel
from src.db.models.api_key import ApiKeyModel
from src.db.models.image import ImageModel
from src.core.security import generate_api_key
from src.db.repositories.team_repository import team_repository
from src.db.repositories.user_repository import user_repository
from src.db.repositories.api_key_repository import api_key_repository
# Add image_repository import - this might not exist yet, so tests will need to handle that
try:
from src.db.repositories.image_repository import image_repository
image_repository_exists = True
except ImportError:
image_repository_exists = False
# Mock repositories
class MockTeamRepository:
def __init__(self):
self.teams = {}
async def create(self, team: TeamModel) -> TeamModel:
if not team.id:
team.id = ObjectId()
self.teams[str(team.id)] = team
return team
async def get_by_id(self, id: ObjectId) -> TeamModel:
return self.teams.get(str(id))
async def get_all(self) -> List[TeamModel]:
return list(self.teams.values())
async def update(self, id: ObjectId, data: Dict[str, Any]) -> TeamModel:
team = self.teams.get(str(id))
if not team:
return None
for key, value in data.items():
setattr(team, key, value)
team.updated_at = datetime.utcnow()
self.teams[str(id)] = team
return team
async def delete(self, id: ObjectId) -> bool:
if str(id) in self.teams:
del self.teams[str(id)]
return True
return False
class MockUserRepository:
def __init__(self):
self.users = {}
async def create(self, user: UserModel) -> UserModel:
if not user.id:
user.id = ObjectId()
self.users[str(user.id)] = user
return user
async def get_by_id(self, id: ObjectId) -> UserModel:
return self.users.get(str(id))
async def get_by_email(self, email: str) -> UserModel:
for user in self.users.values():
if user.email == email:
return user
return None
async def get_by_team(self, team_id: ObjectId) -> List[UserModel]:
return [u for u in self.users.values() if str(u.team_id) == str(team_id)]
class MockApiKeyRepository:
def __init__(self):
self.api_keys = {}
async def create(self, api_key: ApiKeyModel) -> ApiKeyModel:
if not api_key.id:
api_key.id = ObjectId()
self.api_keys[str(api_key.id)] = api_key
return api_key
async def get_by_id(self, id: ObjectId) -> ApiKeyModel:
return self.api_keys.get(str(id))
async def get_by_hash(self, key_hash: str) -> ApiKeyModel:
for key in self.api_keys.values():
if key.key_hash == key_hash:
return key
return None
async def get_by_user(self, user_id: ObjectId) -> List[ApiKeyModel]:
return [k for k in self.api_keys.values() if str(k.user_id) == str(user_id)]
async def update_last_used(self, id: ObjectId) -> bool:
key = self.api_keys.get(str(id))
if not key:
return False
key.last_used = datetime.utcnow()
self.api_keys[str(id)] = key
return True
async def deactivate(self, id: ObjectId) -> bool:
key = self.api_keys.get(str(id))
if not key:
return False
key.is_active = False
self.api_keys[str(id)] = key
return True
class MockImageRepository:
def __init__(self):
self.images = {}
async def create(self, image: ImageModel) -> ImageModel:
if not image.id:
image.id = ObjectId()
self.images[str(image.id)] = image
return image
async def get_by_id(self, id: ObjectId) -> ImageModel:
return self.images.get(str(id))
async def get_by_team(self, team_id: ObjectId) -> List[ImageModel]:
return [img for img in self.images.values() if str(img.team_id) == str(team_id)]
async def update(self, id: ObjectId, data: Dict[str, Any]) -> ImageModel:
image = self.images.get(str(id))
if not image:
return None
for key, value in data.items():
setattr(image, key, value)
self.images[str(id)] = image
return image
async def delete(self, id: ObjectId) -> bool:
if str(id) in self.images:
del self.images[str(id)]
return True
return False
async def search(self, team_id: ObjectId, query: str = None, tags: List[str] = None) -> List[ImageModel]:
results = [img for img in self.images.values() if str(img.team_id) == str(team_id)]
if query:
query = query.lower()
results = [img for img in results if
(img.description and query in img.description.lower()) or
query in img.filename.lower() or
query in img.original_filename.lower()]
if tags:
results = [img for img in results if all(tag in img.tags for tag in tags)]
return results
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="module")
def app() -> FastAPI:
from main import app
# Replace repositories with mocks
team_repository.__class__ = MockTeamRepository
user_repository.__class__ = MockUserRepository
api_key_repository.__class__ = MockApiKeyRepository
# Try to replace image_repository if it exists
if image_repository_exists:
from src.db.repositories.image_repository import image_repository
image_repository.__class__ = MockImageRepository
return app
@pytest.fixture(scope="module")
def client(app: FastAPI) -> Generator:
with TestClient(app) as c:
yield c
@pytest.fixture(scope="function")
async def test_team() -> TeamModel:
team = TeamModel(
name="Test Team",
description="A team for testing"
)
created_team = await team_repository.create(team)
return created_team
@pytest.fixture(scope="function")
async def admin_user(test_team: TeamModel) -> UserModel:
user = UserModel(
email="admin@example.com",
name="Admin User",
team_id=test_team.id,
is_admin=True
)
created_user = await user_repository.create(user)
return created_user
@pytest.fixture(scope="function")
async def regular_user(test_team: TeamModel) -> UserModel:
user = UserModel(
email="user@example.com",
name="Regular User",
team_id=test_team.id,
is_admin=False
)
created_user = await user_repository.create(user)
return created_user
@pytest.fixture(scope="function")
async def admin_api_key(admin_user: UserModel) -> tuple:
raw_key, hashed_key = generate_api_key(str(admin_user.team_id), str(admin_user.id))
api_key = ApiKeyModel(
key_hash=hashed_key,
user_id=admin_user.id,
team_id=admin_user.team_id,
name="Admin API Key",
description="API key for admin testing",
expiry_date=datetime.utcnow() + timedelta(days=30),
is_active=True
)
created_key = await api_key_repository.create(api_key)
return raw_key, created_key
@pytest.fixture(scope="function")
async def user_api_key(regular_user: UserModel) -> tuple:
raw_key, hashed_key = generate_api_key(str(regular_user.team_id), str(regular_user.id))
api_key = ApiKeyModel(
key_hash=hashed_key,
user_id=regular_user.id,
team_id=regular_user.team_id,
name="User API Key",
description="API key for user testing",
expiry_date=datetime.utcnow() + timedelta(days=30),
is_active=True
)
created_key = await api_key_repository.create(api_key)
return raw_key, created_key