import asyncio import pytest from datetime import datetime, timedelta from typing import Dict, Any, Generator, List from bson import ObjectId from fastapi import FastAPI from fastapi.testclient import TestClient from src.models.team import TeamModel from src.models.user import UserModel from src.models.api_key import ApiKeyModel from src.models.image import ImageModel from src.auth.security import generate_api_key from src.db.repositories.team_repository import team_repository from src.db.repositories.user_repository import user_repository from src.db.repositories.api_key_repository import api_key_repository # Add image_repository import - this might not exist yet, so tests will need to handle that try: from src.db.repositories.image_repository import image_repository image_repository_exists = True except ImportError: image_repository_exists = False # Mock repositories class MockTeamRepository: def __init__(self): self.teams = {} async def create(self, team: TeamModel) -> TeamModel: if not team.id: team.id = ObjectId() self.teams[str(team.id)] = team return team async def get_by_id(self, id: ObjectId) -> TeamModel: return self.teams.get(str(id)) async def get_all(self) -> List[TeamModel]: return list(self.teams.values()) async def update(self, id: ObjectId, data: Dict[str, Any]) -> TeamModel: team = self.teams.get(str(id)) if not team: return None for key, value in data.items(): setattr(team, key, value) team.updated_at = datetime.utcnow() self.teams[str(id)] = team return team async def delete(self, id: ObjectId) -> bool: if str(id) in self.teams: del self.teams[str(id)] return True return False class MockUserRepository: def __init__(self): self.users = {} async def create(self, user: UserModel) -> UserModel: if not user.id: user.id = ObjectId() self.users[str(user.id)] = user return user async def get_by_id(self, id: ObjectId) -> UserModel: return self.users.get(str(id)) async def get_by_email(self, email: str) -> UserModel: for user in self.users.values(): if user.email == email: return user return None async def get_by_team(self, team_id: ObjectId) -> List[UserModel]: return [u for u in self.users.values() if str(u.team_id) == str(team_id)] class MockApiKeyRepository: def __init__(self): self.api_keys = {} async def create(self, api_key: ApiKeyModel) -> ApiKeyModel: if not api_key.id: api_key.id = ObjectId() self.api_keys[str(api_key.id)] = api_key return api_key async def get_by_id(self, id: ObjectId) -> ApiKeyModel: return self.api_keys.get(str(id)) async def get_by_hash(self, key_hash: str) -> ApiKeyModel: for key in self.api_keys.values(): if key.key_hash == key_hash: return key return None async def get_by_user(self, user_id: ObjectId) -> List[ApiKeyModel]: return [k for k in self.api_keys.values() if str(k.user_id) == str(user_id)] async def update_last_used(self, id: ObjectId) -> bool: key = self.api_keys.get(str(id)) if not key: return False key.last_used = datetime.utcnow() self.api_keys[str(id)] = key return True async def deactivate(self, id: ObjectId) -> bool: key = self.api_keys.get(str(id)) if not key: return False key.is_active = False self.api_keys[str(id)] = key return True class MockImageRepository: def __init__(self): self.images = {} async def create(self, image: ImageModel) -> ImageModel: if not image.id: image.id = ObjectId() self.images[str(image.id)] = image return image async def get_by_id(self, id: ObjectId) -> ImageModel: return self.images.get(str(id)) async def get_by_team(self, team_id: ObjectId) -> List[ImageModel]: return [img for img in self.images.values() if str(img.team_id) == str(team_id)] async def update(self, id: ObjectId, data: Dict[str, Any]) -> ImageModel: image = self.images.get(str(id)) if not image: return None for key, value in data.items(): setattr(image, key, value) self.images[str(id)] = image return image async def delete(self, id: ObjectId) -> bool: if str(id) in self.images: del self.images[str(id)] return True return False async def search(self, team_id: ObjectId, query: str = None) -> List[ImageModel]: results = [img for img in self.images.values() if str(img.team_id) == str(team_id)] if query: query = query.lower() results = [img for img in results if (img.description and query in img.description.lower()) or query in img.filename.lower() or query in img.original_filename.lower()] return results @pytest.fixture(scope="module") def event_loop(): loop = asyncio.get_event_loop() yield loop loop.close() @pytest.fixture(scope="module") def app() -> FastAPI: from main import app # Replace repositories with mocks team_repository.__class__ = MockTeamRepository user_repository.__class__ = MockUserRepository api_key_repository.__class__ = MockApiKeyRepository # Try to replace image_repository if it exists if image_repository_exists: from src.db.repositories.image_repository import image_repository image_repository.__class__ = MockImageRepository return app @pytest.fixture(scope="module") def client(app: FastAPI) -> Generator: with TestClient(app) as c: yield c @pytest.fixture(scope="function") async def test_team() -> TeamModel: team = TeamModel( name="Test Team", description="A team for testing" ) created_team = await team_repository.create(team) return created_team @pytest.fixture(scope="function") async def admin_user(test_team: TeamModel) -> UserModel: user = UserModel( email="admin@example.com", name="Admin User", team_id=test_team.id, is_admin=True ) created_user = await user_repository.create(user) return created_user @pytest.fixture(scope="function") async def regular_user(test_team: TeamModel) -> UserModel: user = UserModel( email="user@example.com", name="Regular User", team_id=test_team.id, is_admin=False ) created_user = await user_repository.create(user) return created_user @pytest.fixture(scope="function") async def admin_api_key(admin_user: UserModel) -> tuple: raw_key, hashed_key = generate_api_key(str(admin_user.team_id), str(admin_user.id)) api_key = ApiKeyModel( key_hash=hashed_key, user_id=admin_user.id, team_id=admin_user.team_id, name="Admin API Key", description="API key for admin testing", expiry_date=datetime.utcnow() + timedelta(days=30), is_active=True ) created_key = await api_key_repository.create(api_key) return raw_key, created_key @pytest.fixture(scope="function") async def user_api_key(regular_user: UserModel) -> tuple: raw_key, hashed_key = generate_api_key(str(regular_user.team_id), str(regular_user.id)) api_key = ApiKeyModel( key_hash=hashed_key, user_id=regular_user.id, team_id=regular_user.team_id, name="User API Key", description="API key for user testing", expiry_date=datetime.utcnow() + timedelta(days=30), is_active=True ) created_key = await api_key_repository.create(api_key) return raw_key, created_key