This commit is contained in:
johnpccd 2025-05-24 12:27:32 +02:00
parent 1bfd3b7d69
commit 1010ed8d4e
6 changed files with 2677 additions and 0 deletions

View File

@ -0,0 +1,620 @@
import pytest
from fastapi.testclient import TestClient
from bson import ObjectId
from datetime import datetime
from src.models.team import TeamModel
from src.models.user import UserModel
from src.models.image import ImageModel
class CollectionModel:
"""Mock collection model for testing"""
def __init__(self, **kwargs):
self.id = kwargs.get('id', ObjectId())
self.name = kwargs.get('name')
self.description = kwargs.get('description')
self.team_id = kwargs.get('team_id')
self.created_by = kwargs.get('created_by')
self.created_at = kwargs.get('created_at', datetime.utcnow())
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
self.metadata = kwargs.get('metadata', {})
self.image_count = kwargs.get('image_count', 0)
@pytest.mark.asyncio
async def test_create_collection(client: TestClient, admin_api_key: tuple, test_team: TeamModel):
"""Test creating a new image collection"""
raw_key, _ = admin_api_key
# Set up the headers with the admin API key
headers = {"X-API-Key": raw_key}
# Create a new collection
response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Test Collection",
"description": "A collection for testing images",
"metadata": {
"category": "test",
"project": "sereact"
}
}
)
# Check response
assert response.status_code == 201
data = response.json()
assert "id" in data
assert data["name"] == "Test Collection"
assert data["description"] == "A collection for testing images"
assert data["team_id"] == str(test_team.id)
assert "created_at" in data
assert data["image_count"] == 0
@pytest.mark.asyncio
async def test_create_collection_non_admin(client: TestClient, user_api_key: tuple):
"""Test that regular users can create collections"""
raw_key, _ = user_api_key
# Set up the headers with a regular user API key
headers = {"X-API-Key": raw_key}
# Create a new collection
response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "User Collection",
"description": "A collection created by a regular user"
}
)
# Check response - regular users should be able to create collections
assert response.status_code == 201
data = response.json()
assert data["name"] == "User Collection"
@pytest.mark.asyncio
async def test_list_collections(client: TestClient, admin_api_key: tuple, test_team: TeamModel):
"""Test listing collections for a team"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# List collections
response = client.get(
"/api/v1/collections",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "collections" in data
assert "total" in data
assert isinstance(data["collections"], list)
@pytest.mark.asyncio
async def test_get_collection_by_id(client: TestClient, admin_api_key: tuple):
"""Test getting a specific collection by ID"""
raw_key, _ = admin_api_key
# First create a collection
headers = {"X-API-Key": raw_key}
create_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Get Test Collection",
"description": "Collection for get testing"
}
)
assert create_response.status_code == 201
collection_id = create_response.json()["id"]
# Get the collection
response = client.get(
f"/api/v1/collections/{collection_id}",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["id"] == collection_id
assert data["name"] == "Get Test Collection"
assert data["description"] == "Collection for get testing"
@pytest.mark.asyncio
async def test_update_collection(client: TestClient, admin_api_key: tuple):
"""Test updating a collection"""
raw_key, _ = admin_api_key
# First create a collection
headers = {"X-API-Key": raw_key}
create_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Update Test Collection",
"description": "Collection for update testing"
}
)
assert create_response.status_code == 201
collection_id = create_response.json()["id"]
# Update the collection
response = client.put(
f"/api/v1/collections/{collection_id}",
headers=headers,
json={
"name": "Updated Collection Name",
"description": "This collection has been updated",
"metadata": {
"updated": True,
"version": 2
}
}
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["id"] == collection_id
assert data["name"] == "Updated Collection Name"
assert data["description"] == "This collection has been updated"
assert data["metadata"]["updated"] is True
assert "updated_at" in data
@pytest.mark.asyncio
async def test_delete_collection(client: TestClient, admin_api_key: tuple):
"""Test deleting a collection"""
raw_key, _ = admin_api_key
# First create a collection
headers = {"X-API-Key": raw_key}
create_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Delete Test Collection",
"description": "Collection for delete testing"
}
)
assert create_response.status_code == 201
collection_id = create_response.json()["id"]
# Delete the collection
response = client.delete(
f"/api/v1/collections/{collection_id}",
headers=headers
)
# Check response
assert response.status_code == 204
# Verify collection is deleted
get_response = client.get(
f"/api/v1/collections/{collection_id}",
headers=headers
)
assert get_response.status_code == 404
@pytest.mark.asyncio
async def test_add_image_to_collection(client: TestClient, admin_api_key: tuple):
"""Test adding an image to a collection"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Create a collection
collection_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Image Collection",
"description": "Collection for image testing"
}
)
collection_id = collection_response.json()["id"]
# Mock image ID (in real implementation, this would be from uploaded image)
image_id = str(ObjectId())
# Add image to collection
response = client.post(
f"/api/v1/collections/{collection_id}/images",
headers=headers,
json={
"image_id": image_id
}
)
# Check response
assert response.status_code == 200
data = response.json()
assert "success" in data
assert data["success"] is True
@pytest.mark.asyncio
async def test_remove_image_from_collection(client: TestClient, admin_api_key: tuple):
"""Test removing an image from a collection"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Create a collection
collection_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Remove Image Collection",
"description": "Collection for image removal testing"
}
)
collection_id = collection_response.json()["id"]
# Mock image ID
image_id = str(ObjectId())
# Add image to collection first
add_response = client.post(
f"/api/v1/collections/{collection_id}/images",
headers=headers,
json={"image_id": image_id}
)
assert add_response.status_code == 200
# Remove image from collection
response = client.delete(
f"/api/v1/collections/{collection_id}/images/{image_id}",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["success"] is True
@pytest.mark.asyncio
async def test_list_images_in_collection(client: TestClient, admin_api_key: tuple):
"""Test listing images in a collection"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Create a collection
collection_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "List Images Collection",
"description": "Collection for listing images"
}
)
collection_id = collection_response.json()["id"]
# List images in collection
response = client.get(
f"/api/v1/collections/{collection_id}/images",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "images" in data
assert "total" in data
assert isinstance(data["images"], list)
@pytest.mark.asyncio
async def test_collection_access_control(client: TestClient, user_api_key: tuple):
"""Test that users can only access collections from their team"""
raw_key, _ = user_api_key
headers = {"X-API-Key": raw_key}
# Try to access a collection from another team
other_collection_id = str(ObjectId())
response = client.get(
f"/api/v1/collections/{other_collection_id}",
headers=headers
)
# Should return 403 or 404 depending on implementation
assert response.status_code in [403, 404]
@pytest.mark.asyncio
async def test_collection_pagination(client: TestClient, admin_api_key: tuple):
"""Test collection listing with pagination"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Test pagination
response = client.get(
"/api/v1/collections?page=1&limit=10",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "collections" in data
assert "pagination" in data
assert "total" in data["pagination"]
assert "page" in data["pagination"]
assert "pages" in data["pagination"]
assert len(data["collections"]) <= 10
@pytest.mark.asyncio
async def test_collection_search(client: TestClient, admin_api_key: tuple):
"""Test searching collections by name or description"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Create a collection with searchable content
create_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Searchable Collection",
"description": "This collection contains vacation photos"
}
)
assert create_response.status_code == 201
# Search for collections
response = client.get(
"/api/v1/collections/search?query=vacation",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "collections" in data
assert "total" in data
# Results should contain collections matching the query
for collection in data["collections"]:
assert ("vacation" in collection["name"].lower() or
"vacation" in collection["description"].lower())
@pytest.mark.asyncio
async def test_collection_statistics(client: TestClient, admin_api_key: tuple):
"""Test getting collection statistics"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Create a collection
collection_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Stats Collection",
"description": "Collection for statistics testing"
}
)
collection_id = collection_response.json()["id"]
# Get collection statistics
response = client.get(
f"/api/v1/collections/{collection_id}/stats",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "image_count" in data
assert "total_size" in data
assert "created_at" in data
assert "last_updated" in data
@pytest.mark.asyncio
async def test_bulk_add_images_to_collection(client: TestClient, admin_api_key: tuple):
"""Test adding multiple images to a collection at once"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Create a collection
collection_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Bulk Add Collection",
"description": "Collection for bulk operations"
}
)
collection_id = collection_response.json()["id"]
# Mock multiple image IDs
image_ids = [str(ObjectId()) for _ in range(5)]
# Bulk add images to collection
response = client.post(
f"/api/v1/collections/{collection_id}/images/bulk",
headers=headers,
json={
"image_ids": image_ids
}
)
# Check response
assert response.status_code == 200
data = response.json()
assert "added_count" in data
assert data["added_count"] == 5
@pytest.mark.asyncio
async def test_collection_export(client: TestClient, admin_api_key: tuple):
"""Test exporting collection metadata"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Create a collection
collection_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Export Collection",
"description": "Collection for export testing",
"metadata": {"category": "test", "project": "sereact"}
}
)
collection_id = collection_response.json()["id"]
# Export collection
response = client.get(
f"/api/v1/collections/{collection_id}/export",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "collection" in data
assert "images" in data
assert data["collection"]["name"] == "Export Collection"
@pytest.mark.asyncio
async def test_collection_duplicate_name_validation(client: TestClient, admin_api_key: tuple):
"""Test that duplicate collection names within a team are handled"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Create first collection
response1 = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Duplicate Name Collection",
"description": "First collection with this name"
}
)
assert response1.status_code == 201
# Try to create second collection with same name
response2 = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Duplicate Name Collection",
"description": "Second collection with same name"
}
)
# Should either allow it or reject it depending on business rules
# For now, let's assume it's allowed but with different handling
assert response2.status_code in [201, 400]
@pytest.mark.asyncio
async def test_collection_metadata_validation(client: TestClient, admin_api_key: tuple):
"""Test validation of collection metadata"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Test with invalid metadata structure
response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Invalid Metadata Collection",
"description": "Collection with invalid metadata",
"metadata": "invalid_metadata_type" # Should be dict
}
)
# Should validate metadata structure
assert response.status_code == 400
@pytest.mark.asyncio
async def test_collection_ownership_transfer(client: TestClient, admin_api_key: tuple):
"""Test transferring collection ownership (admin only)"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Create a collection
collection_response = client.post(
"/api/v1/collections",
headers=headers,
json={
"name": "Transfer Collection",
"description": "Collection for ownership transfer"
}
)
collection_id = collection_response.json()["id"]
# Transfer ownership to another user
new_owner_id = str(ObjectId())
response = client.put(
f"/api/v1/collections/{collection_id}/owner",
headers=headers,
json={
"new_owner_id": new_owner_id
}
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["created_by"] == new_owner_id
@pytest.mark.asyncio
async def test_invalid_collection_id(client: TestClient, admin_api_key: tuple):
"""Test handling of invalid collection IDs"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Try to get collection with invalid ID
response = client.get(
"/api/v1/collections/invalid-id",
headers=headers
)
# Check that the request returns 400 Bad Request
assert response.status_code == 400
assert "detail" in response.json()
@pytest.mark.asyncio
async def test_nonexistent_collection(client: TestClient, admin_api_key: tuple):
"""Test handling of nonexistent collection IDs"""
raw_key, _ = admin_api_key
headers = {"X-API-Key": raw_key}
# Try to get nonexistent collection
nonexistent_id = str(ObjectId())
response = client.get(
f"/api/v1/collections/{nonexistent_id}",
headers=headers
)
# Check that the request returns 404 Not Found
assert response.status_code == 404
assert "detail" in response.json()

549
tests/api/test_users.py Normal file
View File

@ -0,0 +1,549 @@
import pytest
from fastapi.testclient import TestClient
from bson import ObjectId
from src.models.user import UserModel
from src.models.team import TeamModel
from src.db.repositories.user_repository import user_repository
from src.db.repositories.team_repository import team_repository
@pytest.mark.asyncio
async def test_create_user(client: TestClient, admin_api_key: tuple, test_team: TeamModel):
"""Test creating a new user (admin only)"""
raw_key, _ = admin_api_key
# Set up the headers with the admin API key
headers = {"X-API-Key": raw_key}
# Create a new user
response = client.post(
"/api/v1/users",
headers=headers,
json={
"email": "newuser@example.com",
"name": "New User",
"team_id": str(test_team.id),
"is_admin": False
}
)
# Check response
assert response.status_code == 201
data = response.json()
assert "id" in data
assert data["email"] == "newuser@example.com"
assert data["name"] == "New User"
assert data["team_id"] == str(test_team.id)
assert data["is_admin"] is False
assert "created_at" in data
@pytest.mark.asyncio
async def test_create_user_non_admin(client: TestClient, user_api_key: tuple, test_team: TeamModel):
"""Test that non-admin users cannot create users"""
raw_key, _ = user_api_key
# Set up the headers with a non-admin API key
headers = {"X-API-Key": raw_key}
# Try to create a new user
response = client.post(
"/api/v1/users",
headers=headers,
json={
"email": "unauthorized@example.com",
"name": "Unauthorized User",
"team_id": str(test_team.id),
"is_admin": False
}
)
# Check that the request is forbidden
assert response.status_code == 403
assert "detail" in response.json()
assert "admin" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_list_users_in_team(client: TestClient, admin_api_key: tuple, test_team: TeamModel):
"""Test listing users in a team"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# List users in the team
response = client.get(
f"/api/v1/users?team_id={test_team.id}",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "users" in data
assert "total" in data
assert data["total"] >= 1 # Should include at least the admin user
@pytest.mark.asyncio
async def test_list_all_users_admin_only(client: TestClient, admin_api_key: tuple):
"""Test that only admins can list all users across teams"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# List all users (no team filter)
response = client.get(
"/api/v1/users",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "users" in data
assert "total" in data
@pytest.mark.asyncio
async def test_list_users_non_admin_restricted(client: TestClient, user_api_key: tuple):
"""Test that non-admin users can only see their own team"""
raw_key, _ = user_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Try to list all users
response = client.get(
"/api/v1/users",
headers=headers
)
# Should only return users from their own team
assert response.status_code == 200
data = response.json()
assert "users" in data
# All returned users should be from the same team
if data["users"]:
user_team_ids = set(user["team_id"] for user in data["users"])
assert len(user_team_ids) == 1 # Only one team represented
@pytest.mark.asyncio
async def test_get_user_by_id(client: TestClient, admin_api_key: tuple, admin_user: UserModel):
"""Test getting a specific user by ID"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Get the user
response = client.get(
f"/api/v1/users/{admin_user.id}",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["id"] == str(admin_user.id)
assert data["email"] == admin_user.email
assert data["name"] == admin_user.name
assert data["team_id"] == str(admin_user.team_id)
assert data["is_admin"] == admin_user.is_admin
@pytest.mark.asyncio
async def test_get_user_self(client: TestClient, user_api_key: tuple, regular_user: UserModel):
"""Test that users can get their own information"""
raw_key, _ = user_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Get own user information
response = client.get(
f"/api/v1/users/{regular_user.id}",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["id"] == str(regular_user.id)
assert data["email"] == regular_user.email
@pytest.mark.asyncio
async def test_get_user_other_team_forbidden(client: TestClient, user_api_key: tuple):
"""Test that users cannot access users from other teams"""
raw_key, _ = user_api_key
# Create a user in another team
other_team = TeamModel(
name="Other Team",
description="Another team for testing"
)
created_team = await team_repository.create(other_team)
other_user = UserModel(
email="other@example.com",
name="Other User",
team_id=created_team.id,
is_admin=False
)
created_user = await user_repository.create(other_user)
# Set up the headers
headers = {"X-API-Key": raw_key}
# Try to get the other user
response = client.get(
f"/api/v1/users/{created_user.id}",
headers=headers
)
# Check that the request is forbidden
assert response.status_code == 403
assert "detail" in response.json()
@pytest.mark.asyncio
async def test_update_user(client: TestClient, admin_api_key: tuple, regular_user: UserModel):
"""Test updating a user (admin only)"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Update the user
response = client.put(
f"/api/v1/users/{regular_user.id}",
headers=headers,
json={
"name": "Updated User Name",
"email": "updated@example.com",
"is_admin": True
}
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["id"] == str(regular_user.id)
assert data["name"] == "Updated User Name"
assert data["email"] == "updated@example.com"
assert data["is_admin"] is True
assert "updated_at" in data
@pytest.mark.asyncio
async def test_update_user_self(client: TestClient, user_api_key: tuple, regular_user: UserModel):
"""Test that users can update their own basic information"""
raw_key, _ = user_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Update own information (limited fields)
response = client.put(
f"/api/v1/users/{regular_user.id}",
headers=headers,
json={
"name": "Self Updated Name",
"email": "self_updated@example.com"
# Note: is_admin should not be updatable by self
}
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["name"] == "Self Updated Name"
assert data["email"] == "self_updated@example.com"
# Admin status should remain unchanged
assert data["is_admin"] == regular_user.is_admin
@pytest.mark.asyncio
async def test_update_user_non_admin_forbidden(client: TestClient, user_api_key: tuple, admin_user: UserModel):
"""Test that non-admin users cannot update other users"""
raw_key, _ = user_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Try to update another user
response = client.put(
f"/api/v1/users/{admin_user.id}",
headers=headers,
json={
"name": "Unauthorized Update",
"is_admin": False
}
)
# Check that the request is forbidden
assert response.status_code == 403
assert "detail" in response.json()
@pytest.mark.asyncio
async def test_delete_user(client: TestClient, admin_api_key: tuple):
"""Test deleting a user (admin only)"""
raw_key, _ = admin_api_key
# Create a user to delete
user_to_delete = UserModel(
email="delete@example.com",
name="User to Delete",
team_id=ObjectId(),
is_admin=False
)
created_user = await user_repository.create(user_to_delete)
# Set up the headers
headers = {"X-API-Key": raw_key}
# Delete the user
response = client.delete(
f"/api/v1/users/{created_user.id}",
headers=headers
)
# Check response
assert response.status_code == 204
# Verify the user has been deleted
deleted_user = await user_repository.get_by_id(created_user.id)
assert deleted_user is None
@pytest.mark.asyncio
async def test_delete_user_non_admin(client: TestClient, user_api_key: tuple, admin_user: UserModel):
"""Test that non-admin users cannot delete users"""
raw_key, _ = user_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Try to delete a user
response = client.delete(
f"/api/v1/users/{admin_user.id}",
headers=headers
)
# Check that the request is forbidden
assert response.status_code == 403
assert "detail" in response.json()
assert "admin" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_get_user_activity(client: TestClient, admin_api_key: tuple, regular_user: UserModel):
"""Test getting user activity/statistics"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Get user activity
response = client.get(
f"/api/v1/users/{regular_user.id}/activity",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "user_id" in data
assert "images_uploaded" in data
assert "last_login" in data
assert "api_key_usage" in data
@pytest.mark.asyncio
async def test_change_user_team(client: TestClient, admin_api_key: tuple, regular_user: UserModel):
"""Test moving a user to a different team (admin only)"""
raw_key, _ = admin_api_key
# Create another team
new_team = TeamModel(
name="New Team",
description="A new team for testing"
)
created_team = await team_repository.create(new_team)
# Set up the headers
headers = {"X-API-Key": raw_key}
# Move user to new team
response = client.put(
f"/api/v1/users/{regular_user.id}/team",
headers=headers,
json={
"team_id": str(created_team.id)
}
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["team_id"] == str(created_team.id)
@pytest.mark.asyncio
async def test_user_search(client: TestClient, admin_api_key: tuple):
"""Test searching for users by email or name"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Search for users
response = client.get(
"/api/v1/users/search?query=admin",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "users" in data
assert "total" in data
# Results should contain users matching the query
for user in data["users"]:
assert "admin" in user["email"].lower() or "admin" in user["name"].lower()
@pytest.mark.asyncio
async def test_user_pagination(client: TestClient, admin_api_key: tuple):
"""Test user listing with pagination"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Test pagination
response = client.get(
"/api/v1/users?page=1&limit=10",
headers=headers
)
# Check response
assert response.status_code == 200
data = response.json()
assert "users" in data
assert "pagination" in data
assert "total" in data["pagination"]
assert "page" in data["pagination"]
assert "pages" in data["pagination"]
assert len(data["users"]) <= 10
@pytest.mark.asyncio
async def test_invalid_user_id(client: TestClient, admin_api_key: tuple):
"""Test handling of invalid user IDs"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Try to get user with invalid ID
response = client.get(
"/api/v1/users/invalid-id",
headers=headers
)
# Check that the request returns 400 Bad Request
assert response.status_code == 400
assert "detail" in response.json()
assert "invalid" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_nonexistent_user(client: TestClient, admin_api_key: tuple):
"""Test handling of nonexistent user IDs"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Try to get nonexistent user
nonexistent_id = str(ObjectId())
response = client.get(
f"/api/v1/users/{nonexistent_id}",
headers=headers
)
# Check that the request returns 404 Not Found
assert response.status_code == 404
assert "detail" in response.json()
assert "not found" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_duplicate_email_validation(client: TestClient, admin_api_key: tuple, regular_user: UserModel):
"""Test that duplicate emails are not allowed"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Try to create user with existing email
response = client.post(
"/api/v1/users",
headers=headers,
json={
"email": regular_user.email, # Duplicate email
"name": "Duplicate Email User",
"team_id": str(regular_user.team_id),
"is_admin": False
}
)
# Check that the request is rejected
assert response.status_code == 400
assert "detail" in response.json()
assert "email" in response.json()["detail"].lower()
@pytest.mark.asyncio
async def test_user_role_management(client: TestClient, admin_api_key: tuple, regular_user: UserModel):
"""Test managing user roles and permissions"""
raw_key, _ = admin_api_key
# Set up the headers
headers = {"X-API-Key": raw_key}
# Promote user to admin
response = client.put(
f"/api/v1/users/{regular_user.id}/role",
headers=headers,
json={
"is_admin": True
}
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["is_admin"] is True
# Demote user back to regular
response = client.put(
f"/api/v1/users/{regular_user.id}/role",
headers=headers,
json={
"is_admin": False
}
)
# Check response
assert response.status_code == 200
data = response.json()
assert data["is_admin"] is False

268
tests/auth/test_security.py Normal file
View File

@ -0,0 +1,268 @@
import pytest
import hashlib
from datetime import datetime, timedelta
from bson import ObjectId
from src.auth.security import (
generate_api_key,
hash_api_key,
verify_api_key,
create_access_token,
verify_token
)
from src.models.api_key import ApiKeyModel
from src.models.user import UserModel
from src.models.team import TeamModel
class TestApiKeySecurity:
"""Test API key generation and validation security"""
def test_generate_api_key(self):
"""Test API key generation produces unique, secure keys"""
team_id = str(ObjectId())
user_id = str(ObjectId())
# Generate multiple keys
key1, hash1 = generate_api_key(team_id, user_id)
key2, hash2 = generate_api_key(team_id, user_id)
# Keys should be different
assert key1 != key2
assert hash1 != hash2
# Keys should be sufficiently long
assert len(key1) >= 32
assert len(hash1) >= 32
# Keys should contain team and user info
assert team_id in key1 or user_id in key1
def test_hash_api_key_consistency(self):
"""Test that hashing the same key produces the same hash"""
key = "test-api-key-123"
hash1 = hash_api_key(key)
hash2 = hash_api_key(key)
assert hash1 == hash2
assert len(hash1) >= 32 # Should be a proper hash length
def test_verify_api_key_valid(self):
"""Test verifying a valid API key"""
team_id = str(ObjectId())
user_id = str(ObjectId())
raw_key, key_hash = generate_api_key(team_id, user_id)
# Verification should succeed
assert verify_api_key(raw_key, key_hash) is True
def test_verify_api_key_invalid(self):
"""Test verifying an invalid API key"""
team_id = str(ObjectId())
user_id = str(ObjectId())
raw_key, key_hash = generate_api_key(team_id, user_id)
# Wrong key should fail
assert verify_api_key("wrong-key", key_hash) is False
# Wrong hash should fail
assert verify_api_key(raw_key, "wrong-hash") is False
def test_api_key_format(self):
"""Test that generated API keys follow expected format"""
team_id = str(ObjectId())
user_id = str(ObjectId())
raw_key, key_hash = generate_api_key(team_id, user_id)
# Key should have expected structure (prefix.hash format)
assert "." in raw_key
parts = raw_key.split(".")
assert len(parts) == 2
# First part should be readable prefix
prefix = parts[0]
assert len(prefix) >= 8
# Second part should be hash-like
hash_part = parts[1]
assert len(hash_part) >= 32
class TestTokenSecurity:
"""Test JWT token generation and validation"""
def test_create_access_token(self):
"""Test creating access tokens"""
user_id = str(ObjectId())
team_id = str(ObjectId())
token = create_access_token(
data={"user_id": user_id, "team_id": team_id}
)
assert token is not None
assert isinstance(token, str)
assert len(token) > 50 # JWT tokens are typically long
def test_verify_token_valid(self):
"""Test verifying a valid token"""
user_id = str(ObjectId())
team_id = str(ObjectId())
token = create_access_token(
data={"user_id": user_id, "team_id": team_id}
)
payload = verify_token(token)
assert payload is not None
assert payload["user_id"] == user_id
assert payload["team_id"] == team_id
def test_verify_token_invalid(self):
"""Test verifying an invalid token"""
# Invalid token should return None
assert verify_token("invalid-token") is None
assert verify_token("") is None
assert verify_token(None) is None
def test_token_expiration(self):
"""Test token expiration handling"""
user_id = str(ObjectId())
# Create token with very short expiration
token = create_access_token(
data={"user_id": user_id},
expires_delta=timedelta(seconds=-1) # Already expired
)
# Should fail verification due to expiration
payload = verify_token(token)
assert payload is None
class TestSecurityValidation:
"""Test security validation functions"""
def test_validate_team_access(self):
"""Test team access validation"""
team_id = ObjectId()
user_team_id = ObjectId()
# User should have access to their own team
from src.auth.security import validate_team_access
assert validate_team_access(str(team_id), str(team_id)) is True
# User should not have access to other teams
assert validate_team_access(str(user_team_id), str(team_id)) is False
def test_validate_admin_permissions(self):
"""Test admin permission validation"""
from src.auth.security import validate_admin_permissions
admin_user = UserModel(
email="admin@test.com",
name="Admin User",
team_id=ObjectId(),
is_admin=True
)
regular_user = UserModel(
email="user@test.com",
name="Regular User",
team_id=ObjectId(),
is_admin=False
)
assert validate_admin_permissions(admin_user) is True
assert validate_admin_permissions(regular_user) is False
def test_rate_limiting_validation(self):
"""Test rate limiting for API keys"""
# This would test rate limiting functionality
# Implementation depends on the actual rate limiting strategy
pass
def test_api_key_expiration_check(self):
"""Test API key expiration validation"""
team_id = ObjectId()
user_id = ObjectId()
# Create expired API key
expired_key = ApiKeyModel(
key_hash="test-hash",
user_id=user_id,
team_id=team_id,
name="Expired Key",
expiry_date=datetime.utcnow() - timedelta(days=1),
is_active=True
)
# Create valid API key
valid_key = ApiKeyModel(
key_hash="test-hash-2",
user_id=user_id,
team_id=team_id,
name="Valid Key",
expiry_date=datetime.utcnow() + timedelta(days=30),
is_active=True
)
from src.auth.security import is_api_key_valid
assert is_api_key_valid(expired_key) is False
assert is_api_key_valid(valid_key) is True
def test_inactive_api_key_check(self):
"""Test inactive API key validation"""
team_id = ObjectId()
user_id = ObjectId()
# Create inactive API key
inactive_key = ApiKeyModel(
key_hash="test-hash",
user_id=user_id,
team_id=team_id,
name="Inactive Key",
expiry_date=datetime.utcnow() + timedelta(days=30),
is_active=False
)
from src.auth.security import is_api_key_valid
assert is_api_key_valid(inactive_key) is False
class TestSecurityHeaders:
"""Test security headers and middleware"""
def test_cors_headers(self):
"""Test CORS header configuration"""
# This would test CORS configuration
pass
def test_security_headers(self):
"""Test security headers like X-Frame-Options, etc."""
# This would test security headers
pass
def test_https_enforcement(self):
"""Test HTTPS enforcement in production"""
# This would test HTTPS redirect functionality
pass
class TestPasswordSecurity:
"""Test password hashing and validation if implemented"""
def test_password_hashing(self):
"""Test password hashing functionality"""
# If password authentication is implemented
pass
def test_password_validation(self):
"""Test password strength validation"""
# If password authentication is implemented
pass

View File

@ -0,0 +1,360 @@
import pytest
import numpy as np
from unittest.mock import patch, MagicMock, AsyncMock
from bson import ObjectId
from io import BytesIO
from src.services.embedding_service import EmbeddingService
from src.models.image import ImageModel
class TestEmbeddingService:
"""Test embedding generation for images"""
@pytest.fixture
def mock_vision_client(self):
"""Mock Google Cloud Vision client"""
mock_client = MagicMock()
mock_response = MagicMock()
mock_response.image_properties_annotation.dominant_colors.colors = []
mock_response.label_annotations = []
mock_response.object_localizations = []
mock_response.text_annotations = []
mock_client.annotate_image.return_value = mock_response
return mock_client
@pytest.fixture
def embedding_service(self, mock_vision_client):
"""Create embedding service with mocked dependencies"""
with patch('src.services.embedding_service.vision') as mock_vision:
mock_vision.ImageAnnotatorClient.return_value = mock_vision_client
service = EmbeddingService()
service.client = mock_vision_client
return service
@pytest.fixture
def sample_image_data(self):
"""Create sample image data"""
# Create a simple test image (1x1 pixel PNG)
image_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\nIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xdd\x8d\xb4\x1c\x00\x00\x00\x00IEND\xaeB`\x82'
return BytesIO(image_data)
@pytest.fixture
def sample_image_model(self):
"""Create a sample image model"""
return ImageModel(
filename="test-image.jpg",
original_filename="test_image.jpg",
file_size=1024,
content_type="image/jpeg",
storage_path="images/test-image.jpg",
team_id=ObjectId(),
uploader_id=ObjectId()
)
def test_generate_embedding_from_image(self, embedding_service, sample_image_data):
"""Test generating embeddings from image data"""
# Mock the embedding generation
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
# Generate embedding
embedding = embedding_service.generate_embedding(sample_image_data)
# Verify embedding was generated
assert embedding is not None
assert len(embedding) == 512
assert isinstance(embedding, list)
assert all(isinstance(x, (int, float)) for x in embedding)
def test_extract_image_features(self, embedding_service, sample_image_data):
"""Test extracting features from images using Vision API"""
# Mock Vision API response
mock_response = MagicMock()
mock_response.label_annotations = [
MagicMock(description="cat", score=0.95),
MagicMock(description="animal", score=0.87),
MagicMock(description="pet", score=0.82)
]
mock_response.object_localizations = [
MagicMock(name="Cat", score=0.9)
]
mock_response.image_properties_annotation.dominant_colors.colors = [
MagicMock(color=MagicMock(red=255, green=100, blue=50), score=0.8)
]
embedding_service.client.annotate_image.return_value = mock_response
# Extract features
features = embedding_service.extract_image_features(sample_image_data)
# Verify features were extracted
assert 'labels' in features
assert 'objects' in features
assert 'colors' in features
assert len(features['labels']) == 3
assert features['labels'][0]['description'] == "cat"
assert features['labels'][0]['score'] == 0.95
def test_generate_embedding_with_metadata(self, embedding_service, sample_image_data, sample_image_model):
"""Test generating embeddings with image metadata"""
mock_embedding = np.random.rand(512).tolist()
mock_features = {
'labels': [{'description': 'cat', 'score': 0.95}],
'objects': [{'name': 'Cat', 'score': 0.9}],
'colors': [{'red': 255, 'green': 100, 'blue': 50, 'score': 0.8}]
}
with patch.object(embedding_service, '_extract_features') as mock_extract_features, \
patch.object(embedding_service, 'extract_image_features') as mock_extract_metadata:
mock_extract_features.return_value = mock_embedding
mock_extract_metadata.return_value = mock_features
# Generate embedding with metadata
result = embedding_service.generate_embedding_with_metadata(
sample_image_data, sample_image_model
)
# Verify result structure
assert 'embedding' in result
assert 'metadata' in result
assert 'model' in result
assert len(result['embedding']) == 512
assert result['metadata']['labels'][0]['description'] == 'cat'
assert result['model'] == 'clip' # or whatever model is used
def test_batch_generate_embeddings(self, embedding_service):
"""Test generating embeddings for multiple images in batch"""
# Create multiple image data samples
image_batch = []
for i in range(3):
image_data = BytesIO(b'fake_image_data_' + str(i).encode())
image_model = ImageModel(
filename=f"image{i}.jpg",
original_filename=f"image{i}.jpg",
file_size=1024,
content_type="image/jpeg",
storage_path=f"images/image{i}.jpg",
team_id=ObjectId(),
uploader_id=ObjectId()
)
image_batch.append((image_data, image_model))
# Mock embedding generation
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate:
mock_generate.return_value = {
'embedding': mock_embedding,
'metadata': {'labels': []},
'model': 'clip'
}
# Generate batch embeddings
results = embedding_service.batch_generate_embeddings(image_batch)
# Verify batch results
assert len(results) == 3
assert all('embedding' in result for result in results)
assert all('metadata' in result for result in results)
def test_embedding_model_consistency(self, embedding_service, sample_image_data):
"""Test that the same image produces consistent embeddings"""
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
# Generate embedding twice
embedding1 = embedding_service.generate_embedding(sample_image_data)
sample_image_data.seek(0) # Reset stream position
embedding2 = embedding_service.generate_embedding(sample_image_data)
# Embeddings should be identical for the same image
assert embedding1 == embedding2
def test_embedding_dimension_validation(self, embedding_service):
"""Test that embeddings have the correct dimensions"""
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
# Validate embedding dimensions
assert embedding_service.validate_embedding_dimensions(mock_embedding) is True
# Test wrong dimensions
wrong_embedding = np.random.rand(256).tolist()
assert embedding_service.validate_embedding_dimensions(wrong_embedding) is False
def test_handle_unsupported_image_format(self, embedding_service):
"""Test handling of unsupported image formats"""
# Create invalid image data
invalid_data = BytesIO(b'not_an_image')
# Should raise appropriate exception
with pytest.raises(ValueError):
embedding_service.generate_embedding(invalid_data)
def test_handle_corrupted_image(self, embedding_service):
"""Test handling of corrupted image data"""
# Create corrupted image data
corrupted_data = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00corrupted')
# Should handle gracefully
with pytest.raises(Exception):
embedding_service.generate_embedding(corrupted_data)
def test_vision_api_error_handling(self, embedding_service, sample_image_data):
"""Test handling of Vision API errors"""
# Mock Vision API error
embedding_service.client.annotate_image.side_effect = Exception("Vision API error")
# Should handle the error gracefully
with pytest.raises(Exception):
embedding_service.extract_image_features(sample_image_data)
def test_embedding_caching(self, embedding_service, sample_image_data):
"""Test caching of embeddings for the same image"""
# This would test caching functionality if implemented
pass
def test_embedding_quality_metrics(self, embedding_service, sample_image_data):
"""Test quality metrics for generated embeddings"""
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
# Generate embedding
embedding = embedding_service.generate_embedding(sample_image_data)
# Check embedding quality metrics
quality_score = embedding_service.calculate_embedding_quality(embedding)
assert 0 <= quality_score <= 1
def test_different_image_types(self, embedding_service):
"""Test embedding generation for different image types"""
image_types = [
('image/jpeg', b'fake_jpeg_data'),
('image/png', b'fake_png_data'),
('image/webp', b'fake_webp_data')
]
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
for content_type, data in image_types:
image_data = BytesIO(data)
# Should handle different image types
embedding = embedding_service.generate_embedding(image_data)
assert len(embedding) == 512
def test_large_image_handling(self, embedding_service):
"""Test handling of large images"""
# Create large image data (simulated)
large_image_data = BytesIO(b'x' * (10 * 1024 * 1024)) # 10MB
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = mock_embedding
# Should handle large images
embedding = embedding_service.generate_embedding(large_image_data)
assert len(embedding) == 512
def test_embedding_normalization(self, embedding_service, sample_image_data):
"""Test that embeddings are properly normalized"""
# Generate raw embedding
raw_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
mock_extract.return_value = raw_embedding
# Generate normalized embedding
embedding = embedding_service.generate_embedding(sample_image_data, normalize=True)
# Check if embedding is normalized (L2 norm should be 1)
norm = np.linalg.norm(embedding)
assert abs(norm - 1.0) < 0.001 # Allow small floating point errors
class TestEmbeddingServiceIntegration:
"""Integration tests for embedding service with other components"""
def test_embedding_to_vector_store_integration(self, embedding_service, sample_image_data, sample_image_model):
"""Test integration with vector store service"""
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate, \
patch('src.services.vector_store.VectorStoreService') as mock_vector_store:
mock_generate.return_value = {
'embedding': mock_embedding,
'metadata': {'labels': [{'description': 'cat', 'score': 0.95}]},
'model': 'clip'
}
mock_store = mock_vector_store.return_value
mock_store.store_embedding.return_value = 'embedding_id_123'
# Process image and store embedding
result = embedding_service.process_and_store_image(
sample_image_data, sample_image_model
)
# Verify integration
assert result['embedding_id'] == 'embedding_id_123'
mock_store.store_embedding.assert_called_once()
def test_pubsub_trigger_integration(self, embedding_service):
"""Test integration with Pub/Sub message processing"""
# Mock Pub/Sub message
mock_message = {
'image_id': str(ObjectId()),
'storage_path': 'images/test.jpg',
'team_id': str(ObjectId())
}
with patch.object(embedding_service, 'process_image_from_storage') as mock_process:
mock_process.return_value = {'embedding_id': 'emb123'}
# Process Pub/Sub message
result = embedding_service.handle_pubsub_message(mock_message)
# Verify message processing
assert result['embedding_id'] == 'emb123'
mock_process.assert_called_once_with(
mock_message['storage_path'],
mock_message['image_id'],
mock_message['team_id']
)
def test_cloud_function_deployment(self, embedding_service):
"""Test Cloud Function deployment compatibility"""
# Test that the service can be initialized in a Cloud Function environment
# This would test environment variable loading, authentication, etc.
pass
def test_error_recovery_and_retry(self, embedding_service, sample_image_data):
"""Test error recovery and retry mechanisms"""
# Mock transient error followed by success
mock_embedding = np.random.rand(512).tolist()
with patch.object(embedding_service, '_extract_features') as mock_extract:
# First call fails, second succeeds
mock_extract.side_effect = [Exception("Transient error"), mock_embedding]
# Should retry and succeed
embedding = embedding_service.generate_embedding_with_retry(
sample_image_data, max_retries=2
)
assert len(embedding) == 512
assert mock_extract.call_count == 2

View File

@ -0,0 +1,489 @@
import pytest
import numpy as np
from unittest.mock import patch, MagicMock
from bson import ObjectId
from io import BytesIO
from PIL import Image
from src.services.image_processor import ImageProcessor
from src.models.image import ImageModel
class TestImageProcessor:
"""Test image processing functionality"""
@pytest.fixture
def image_processor(self):
"""Create image processor instance"""
return ImageProcessor()
@pytest.fixture
def sample_image_data(self):
"""Create sample image data"""
# Create a simple test image using PIL
img = Image.new('RGB', (800, 600), color='red')
img_bytes = BytesIO()
img.save(img_bytes, format='JPEG')
img_bytes.seek(0)
return img_bytes
@pytest.fixture
def sample_png_image(self):
"""Create sample PNG image data"""
img = Image.new('RGBA', (400, 300), color=(255, 0, 0, 128))
img_bytes = BytesIO()
img.save(img_bytes, format='PNG')
img_bytes.seek(0)
return img_bytes
@pytest.fixture
def sample_image_model(self):
"""Create a sample image model"""
return ImageModel(
filename="test-image.jpg",
original_filename="test_image.jpg",
file_size=1024,
content_type="image/jpeg",
storage_path="images/test-image.jpg",
team_id=ObjectId(),
uploader_id=ObjectId()
)
def test_extract_image_metadata(self, image_processor, sample_image_data):
"""Test extracting basic image metadata"""
# Extract metadata
metadata = image_processor.extract_metadata(sample_image_data)
# Verify metadata extraction
assert 'width' in metadata
assert 'height' in metadata
assert 'format' in metadata
assert 'mode' in metadata
assert metadata['width'] == 800
assert metadata['height'] == 600
assert metadata['format'] == 'JPEG'
def test_extract_exif_data(self, image_processor):
"""Test extracting EXIF data from images"""
# Create image with EXIF data (simulated)
img = Image.new('RGB', (100, 100), color='blue')
img_bytes = BytesIO()
img.save(img_bytes, format='JPEG')
img_bytes.seek(0)
# Extract EXIF data
exif_data = image_processor.extract_exif_data(img_bytes)
# Verify EXIF extraction (may be empty for generated images)
assert isinstance(exif_data, dict)
def test_resize_image(self, image_processor, sample_image_data):
"""Test resizing images while maintaining aspect ratio"""
# Resize image
resized_data = image_processor.resize_image(
sample_image_data,
max_width=400,
max_height=300
)
# Verify resized image
assert resized_data is not None
# Check new dimensions
resized_img = Image.open(resized_data)
assert resized_img.width <= 400
assert resized_img.height <= 300
# Aspect ratio should be maintained
original_ratio = 800 / 600
new_ratio = resized_img.width / resized_img.height
assert abs(original_ratio - new_ratio) < 0.01
def test_generate_thumbnail(self, image_processor, sample_image_data):
"""Test generating image thumbnails"""
# Generate thumbnail
thumbnail_data = image_processor.generate_thumbnail(
sample_image_data,
size=(150, 150)
)
# Verify thumbnail
assert thumbnail_data is not None
# Check thumbnail dimensions
thumbnail_img = Image.open(thumbnail_data)
assert thumbnail_img.width <= 150
assert thumbnail_img.height <= 150
def test_optimize_image_quality(self, image_processor, sample_image_data):
"""Test optimizing image quality and file size"""
# Get original size
original_size = len(sample_image_data.getvalue())
# Optimize image
optimized_data = image_processor.optimize_image(
sample_image_data,
quality=85,
optimize=True
)
# Verify optimization
assert optimized_data is not None
optimized_size = len(optimized_data.getvalue())
# Optimized image should typically be smaller or similar size
assert optimized_size <= original_size * 1.1 # Allow 10% tolerance
def test_convert_image_format(self, image_processor, sample_png_image):
"""Test converting between image formats"""
# Convert PNG to JPEG
jpeg_data = image_processor.convert_format(
sample_png_image,
target_format='JPEG'
)
# Verify conversion
assert jpeg_data is not None
# Check converted image
converted_img = Image.open(jpeg_data)
assert converted_img.format == 'JPEG'
def test_detect_image_colors(self, image_processor, sample_image_data):
"""Test detecting dominant colors in images"""
# Detect colors
colors = image_processor.detect_dominant_colors(
sample_image_data,
num_colors=5
)
# Verify color detection
assert isinstance(colors, list)
assert len(colors) <= 5
# Each color should have RGB values and percentage
for color in colors:
assert 'rgb' in color
assert 'percentage' in color
assert len(color['rgb']) == 3
assert 0 <= color['percentage'] <= 100
def test_validate_image_format(self, image_processor, sample_image_data):
"""Test validating supported image formats"""
# Valid image should pass validation
is_valid = image_processor.validate_image_format(sample_image_data)
assert is_valid is True
# Invalid data should fail validation
invalid_data = BytesIO(b'not_an_image')
is_valid = image_processor.validate_image_format(invalid_data)
assert is_valid is False
def test_calculate_image_hash(self, image_processor, sample_image_data):
"""Test calculating perceptual hash for duplicate detection"""
# Calculate hash
image_hash = image_processor.calculate_perceptual_hash(sample_image_data)
# Verify hash
assert image_hash is not None
assert isinstance(image_hash, str)
assert len(image_hash) > 0
# Same image should produce same hash
sample_image_data.seek(0)
hash2 = image_processor.calculate_perceptual_hash(sample_image_data)
assert image_hash == hash2
def test_detect_image_orientation(self, image_processor, sample_image_data):
"""Test detecting and correcting image orientation"""
# Detect orientation
orientation = image_processor.detect_orientation(sample_image_data)
# Verify orientation detection
assert orientation in [0, 90, 180, 270]
# Auto-correct orientation if needed
corrected_data = image_processor.auto_correct_orientation(sample_image_data)
assert corrected_data is not None
def test_extract_text_from_image(self, image_processor):
"""Test OCR text extraction from images"""
# Create image with text (simulated)
img = Image.new('RGB', (200, 100), color='white')
img_bytes = BytesIO()
img.save(img_bytes, format='JPEG')
img_bytes.seek(0)
with patch('src.services.image_processor.pytesseract') as mock_ocr:
mock_ocr.image_to_string.return_value = "Sample text"
# Extract text
text = image_processor.extract_text(img_bytes)
# Verify text extraction
assert text == "Sample text"
mock_ocr.image_to_string.assert_called_once()
def test_batch_process_images(self, image_processor):
"""Test batch processing multiple images"""
# Create batch of images
image_batch = []
for i in range(3):
img = Image.new('RGB', (100, 100), color=(i*80, 0, 0))
img_bytes = BytesIO()
img.save(img_bytes, format='JPEG')
img_bytes.seek(0)
image_batch.append(img_bytes)
# Process batch
results = image_processor.batch_process(
image_batch,
operations=['resize', 'thumbnail', 'metadata']
)
# Verify batch processing
assert len(results) == 3
for result in results:
assert 'metadata' in result
assert 'resized' in result
assert 'thumbnail' in result
def test_image_quality_assessment(self, image_processor, sample_image_data):
"""Test assessing image quality metrics"""
# Assess quality
quality_metrics = image_processor.assess_quality(sample_image_data)
# Verify quality metrics
assert 'sharpness' in quality_metrics
assert 'brightness' in quality_metrics
assert 'contrast' in quality_metrics
assert 'overall_score' in quality_metrics
# Scores should be in valid ranges
assert 0 <= quality_metrics['overall_score'] <= 100
def test_watermark_addition(self, image_processor, sample_image_data):
"""Test adding watermarks to images"""
# Add text watermark
watermarked_data = image_processor.add_watermark(
sample_image_data,
watermark_text="SEREACT",
position="bottom-right",
opacity=0.5
)
# Verify watermark addition
assert watermarked_data is not None
# Check that image is still valid
watermarked_img = Image.open(watermarked_data)
assert watermarked_img.format == 'JPEG'
def test_image_compression_levels(self, image_processor, sample_image_data):
"""Test different compression levels"""
original_size = len(sample_image_data.getvalue())
# Test different quality levels
for quality in [95, 85, 75, 60]:
compressed_data = image_processor.compress_image(
sample_image_data,
quality=quality
)
compressed_size = len(compressed_data.getvalue())
# Lower quality should generally result in smaller files
if quality < 95:
assert compressed_size <= original_size
# Reset stream position
sample_image_data.seek(0)
def test_handle_corrupted_image(self, image_processor):
"""Test handling of corrupted image data"""
# Create corrupted image data
corrupted_data = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00corrupted')
# Should handle gracefully
with pytest.raises(Exception):
image_processor.extract_metadata(corrupted_data)
def test_large_image_processing(self, image_processor):
"""Test processing very large images"""
# Create large image (simulated)
large_img = Image.new('RGB', (4000, 3000), color='green')
img_bytes = BytesIO()
large_img.save(img_bytes, format='JPEG', quality=95)
img_bytes.seek(0)
# Process large image
metadata = image_processor.extract_metadata(img_bytes)
# Verify processing
assert metadata['width'] == 4000
assert metadata['height'] == 3000
# Test resizing large image
img_bytes.seek(0)
resized_data = image_processor.resize_image(
img_bytes,
max_width=1920,
max_height=1080
)
resized_img = Image.open(resized_data)
assert resized_img.width <= 1920
assert resized_img.height <= 1080
def test_progressive_jpeg_support(self, image_processor, sample_image_data):
"""Test support for progressive JPEG format"""
# Convert to progressive JPEG
progressive_data = image_processor.convert_to_progressive_jpeg(
sample_image_data
)
# Verify progressive format
assert progressive_data is not None
# Check that it's still a valid JPEG
progressive_img = Image.open(progressive_data)
assert progressive_img.format == 'JPEG'
class TestImageProcessorIntegration:
"""Integration tests for image processor with other services"""
def test_integration_with_storage_service(self, image_processor, sample_image_data):
"""Test integration with storage service"""
with patch('src.services.storage.StorageService') as mock_storage:
mock_storage_instance = mock_storage.return_value
mock_storage_instance.upload_file.return_value = (
'images/processed.jpg', 'image/jpeg', 1024, {}
)
# Process and upload image
result = image_processor.process_and_upload(
sample_image_data,
operations=['resize', 'optimize'],
team_id=str(ObjectId())
)
# Verify integration
assert 'storage_path' in result
mock_storage_instance.upload_file.assert_called_once()
def test_integration_with_embedding_service(self, image_processor, sample_image_data):
"""Test integration with embedding service"""
with patch('src.services.embedding_service.EmbeddingService') as mock_embedding:
mock_embedding_instance = mock_embedding.return_value
mock_embedding_instance.generate_embedding.return_value = [0.1] * 512
# Process image and generate embedding
result = image_processor.process_for_embedding(sample_image_data)
# Verify integration
assert 'processed_image' in result
assert 'embedding' in result
mock_embedding_instance.generate_embedding.assert_called_once()
def test_pubsub_message_processing(self, image_processor):
"""Test processing images from Pub/Sub messages"""
# Mock Pub/Sub message
message_data = {
'image_id': str(ObjectId()),
'storage_path': 'images/raw/test.jpg',
'operations': ['resize', 'thumbnail', 'optimize']
}
with patch.object(image_processor, 'process_from_storage') as mock_process:
mock_process.return_value = {
'processed_path': 'images/processed/test.jpg',
'thumbnail_path': 'images/thumbnails/test.jpg'
}
# Process message
result = image_processor.handle_processing_message(message_data)
# Verify message processing
assert 'processed_path' in result
mock_process.assert_called_once()
def test_error_handling_and_retry(self, image_processor, sample_image_data):
"""Test error handling and retry mechanisms"""
# Mock transient error followed by success
with patch.object(image_processor, 'extract_metadata') as mock_extract:
# First call fails, second succeeds
mock_extract.side_effect = [
Exception("Transient error"),
{'width': 800, 'height': 600, 'format': 'JPEG'}
]
# Should retry and succeed
metadata = image_processor.extract_metadata_with_retry(
sample_image_data,
max_retries=2
)
assert metadata['width'] == 800
assert mock_extract.call_count == 2
class TestImageProcessorPerformance:
"""Performance tests for image processing"""
def test_processing_speed_benchmarks(self, image_processor):
"""Test processing speed for different image sizes"""
import time
sizes = [(100, 100), (500, 500), (1000, 1000)]
for width, height in sizes:
# Create test image
img = Image.new('RGB', (width, height), color='blue')
img_bytes = BytesIO()
img.save(img_bytes, format='JPEG')
img_bytes.seek(0)
# Measure processing time
start_time = time.time()
metadata = image_processor.extract_metadata(img_bytes)
processing_time = time.time() - start_time
# Verify reasonable processing time (adjust thresholds as needed)
assert processing_time < 5.0 # Should process within 5 seconds
assert metadata['width'] == width
assert metadata['height'] == height
def test_memory_usage_optimization(self, image_processor):
"""Test memory usage during image processing"""
# This would test memory usage patterns
# Implementation depends on memory profiling tools
pass
def test_concurrent_processing(self, image_processor):
"""Test concurrent image processing"""
import concurrent.futures
# Create multiple test images
images = []
for i in range(5):
img = Image.new('RGB', (200, 200), color=(i*50, 0, 0))
img_bytes = BytesIO()
img.save(img_bytes, format='JPEG')
img_bytes.seek(0)
images.append(img_bytes)
# Process concurrently
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(image_processor.extract_metadata, img)
for img in images
]
results = [future.result() for future in futures]
# Verify all processed successfully
assert len(results) == 5
for result in results:
assert 'width' in result
assert 'height' in result

View File

@ -0,0 +1,391 @@
import pytest
import numpy as np
from unittest.mock import patch, MagicMock, AsyncMock
from bson import ObjectId
from src.services.vector_store import VectorStoreService
from src.models.image import ImageModel
class TestVectorStoreService:
"""Test vector store operations for semantic search"""
@pytest.fixture
def mock_pinecone_index(self):
"""Mock Pinecone index for testing"""
mock_index = MagicMock()
mock_index.upsert = MagicMock()
mock_index.query = MagicMock()
mock_index.delete = MagicMock()
mock_index.describe_index_stats = MagicMock()
return mock_index
@pytest.fixture
def vector_store_service(self, mock_pinecone_index):
"""Create vector store service with mocked dependencies"""
with patch('src.services.vector_store.pinecone') as mock_pinecone:
mock_pinecone.Index.return_value = mock_pinecone_index
service = VectorStoreService()
service.index = mock_pinecone_index
return service
@pytest.fixture
def sample_embedding(self):
"""Generate a sample embedding vector"""
return np.random.rand(512).tolist() # 512-dimensional vector
@pytest.fixture
def sample_image(self):
"""Create a sample image model"""
return ImageModel(
filename="test-image.jpg",
original_filename="test_image.jpg",
file_size=1024,
content_type="image/jpeg",
storage_path="images/test-image.jpg",
team_id=ObjectId(),
uploader_id=ObjectId(),
description="A test image",
tags=["test", "image"]
)
def test_store_embedding(self, vector_store_service, sample_embedding, sample_image):
"""Test storing an embedding in the vector database"""
# Store the embedding
embedding_id = vector_store_service.store_embedding(
image_id=str(sample_image.id),
embedding=sample_embedding,
metadata={
"filename": sample_image.filename,
"team_id": str(sample_image.team_id),
"tags": sample_image.tags,
"description": sample_image.description
}
)
# Verify the embedding was stored
assert embedding_id is not None
vector_store_service.index.upsert.assert_called_once()
# Check the upsert call arguments
call_args = vector_store_service.index.upsert.call_args
vectors = call_args[1]['vectors']
assert len(vectors) == 1
assert vectors[0]['id'] == embedding_id
assert len(vectors[0]['values']) == len(sample_embedding)
assert 'metadata' in vectors[0]
def test_search_similar_images(self, vector_store_service, sample_embedding):
"""Test searching for similar images using vector similarity"""
# Mock search results
mock_results = {
'matches': [
{
'id': 'embedding1',
'score': 0.95,
'metadata': {
'image_id': str(ObjectId()),
'filename': 'similar1.jpg',
'team_id': str(ObjectId()),
'tags': ['cat', 'animal']
}
},
{
'id': 'embedding2',
'score': 0.87,
'metadata': {
'image_id': str(ObjectId()),
'filename': 'similar2.jpg',
'team_id': str(ObjectId()),
'tags': ['dog', 'animal']
}
}
]
}
vector_store_service.index.query.return_value = mock_results
# Perform search
results = vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=str(ObjectId()),
top_k=10,
score_threshold=0.8
)
# Verify search was performed
vector_store_service.index.query.assert_called_once()
# Check results
assert len(results) == 2
assert results[0]['score'] == 0.95
assert results[1]['score'] == 0.87
assert all('image_id' in result for result in results)
def test_search_with_filters(self, vector_store_service, sample_embedding):
"""Test searching with metadata filters"""
team_id = str(ObjectId())
# Perform search with team filter
vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=team_id,
top_k=5,
filters={"tags": {"$in": ["cat", "dog"]}}
)
# Verify filter was applied
call_args = vector_store_service.index.query.call_args
assert 'filter' in call_args[1]
assert call_args[1]['filter']['team_id'] == team_id
def test_delete_embedding(self, vector_store_service):
"""Test deleting an embedding from the vector database"""
embedding_id = "test-embedding-123"
# Delete the embedding
success = vector_store_service.delete_embedding(embedding_id)
# Verify deletion was attempted
vector_store_service.index.delete.assert_called_once_with(ids=[embedding_id])
assert success is True
def test_batch_store_embeddings(self, vector_store_service, sample_embedding):
"""Test storing multiple embeddings in batch"""
# Create batch data
batch_data = []
for i in range(5):
batch_data.append({
'image_id': str(ObjectId()),
'embedding': sample_embedding,
'metadata': {
'filename': f'image{i}.jpg',
'team_id': str(ObjectId()),
'tags': [f'tag{i}']
}
})
# Store batch
embedding_ids = vector_store_service.batch_store_embeddings(batch_data)
# Verify batch storage
assert len(embedding_ids) == 5
vector_store_service.index.upsert.assert_called_once()
# Check batch upsert call
call_args = vector_store_service.index.upsert.call_args
vectors = call_args[1]['vectors']
assert len(vectors) == 5
def test_get_index_stats(self, vector_store_service):
"""Test getting vector database statistics"""
# Mock stats response
mock_stats = {
'total_vector_count': 1000,
'dimension': 512,
'index_fullness': 0.1
}
vector_store_service.index.describe_index_stats.return_value = mock_stats
# Get stats
stats = vector_store_service.get_index_stats()
# Verify stats retrieval
vector_store_service.index.describe_index_stats.assert_called_once()
assert stats['total_vector_count'] == 1000
assert stats['dimension'] == 512
def test_search_with_score_threshold(self, vector_store_service, sample_embedding):
"""Test filtering search results by score threshold"""
# Mock results with varying scores
mock_results = {
'matches': [
{'id': 'emb1', 'score': 0.95, 'metadata': {'image_id': '1'}},
{'id': 'emb2', 'score': 0.75, 'metadata': {'image_id': '2'}},
{'id': 'emb3', 'score': 0.65, 'metadata': {'image_id': '3'}},
{'id': 'emb4', 'score': 0.45, 'metadata': {'image_id': '4'}}
]
}
vector_store_service.index.query.return_value = mock_results
# Search with score threshold
results = vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=str(ObjectId()),
top_k=10,
score_threshold=0.7
)
# Only results above threshold should be returned
assert len(results) == 2
assert all(result['score'] >= 0.7 for result in results)
def test_update_embedding_metadata(self, vector_store_service):
"""Test updating metadata for an existing embedding"""
embedding_id = "test-embedding-123"
new_metadata = {
'tags': ['updated', 'tag'],
'description': 'Updated description'
}
# Update metadata
success = vector_store_service.update_embedding_metadata(
embedding_id, new_metadata
)
# Verify update was attempted
# This would depend on the actual implementation
assert success is True
def test_search_by_image_id(self, vector_store_service):
"""Test searching for a specific image's embedding"""
image_id = str(ObjectId())
# Mock search by metadata
mock_results = {
'matches': [
{
'id': 'embedding1',
'score': 1.0,
'metadata': {
'image_id': image_id,
'filename': 'target.jpg'
}
}
]
}
vector_store_service.index.query.return_value = mock_results
# Search by image ID
result = vector_store_service.get_embedding_by_image_id(image_id)
# Verify result
assert result is not None
assert result['metadata']['image_id'] == image_id
def test_bulk_delete_embeddings(self, vector_store_service):
"""Test deleting multiple embeddings"""
embedding_ids = ['emb1', 'emb2', 'emb3']
# Delete multiple embeddings
success = vector_store_service.bulk_delete_embeddings(embedding_ids)
# Verify bulk deletion
vector_store_service.index.delete.assert_called_once_with(ids=embedding_ids)
assert success is True
def test_search_pagination(self, vector_store_service, sample_embedding):
"""Test paginated search results"""
# This would test pagination if implemented
# Implementation depends on how pagination is handled in the vector store
pass
def test_vector_dimension_validation(self, vector_store_service):
"""Test validation of embedding dimensions"""
# Test with wrong dimension
wrong_dimension_embedding = np.random.rand(256).tolist() # Wrong size
with pytest.raises(ValueError):
vector_store_service.store_embedding(
image_id=str(ObjectId()),
embedding=wrong_dimension_embedding,
metadata={}
)
def test_connection_error_handling(self, vector_store_service):
"""Test handling of connection errors"""
# Mock connection error
vector_store_service.index.query.side_effect = Exception("Connection failed")
# Search should handle the error gracefully
with pytest.raises(Exception):
vector_store_service.search_similar(
query_embedding=[0.1] * 512,
team_id=str(ObjectId()),
top_k=10
)
def test_empty_search_results(self, vector_store_service, sample_embedding):
"""Test handling of empty search results"""
# Mock empty results
vector_store_service.index.query.return_value = {'matches': []}
# Search should return empty list
results = vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=str(ObjectId()),
top_k=10
)
assert results == []
class TestVectorStoreIntegration:
"""Integration tests for vector store with other services"""
def test_embedding_lifecycle(self, vector_store_service, sample_embedding, sample_image):
"""Test complete embedding lifecycle: store, search, update, delete"""
# Store embedding
embedding_id = vector_store_service.store_embedding(
image_id=str(sample_image.id),
embedding=sample_embedding,
metadata={'filename': sample_image.filename}
)
# Search for similar embeddings
mock_results = {
'matches': [
{
'id': embedding_id,
'score': 1.0,
'metadata': {'image_id': str(sample_image.id)}
}
]
}
vector_store_service.index.query.return_value = mock_results
results = vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=str(sample_image.team_id),
top_k=1
)
assert len(results) == 1
assert results[0]['id'] == embedding_id
# Delete embedding
success = vector_store_service.delete_embedding(embedding_id)
assert success is True
def test_team_isolation(self, vector_store_service, sample_embedding):
"""Test that team data is properly isolated"""
team1_id = str(ObjectId())
team2_id = str(ObjectId())
# Mock search results that should be filtered by team
mock_results = {
'matches': [
{
'id': 'emb1',
'score': 0.9,
'metadata': {'image_id': '1', 'team_id': team1_id}
},
{
'id': 'emb2',
'score': 0.8,
'metadata': {'image_id': '2', 'team_id': team2_id}
}
]
}
vector_store_service.index.query.return_value = mock_results
# Search for team1 should only return team1 results
results = vector_store_service.search_similar(
query_embedding=sample_embedding,
team_id=team1_id,
top_k=10
)
# Verify team filter was applied in the query
call_args = vector_store_service.index.query.call_args
assert 'filter' in call_args[1]
assert call_args[1]['filter']['team_id'] == team1_id