253 lines
7.8 KiB
Python
253 lines
7.8 KiB
Python
import asyncio
|
|
import pytest
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, Any, Generator, List
|
|
from bson import ObjectId
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
from unittest.mock import patch
|
|
|
|
from src.models.team import TeamModel
|
|
from src.models.user import UserModel
|
|
from src.models.api_key import ApiKeyModel
|
|
from src.models.image import ImageModel
|
|
from src.auth.security import generate_api_key
|
|
from src.db.repositories.team_repository import team_repository
|
|
from src.db.repositories.user_repository import user_repository
|
|
from src.db.repositories.api_key_repository import api_key_repository
|
|
|
|
# Add image_repository import - this might not exist yet, so tests will need to handle that
|
|
try:
|
|
from src.db.repositories.image_repository import image_repository
|
|
image_repository_exists = True
|
|
except ImportError:
|
|
image_repository_exists = False
|
|
|
|
|
|
# Mock repositories
|
|
class MockTeamRepository:
|
|
def __init__(self):
|
|
self.teams = {}
|
|
|
|
async def create(self, team: TeamModel) -> TeamModel:
|
|
if not team.id:
|
|
team.id = ObjectId()
|
|
self.teams[str(team.id)] = team
|
|
return team
|
|
|
|
async def get_by_id(self, id: ObjectId) -> TeamModel:
|
|
return self.teams.get(str(id))
|
|
|
|
async def get_all(self) -> List[TeamModel]:
|
|
return list(self.teams.values())
|
|
|
|
async def update(self, id: ObjectId, data: Dict[str, Any]) -> TeamModel:
|
|
team = self.teams.get(str(id))
|
|
if not team:
|
|
return None
|
|
|
|
for key, value in data.items():
|
|
setattr(team, key, value)
|
|
|
|
team.updated_at = datetime.utcnow()
|
|
self.teams[str(id)] = team
|
|
return team
|
|
|
|
async def delete(self, id: ObjectId) -> bool:
|
|
if str(id) in self.teams:
|
|
del self.teams[str(id)]
|
|
return True
|
|
return False
|
|
|
|
|
|
class MockUserRepository:
|
|
def __init__(self):
|
|
self.users = {}
|
|
|
|
async def create(self, user: UserModel) -> UserModel:
|
|
if not user.id:
|
|
user.id = ObjectId()
|
|
self.users[str(user.id)] = user
|
|
return user
|
|
|
|
async def get_by_id(self, id: ObjectId) -> UserModel:
|
|
return self.users.get(str(id))
|
|
|
|
async def get_by_email(self, email: str) -> UserModel:
|
|
for user in self.users.values():
|
|
if user.email == email:
|
|
return user
|
|
return None
|
|
|
|
async def get_by_team(self, team_id: ObjectId) -> List[UserModel]:
|
|
return [u for u in self.users.values() if str(u.team_id) == str(team_id)]
|
|
|
|
|
|
class MockApiKeyRepository:
|
|
def __init__(self):
|
|
self.api_keys = {}
|
|
|
|
async def create(self, api_key: ApiKeyModel) -> ApiKeyModel:
|
|
if not api_key.id:
|
|
api_key.id = ObjectId()
|
|
self.api_keys[str(api_key.id)] = api_key
|
|
return api_key
|
|
|
|
async def get_by_id(self, id: ObjectId) -> ApiKeyModel:
|
|
return self.api_keys.get(str(id))
|
|
|
|
async def get_by_hash(self, key_hash: str) -> ApiKeyModel:
|
|
for key in self.api_keys.values():
|
|
if key.key_hash == key_hash:
|
|
return key
|
|
return None
|
|
|
|
async def get_by_user(self, user_id: ObjectId) -> List[ApiKeyModel]:
|
|
return [k for k in self.api_keys.values() if str(k.user_id) == str(user_id)]
|
|
|
|
async def update_last_used(self, id: ObjectId) -> bool:
|
|
key = self.api_keys.get(str(id))
|
|
if not key:
|
|
return False
|
|
key.last_used = datetime.utcnow()
|
|
self.api_keys[str(id)] = key
|
|
return True
|
|
|
|
async def deactivate(self, id: ObjectId) -> bool:
|
|
key = self.api_keys.get(str(id))
|
|
if not key:
|
|
return False
|
|
key.is_active = False
|
|
self.api_keys[str(id)] = key
|
|
return True
|
|
|
|
|
|
class MockImageRepository:
|
|
def __init__(self):
|
|
self.images = {}
|
|
|
|
async def create(self, image: ImageModel) -> ImageModel:
|
|
if not image.id:
|
|
image.id = ObjectId()
|
|
self.images[str(image.id)] = image
|
|
return image
|
|
|
|
async def get_by_id(self, id: ObjectId) -> ImageModel:
|
|
return self.images.get(str(id))
|
|
|
|
async def get_by_team(self, team_id: ObjectId) -> List[ImageModel]:
|
|
return [img for img in self.images.values() if str(img.team_id) == str(team_id)]
|
|
|
|
async def update(self, id: ObjectId, data: Dict[str, Any]) -> ImageModel:
|
|
image = self.images.get(str(id))
|
|
if not image:
|
|
return None
|
|
|
|
for key, value in data.items():
|
|
setattr(image, key, value)
|
|
|
|
self.images[str(id)] = image
|
|
return image
|
|
|
|
async def delete(self, id: ObjectId) -> bool:
|
|
if str(id) in self.images:
|
|
del self.images[str(id)]
|
|
return True
|
|
return False
|
|
|
|
async def search(self, team_id: ObjectId, query: str = None) -> List[ImageModel]:
|
|
results = [img for img in self.images.values() if str(img.team_id) == str(team_id)]
|
|
|
|
if query:
|
|
query = query.lower()
|
|
results = [img for img in results if
|
|
(img.description and query in img.description.lower()) or
|
|
query in img.filename.lower() or
|
|
query in img.original_filename.lower()]
|
|
|
|
return results
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def event_loop():
|
|
loop = asyncio.get_event_loop()
|
|
yield loop
|
|
loop.close()
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def app() -> FastAPI:
|
|
from main import app
|
|
return app
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def client(app: FastAPI) -> Generator:
|
|
with TestClient(app) as c:
|
|
yield c
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def test_team() -> TeamModel:
|
|
team = TeamModel(
|
|
name="Test Team",
|
|
description="A team for testing"
|
|
)
|
|
created_team = await team_repository.create(team)
|
|
return created_team
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def admin_user(test_team: TeamModel) -> UserModel:
|
|
user = UserModel(
|
|
email="admin@example.com",
|
|
name="Admin User",
|
|
team_id=test_team.id,
|
|
is_admin=True
|
|
)
|
|
created_user = await user_repository.create(user)
|
|
return created_user
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def regular_user(test_team: TeamModel) -> UserModel:
|
|
user = UserModel(
|
|
email="user@example.com",
|
|
name="Regular User",
|
|
team_id=test_team.id,
|
|
is_admin=False
|
|
)
|
|
created_user = await user_repository.create(user)
|
|
return created_user
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def admin_api_key(admin_user: UserModel) -> tuple:
|
|
raw_key, hashed_key = generate_api_key(str(admin_user.team_id), str(admin_user.id))
|
|
api_key = ApiKeyModel(
|
|
key_hash=hashed_key,
|
|
user_id=admin_user.id,
|
|
team_id=admin_user.team_id,
|
|
name="Admin API Key",
|
|
description="API key for admin testing",
|
|
expiry_date=datetime.utcnow() + timedelta(days=30),
|
|
is_active=True
|
|
)
|
|
created_key = await api_key_repository.create(api_key)
|
|
return raw_key, created_key
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def user_api_key(regular_user: UserModel) -> tuple:
|
|
raw_key, hashed_key = generate_api_key(str(regular_user.team_id), str(regular_user.id))
|
|
api_key = ApiKeyModel(
|
|
key_hash=hashed_key,
|
|
user_id=regular_user.id,
|
|
team_id=regular_user.team_id,
|
|
name="User API Key",
|
|
description="API key for user testing",
|
|
expiry_date=datetime.utcnow() + timedelta(days=30),
|
|
is_active=True
|
|
)
|
|
created_key = await api_key_repository.create(api_key)
|
|
return raw_key, created_key |