diff --git a/tests/api/test_collections.py b/tests/api/test_collections.py new file mode 100644 index 0000000..740640f --- /dev/null +++ b/tests/api/test_collections.py @@ -0,0 +1,620 @@ +import pytest +from fastapi.testclient import TestClient +from bson import ObjectId +from datetime import datetime + +from src.models.team import TeamModel +from src.models.user import UserModel +from src.models.image import ImageModel + + +class CollectionModel: + """Mock collection model for testing""" + def __init__(self, **kwargs): + self.id = kwargs.get('id', ObjectId()) + self.name = kwargs.get('name') + self.description = kwargs.get('description') + self.team_id = kwargs.get('team_id') + self.created_by = kwargs.get('created_by') + self.created_at = kwargs.get('created_at', datetime.utcnow()) + self.updated_at = kwargs.get('updated_at', datetime.utcnow()) + self.metadata = kwargs.get('metadata', {}) + self.image_count = kwargs.get('image_count', 0) + + +@pytest.mark.asyncio +async def test_create_collection(client: TestClient, admin_api_key: tuple, test_team: TeamModel): + """Test creating a new image collection""" + raw_key, _ = admin_api_key + + # Set up the headers with the admin API key + headers = {"X-API-Key": raw_key} + + # Create a new collection + response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Test Collection", + "description": "A collection for testing images", + "metadata": { + "category": "test", + "project": "sereact" + } + } + ) + + # Check response + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert data["name"] == "Test Collection" + assert data["description"] == "A collection for testing images" + assert data["team_id"] == str(test_team.id) + assert "created_at" in data + assert data["image_count"] == 0 + + +@pytest.mark.asyncio +async def test_create_collection_non_admin(client: TestClient, user_api_key: tuple): + """Test that regular users can create collections""" + raw_key, _ = user_api_key + + # Set up the headers with a regular user API key + headers = {"X-API-Key": raw_key} + + # Create a new collection + response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "User Collection", + "description": "A collection created by a regular user" + } + ) + + # Check response - regular users should be able to create collections + assert response.status_code == 201 + data = response.json() + assert data["name"] == "User Collection" + + +@pytest.mark.asyncio +async def test_list_collections(client: TestClient, admin_api_key: tuple, test_team: TeamModel): + """Test listing collections for a team""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # List collections + response = client.get( + "/api/v1/collections", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "collections" in data + assert "total" in data + assert isinstance(data["collections"], list) + + +@pytest.mark.asyncio +async def test_get_collection_by_id(client: TestClient, admin_api_key: tuple): + """Test getting a specific collection by ID""" + raw_key, _ = admin_api_key + + # First create a collection + headers = {"X-API-Key": raw_key} + + create_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Get Test Collection", + "description": "Collection for get testing" + } + ) + + assert create_response.status_code == 201 + collection_id = create_response.json()["id"] + + # Get the collection + response = client.get( + f"/api/v1/collections/{collection_id}", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["id"] == collection_id + assert data["name"] == "Get Test Collection" + assert data["description"] == "Collection for get testing" + + +@pytest.mark.asyncio +async def test_update_collection(client: TestClient, admin_api_key: tuple): + """Test updating a collection""" + raw_key, _ = admin_api_key + + # First create a collection + headers = {"X-API-Key": raw_key} + + create_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Update Test Collection", + "description": "Collection for update testing" + } + ) + + assert create_response.status_code == 201 + collection_id = create_response.json()["id"] + + # Update the collection + response = client.put( + f"/api/v1/collections/{collection_id}", + headers=headers, + json={ + "name": "Updated Collection Name", + "description": "This collection has been updated", + "metadata": { + "updated": True, + "version": 2 + } + } + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["id"] == collection_id + assert data["name"] == "Updated Collection Name" + assert data["description"] == "This collection has been updated" + assert data["metadata"]["updated"] is True + assert "updated_at" in data + + +@pytest.mark.asyncio +async def test_delete_collection(client: TestClient, admin_api_key: tuple): + """Test deleting a collection""" + raw_key, _ = admin_api_key + + # First create a collection + headers = {"X-API-Key": raw_key} + + create_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Delete Test Collection", + "description": "Collection for delete testing" + } + ) + + assert create_response.status_code == 201 + collection_id = create_response.json()["id"] + + # Delete the collection + response = client.delete( + f"/api/v1/collections/{collection_id}", + headers=headers + ) + + # Check response + assert response.status_code == 204 + + # Verify collection is deleted + get_response = client.get( + f"/api/v1/collections/{collection_id}", + headers=headers + ) + assert get_response.status_code == 404 + + +@pytest.mark.asyncio +async def test_add_image_to_collection(client: TestClient, admin_api_key: tuple): + """Test adding an image to a collection""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Create a collection + collection_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Image Collection", + "description": "Collection for image testing" + } + ) + collection_id = collection_response.json()["id"] + + # Mock image ID (in real implementation, this would be from uploaded image) + image_id = str(ObjectId()) + + # Add image to collection + response = client.post( + f"/api/v1/collections/{collection_id}/images", + headers=headers, + json={ + "image_id": image_id + } + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "success" in data + assert data["success"] is True + + +@pytest.mark.asyncio +async def test_remove_image_from_collection(client: TestClient, admin_api_key: tuple): + """Test removing an image from a collection""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Create a collection + collection_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Remove Image Collection", + "description": "Collection for image removal testing" + } + ) + collection_id = collection_response.json()["id"] + + # Mock image ID + image_id = str(ObjectId()) + + # Add image to collection first + add_response = client.post( + f"/api/v1/collections/{collection_id}/images", + headers=headers, + json={"image_id": image_id} + ) + assert add_response.status_code == 200 + + # Remove image from collection + response = client.delete( + f"/api/v1/collections/{collection_id}/images/{image_id}", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + + +@pytest.mark.asyncio +async def test_list_images_in_collection(client: TestClient, admin_api_key: tuple): + """Test listing images in a collection""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Create a collection + collection_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "List Images Collection", + "description": "Collection for listing images" + } + ) + collection_id = collection_response.json()["id"] + + # List images in collection + response = client.get( + f"/api/v1/collections/{collection_id}/images", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "images" in data + assert "total" in data + assert isinstance(data["images"], list) + + +@pytest.mark.asyncio +async def test_collection_access_control(client: TestClient, user_api_key: tuple): + """Test that users can only access collections from their team""" + raw_key, _ = user_api_key + headers = {"X-API-Key": raw_key} + + # Try to access a collection from another team + other_collection_id = str(ObjectId()) + + response = client.get( + f"/api/v1/collections/{other_collection_id}", + headers=headers + ) + + # Should return 403 or 404 depending on implementation + assert response.status_code in [403, 404] + + +@pytest.mark.asyncio +async def test_collection_pagination(client: TestClient, admin_api_key: tuple): + """Test collection listing with pagination""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Test pagination + response = client.get( + "/api/v1/collections?page=1&limit=10", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "collections" in data + assert "pagination" in data + assert "total" in data["pagination"] + assert "page" in data["pagination"] + assert "pages" in data["pagination"] + assert len(data["collections"]) <= 10 + + +@pytest.mark.asyncio +async def test_collection_search(client: TestClient, admin_api_key: tuple): + """Test searching collections by name or description""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Create a collection with searchable content + create_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Searchable Collection", + "description": "This collection contains vacation photos" + } + ) + assert create_response.status_code == 201 + + # Search for collections + response = client.get( + "/api/v1/collections/search?query=vacation", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "collections" in data + assert "total" in data + + # Results should contain collections matching the query + for collection in data["collections"]: + assert ("vacation" in collection["name"].lower() or + "vacation" in collection["description"].lower()) + + +@pytest.mark.asyncio +async def test_collection_statistics(client: TestClient, admin_api_key: tuple): + """Test getting collection statistics""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Create a collection + collection_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Stats Collection", + "description": "Collection for statistics testing" + } + ) + collection_id = collection_response.json()["id"] + + # Get collection statistics + response = client.get( + f"/api/v1/collections/{collection_id}/stats", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "image_count" in data + assert "total_size" in data + assert "created_at" in data + assert "last_updated" in data + + +@pytest.mark.asyncio +async def test_bulk_add_images_to_collection(client: TestClient, admin_api_key: tuple): + """Test adding multiple images to a collection at once""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Create a collection + collection_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Bulk Add Collection", + "description": "Collection for bulk operations" + } + ) + collection_id = collection_response.json()["id"] + + # Mock multiple image IDs + image_ids = [str(ObjectId()) for _ in range(5)] + + # Bulk add images to collection + response = client.post( + f"/api/v1/collections/{collection_id}/images/bulk", + headers=headers, + json={ + "image_ids": image_ids + } + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "added_count" in data + assert data["added_count"] == 5 + + +@pytest.mark.asyncio +async def test_collection_export(client: TestClient, admin_api_key: tuple): + """Test exporting collection metadata""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Create a collection + collection_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Export Collection", + "description": "Collection for export testing", + "metadata": {"category": "test", "project": "sereact"} + } + ) + collection_id = collection_response.json()["id"] + + # Export collection + response = client.get( + f"/api/v1/collections/{collection_id}/export", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "collection" in data + assert "images" in data + assert data["collection"]["name"] == "Export Collection" + + +@pytest.mark.asyncio +async def test_collection_duplicate_name_validation(client: TestClient, admin_api_key: tuple): + """Test that duplicate collection names within a team are handled""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Create first collection + response1 = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Duplicate Name Collection", + "description": "First collection with this name" + } + ) + assert response1.status_code == 201 + + # Try to create second collection with same name + response2 = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Duplicate Name Collection", + "description": "Second collection with same name" + } + ) + + # Should either allow it or reject it depending on business rules + # For now, let's assume it's allowed but with different handling + assert response2.status_code in [201, 400] + + +@pytest.mark.asyncio +async def test_collection_metadata_validation(client: TestClient, admin_api_key: tuple): + """Test validation of collection metadata""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Test with invalid metadata structure + response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Invalid Metadata Collection", + "description": "Collection with invalid metadata", + "metadata": "invalid_metadata_type" # Should be dict + } + ) + + # Should validate metadata structure + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_collection_ownership_transfer(client: TestClient, admin_api_key: tuple): + """Test transferring collection ownership (admin only)""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Create a collection + collection_response = client.post( + "/api/v1/collections", + headers=headers, + json={ + "name": "Transfer Collection", + "description": "Collection for ownership transfer" + } + ) + collection_id = collection_response.json()["id"] + + # Transfer ownership to another user + new_owner_id = str(ObjectId()) + response = client.put( + f"/api/v1/collections/{collection_id}/owner", + headers=headers, + json={ + "new_owner_id": new_owner_id + } + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["created_by"] == new_owner_id + + +@pytest.mark.asyncio +async def test_invalid_collection_id(client: TestClient, admin_api_key: tuple): + """Test handling of invalid collection IDs""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Try to get collection with invalid ID + response = client.get( + "/api/v1/collections/invalid-id", + headers=headers + ) + + # Check that the request returns 400 Bad Request + assert response.status_code == 400 + assert "detail" in response.json() + + +@pytest.mark.asyncio +async def test_nonexistent_collection(client: TestClient, admin_api_key: tuple): + """Test handling of nonexistent collection IDs""" + raw_key, _ = admin_api_key + headers = {"X-API-Key": raw_key} + + # Try to get nonexistent collection + nonexistent_id = str(ObjectId()) + response = client.get( + f"/api/v1/collections/{nonexistent_id}", + headers=headers + ) + + # Check that the request returns 404 Not Found + assert response.status_code == 404 + assert "detail" in response.json() \ No newline at end of file diff --git a/tests/api/test_users.py b/tests/api/test_users.py new file mode 100644 index 0000000..8d6010f --- /dev/null +++ b/tests/api/test_users.py @@ -0,0 +1,549 @@ +import pytest +from fastapi.testclient import TestClient +from bson import ObjectId + +from src.models.user import UserModel +from src.models.team import TeamModel +from src.db.repositories.user_repository import user_repository +from src.db.repositories.team_repository import team_repository + + +@pytest.mark.asyncio +async def test_create_user(client: TestClient, admin_api_key: tuple, test_team: TeamModel): + """Test creating a new user (admin only)""" + raw_key, _ = admin_api_key + + # Set up the headers with the admin API key + headers = {"X-API-Key": raw_key} + + # Create a new user + response = client.post( + "/api/v1/users", + headers=headers, + json={ + "email": "newuser@example.com", + "name": "New User", + "team_id": str(test_team.id), + "is_admin": False + } + ) + + # Check response + assert response.status_code == 201 + data = response.json() + assert "id" in data + assert data["email"] == "newuser@example.com" + assert data["name"] == "New User" + assert data["team_id"] == str(test_team.id) + assert data["is_admin"] is False + assert "created_at" in data + + +@pytest.mark.asyncio +async def test_create_user_non_admin(client: TestClient, user_api_key: tuple, test_team: TeamModel): + """Test that non-admin users cannot create users""" + raw_key, _ = user_api_key + + # Set up the headers with a non-admin API key + headers = {"X-API-Key": raw_key} + + # Try to create a new user + response = client.post( + "/api/v1/users", + headers=headers, + json={ + "email": "unauthorized@example.com", + "name": "Unauthorized User", + "team_id": str(test_team.id), + "is_admin": False + } + ) + + # Check that the request is forbidden + assert response.status_code == 403 + assert "detail" in response.json() + assert "admin" in response.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_list_users_in_team(client: TestClient, admin_api_key: tuple, test_team: TeamModel): + """Test listing users in a team""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # List users in the team + response = client.get( + f"/api/v1/users?team_id={test_team.id}", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "users" in data + assert "total" in data + assert data["total"] >= 1 # Should include at least the admin user + + +@pytest.mark.asyncio +async def test_list_all_users_admin_only(client: TestClient, admin_api_key: tuple): + """Test that only admins can list all users across teams""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # List all users (no team filter) + response = client.get( + "/api/v1/users", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "users" in data + assert "total" in data + + +@pytest.mark.asyncio +async def test_list_users_non_admin_restricted(client: TestClient, user_api_key: tuple): + """Test that non-admin users can only see their own team""" + raw_key, _ = user_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Try to list all users + response = client.get( + "/api/v1/users", + headers=headers + ) + + # Should only return users from their own team + assert response.status_code == 200 + data = response.json() + assert "users" in data + # All returned users should be from the same team + if data["users"]: + user_team_ids = set(user["team_id"] for user in data["users"]) + assert len(user_team_ids) == 1 # Only one team represented + + +@pytest.mark.asyncio +async def test_get_user_by_id(client: TestClient, admin_api_key: tuple, admin_user: UserModel): + """Test getting a specific user by ID""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Get the user + response = client.get( + f"/api/v1/users/{admin_user.id}", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(admin_user.id) + assert data["email"] == admin_user.email + assert data["name"] == admin_user.name + assert data["team_id"] == str(admin_user.team_id) + assert data["is_admin"] == admin_user.is_admin + + +@pytest.mark.asyncio +async def test_get_user_self(client: TestClient, user_api_key: tuple, regular_user: UserModel): + """Test that users can get their own information""" + raw_key, _ = user_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Get own user information + response = client.get( + f"/api/v1/users/{regular_user.id}", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(regular_user.id) + assert data["email"] == regular_user.email + + +@pytest.mark.asyncio +async def test_get_user_other_team_forbidden(client: TestClient, user_api_key: tuple): + """Test that users cannot access users from other teams""" + raw_key, _ = user_api_key + + # Create a user in another team + other_team = TeamModel( + name="Other Team", + description="Another team for testing" + ) + created_team = await team_repository.create(other_team) + + other_user = UserModel( + email="other@example.com", + name="Other User", + team_id=created_team.id, + is_admin=False + ) + created_user = await user_repository.create(other_user) + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Try to get the other user + response = client.get( + f"/api/v1/users/{created_user.id}", + headers=headers + ) + + # Check that the request is forbidden + assert response.status_code == 403 + assert "detail" in response.json() + + +@pytest.mark.asyncio +async def test_update_user(client: TestClient, admin_api_key: tuple, regular_user: UserModel): + """Test updating a user (admin only)""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Update the user + response = client.put( + f"/api/v1/users/{regular_user.id}", + headers=headers, + json={ + "name": "Updated User Name", + "email": "updated@example.com", + "is_admin": True + } + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["id"] == str(regular_user.id) + assert data["name"] == "Updated User Name" + assert data["email"] == "updated@example.com" + assert data["is_admin"] is True + assert "updated_at" in data + + +@pytest.mark.asyncio +async def test_update_user_self(client: TestClient, user_api_key: tuple, regular_user: UserModel): + """Test that users can update their own basic information""" + raw_key, _ = user_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Update own information (limited fields) + response = client.put( + f"/api/v1/users/{regular_user.id}", + headers=headers, + json={ + "name": "Self Updated Name", + "email": "self_updated@example.com" + # Note: is_admin should not be updatable by self + } + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["name"] == "Self Updated Name" + assert data["email"] == "self_updated@example.com" + # Admin status should remain unchanged + assert data["is_admin"] == regular_user.is_admin + + +@pytest.mark.asyncio +async def test_update_user_non_admin_forbidden(client: TestClient, user_api_key: tuple, admin_user: UserModel): + """Test that non-admin users cannot update other users""" + raw_key, _ = user_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Try to update another user + response = client.put( + f"/api/v1/users/{admin_user.id}", + headers=headers, + json={ + "name": "Unauthorized Update", + "is_admin": False + } + ) + + # Check that the request is forbidden + assert response.status_code == 403 + assert "detail" in response.json() + + +@pytest.mark.asyncio +async def test_delete_user(client: TestClient, admin_api_key: tuple): + """Test deleting a user (admin only)""" + raw_key, _ = admin_api_key + + # Create a user to delete + user_to_delete = UserModel( + email="delete@example.com", + name="User to Delete", + team_id=ObjectId(), + is_admin=False + ) + created_user = await user_repository.create(user_to_delete) + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Delete the user + response = client.delete( + f"/api/v1/users/{created_user.id}", + headers=headers + ) + + # Check response + assert response.status_code == 204 + + # Verify the user has been deleted + deleted_user = await user_repository.get_by_id(created_user.id) + assert deleted_user is None + + +@pytest.mark.asyncio +async def test_delete_user_non_admin(client: TestClient, user_api_key: tuple, admin_user: UserModel): + """Test that non-admin users cannot delete users""" + raw_key, _ = user_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Try to delete a user + response = client.delete( + f"/api/v1/users/{admin_user.id}", + headers=headers + ) + + # Check that the request is forbidden + assert response.status_code == 403 + assert "detail" in response.json() + assert "admin" in response.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_get_user_activity(client: TestClient, admin_api_key: tuple, regular_user: UserModel): + """Test getting user activity/statistics""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Get user activity + response = client.get( + f"/api/v1/users/{regular_user.id}/activity", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "user_id" in data + assert "images_uploaded" in data + assert "last_login" in data + assert "api_key_usage" in data + + +@pytest.mark.asyncio +async def test_change_user_team(client: TestClient, admin_api_key: tuple, regular_user: UserModel): + """Test moving a user to a different team (admin only)""" + raw_key, _ = admin_api_key + + # Create another team + new_team = TeamModel( + name="New Team", + description="A new team for testing" + ) + created_team = await team_repository.create(new_team) + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Move user to new team + response = client.put( + f"/api/v1/users/{regular_user.id}/team", + headers=headers, + json={ + "team_id": str(created_team.id) + } + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["team_id"] == str(created_team.id) + + +@pytest.mark.asyncio +async def test_user_search(client: TestClient, admin_api_key: tuple): + """Test searching for users by email or name""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Search for users + response = client.get( + "/api/v1/users/search?query=admin", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "users" in data + assert "total" in data + + # Results should contain users matching the query + for user in data["users"]: + assert "admin" in user["email"].lower() or "admin" in user["name"].lower() + + +@pytest.mark.asyncio +async def test_user_pagination(client: TestClient, admin_api_key: tuple): + """Test user listing with pagination""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Test pagination + response = client.get( + "/api/v1/users?page=1&limit=10", + headers=headers + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert "users" in data + assert "pagination" in data + assert "total" in data["pagination"] + assert "page" in data["pagination"] + assert "pages" in data["pagination"] + assert len(data["users"]) <= 10 + + +@pytest.mark.asyncio +async def test_invalid_user_id(client: TestClient, admin_api_key: tuple): + """Test handling of invalid user IDs""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Try to get user with invalid ID + response = client.get( + "/api/v1/users/invalid-id", + headers=headers + ) + + # Check that the request returns 400 Bad Request + assert response.status_code == 400 + assert "detail" in response.json() + assert "invalid" in response.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_nonexistent_user(client: TestClient, admin_api_key: tuple): + """Test handling of nonexistent user IDs""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Try to get nonexistent user + nonexistent_id = str(ObjectId()) + response = client.get( + f"/api/v1/users/{nonexistent_id}", + headers=headers + ) + + # Check that the request returns 404 Not Found + assert response.status_code == 404 + assert "detail" in response.json() + assert "not found" in response.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_duplicate_email_validation(client: TestClient, admin_api_key: tuple, regular_user: UserModel): + """Test that duplicate emails are not allowed""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Try to create user with existing email + response = client.post( + "/api/v1/users", + headers=headers, + json={ + "email": regular_user.email, # Duplicate email + "name": "Duplicate Email User", + "team_id": str(regular_user.team_id), + "is_admin": False + } + ) + + # Check that the request is rejected + assert response.status_code == 400 + assert "detail" in response.json() + assert "email" in response.json()["detail"].lower() + + +@pytest.mark.asyncio +async def test_user_role_management(client: TestClient, admin_api_key: tuple, regular_user: UserModel): + """Test managing user roles and permissions""" + raw_key, _ = admin_api_key + + # Set up the headers + headers = {"X-API-Key": raw_key} + + # Promote user to admin + response = client.put( + f"/api/v1/users/{regular_user.id}/role", + headers=headers, + json={ + "is_admin": True + } + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["is_admin"] is True + + # Demote user back to regular + response = client.put( + f"/api/v1/users/{regular_user.id}/role", + headers=headers, + json={ + "is_admin": False + } + ) + + # Check response + assert response.status_code == 200 + data = response.json() + assert data["is_admin"] is False \ No newline at end of file diff --git a/tests/auth/test_security.py b/tests/auth/test_security.py new file mode 100644 index 0000000..4bf1925 --- /dev/null +++ b/tests/auth/test_security.py @@ -0,0 +1,268 @@ +import pytest +import hashlib +from datetime import datetime, timedelta +from bson import ObjectId + +from src.auth.security import ( + generate_api_key, + hash_api_key, + verify_api_key, + create_access_token, + verify_token +) +from src.models.api_key import ApiKeyModel +from src.models.user import UserModel +from src.models.team import TeamModel + + +class TestApiKeySecurity: + """Test API key generation and validation security""" + + def test_generate_api_key(self): + """Test API key generation produces unique, secure keys""" + team_id = str(ObjectId()) + user_id = str(ObjectId()) + + # Generate multiple keys + key1, hash1 = generate_api_key(team_id, user_id) + key2, hash2 = generate_api_key(team_id, user_id) + + # Keys should be different + assert key1 != key2 + assert hash1 != hash2 + + # Keys should be sufficiently long + assert len(key1) >= 32 + assert len(hash1) >= 32 + + # Keys should contain team and user info + assert team_id in key1 or user_id in key1 + + def test_hash_api_key_consistency(self): + """Test that hashing the same key produces the same hash""" + key = "test-api-key-123" + + hash1 = hash_api_key(key) + hash2 = hash_api_key(key) + + assert hash1 == hash2 + assert len(hash1) >= 32 # Should be a proper hash length + + def test_verify_api_key_valid(self): + """Test verifying a valid API key""" + team_id = str(ObjectId()) + user_id = str(ObjectId()) + + raw_key, key_hash = generate_api_key(team_id, user_id) + + # Verification should succeed + assert verify_api_key(raw_key, key_hash) is True + + def test_verify_api_key_invalid(self): + """Test verifying an invalid API key""" + team_id = str(ObjectId()) + user_id = str(ObjectId()) + + raw_key, key_hash = generate_api_key(team_id, user_id) + + # Wrong key should fail + assert verify_api_key("wrong-key", key_hash) is False + + # Wrong hash should fail + assert verify_api_key(raw_key, "wrong-hash") is False + + def test_api_key_format(self): + """Test that generated API keys follow expected format""" + team_id = str(ObjectId()) + user_id = str(ObjectId()) + + raw_key, key_hash = generate_api_key(team_id, user_id) + + # Key should have expected structure (prefix.hash format) + assert "." in raw_key + parts = raw_key.split(".") + assert len(parts) == 2 + + # First part should be readable prefix + prefix = parts[0] + assert len(prefix) >= 8 + + # Second part should be hash-like + hash_part = parts[1] + assert len(hash_part) >= 32 + + +class TestTokenSecurity: + """Test JWT token generation and validation""" + + def test_create_access_token(self): + """Test creating access tokens""" + user_id = str(ObjectId()) + team_id = str(ObjectId()) + + token = create_access_token( + data={"user_id": user_id, "team_id": team_id} + ) + + assert token is not None + assert isinstance(token, str) + assert len(token) > 50 # JWT tokens are typically long + + def test_verify_token_valid(self): + """Test verifying a valid token""" + user_id = str(ObjectId()) + team_id = str(ObjectId()) + + token = create_access_token( + data={"user_id": user_id, "team_id": team_id} + ) + + payload = verify_token(token) + assert payload is not None + assert payload["user_id"] == user_id + assert payload["team_id"] == team_id + + def test_verify_token_invalid(self): + """Test verifying an invalid token""" + # Invalid token should return None + assert verify_token("invalid-token") is None + assert verify_token("") is None + assert verify_token(None) is None + + def test_token_expiration(self): + """Test token expiration handling""" + user_id = str(ObjectId()) + + # Create token with very short expiration + token = create_access_token( + data={"user_id": user_id}, + expires_delta=timedelta(seconds=-1) # Already expired + ) + + # Should fail verification due to expiration + payload = verify_token(token) + assert payload is None + + +class TestSecurityValidation: + """Test security validation functions""" + + def test_validate_team_access(self): + """Test team access validation""" + team_id = ObjectId() + user_team_id = ObjectId() + + # User should have access to their own team + from src.auth.security import validate_team_access + assert validate_team_access(str(team_id), str(team_id)) is True + + # User should not have access to other teams + assert validate_team_access(str(user_team_id), str(team_id)) is False + + def test_validate_admin_permissions(self): + """Test admin permission validation""" + from src.auth.security import validate_admin_permissions + + admin_user = UserModel( + email="admin@test.com", + name="Admin User", + team_id=ObjectId(), + is_admin=True + ) + + regular_user = UserModel( + email="user@test.com", + name="Regular User", + team_id=ObjectId(), + is_admin=False + ) + + assert validate_admin_permissions(admin_user) is True + assert validate_admin_permissions(regular_user) is False + + def test_rate_limiting_validation(self): + """Test rate limiting for API keys""" + # This would test rate limiting functionality + # Implementation depends on the actual rate limiting strategy + pass + + def test_api_key_expiration_check(self): + """Test API key expiration validation""" + team_id = ObjectId() + user_id = ObjectId() + + # Create expired API key + expired_key = ApiKeyModel( + key_hash="test-hash", + user_id=user_id, + team_id=team_id, + name="Expired Key", + expiry_date=datetime.utcnow() - timedelta(days=1), + is_active=True + ) + + # Create valid API key + valid_key = ApiKeyModel( + key_hash="test-hash-2", + user_id=user_id, + team_id=team_id, + name="Valid Key", + expiry_date=datetime.utcnow() + timedelta(days=30), + is_active=True + ) + + from src.auth.security import is_api_key_valid + + assert is_api_key_valid(expired_key) is False + assert is_api_key_valid(valid_key) is True + + def test_inactive_api_key_check(self): + """Test inactive API key validation""" + team_id = ObjectId() + user_id = ObjectId() + + # Create inactive API key + inactive_key = ApiKeyModel( + key_hash="test-hash", + user_id=user_id, + team_id=team_id, + name="Inactive Key", + expiry_date=datetime.utcnow() + timedelta(days=30), + is_active=False + ) + + from src.auth.security import is_api_key_valid + assert is_api_key_valid(inactive_key) is False + + +class TestSecurityHeaders: + """Test security headers and middleware""" + + def test_cors_headers(self): + """Test CORS header configuration""" + # This would test CORS configuration + pass + + def test_security_headers(self): + """Test security headers like X-Frame-Options, etc.""" + # This would test security headers + pass + + def test_https_enforcement(self): + """Test HTTPS enforcement in production""" + # This would test HTTPS redirect functionality + pass + + +class TestPasswordSecurity: + """Test password hashing and validation if implemented""" + + def test_password_hashing(self): + """Test password hashing functionality""" + # If password authentication is implemented + pass + + def test_password_validation(self): + """Test password strength validation""" + # If password authentication is implemented + pass \ No newline at end of file diff --git a/tests/services/test_embedding_service.py b/tests/services/test_embedding_service.py new file mode 100644 index 0000000..b102f74 --- /dev/null +++ b/tests/services/test_embedding_service.py @@ -0,0 +1,360 @@ +import pytest +import numpy as np +from unittest.mock import patch, MagicMock, AsyncMock +from bson import ObjectId +from io import BytesIO + +from src.services.embedding_service import EmbeddingService +from src.models.image import ImageModel + + +class TestEmbeddingService: + """Test embedding generation for images""" + + @pytest.fixture + def mock_vision_client(self): + """Mock Google Cloud Vision client""" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.image_properties_annotation.dominant_colors.colors = [] + mock_response.label_annotations = [] + mock_response.object_localizations = [] + mock_response.text_annotations = [] + mock_client.annotate_image.return_value = mock_response + return mock_client + + @pytest.fixture + def embedding_service(self, mock_vision_client): + """Create embedding service with mocked dependencies""" + with patch('src.services.embedding_service.vision') as mock_vision: + mock_vision.ImageAnnotatorClient.return_value = mock_vision_client + service = EmbeddingService() + service.client = mock_vision_client + return service + + @pytest.fixture + def sample_image_data(self): + """Create sample image data""" + # Create a simple test image (1x1 pixel PNG) + image_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\nIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xdd\x8d\xb4\x1c\x00\x00\x00\x00IEND\xaeB`\x82' + return BytesIO(image_data) + + @pytest.fixture + def sample_image_model(self): + """Create a sample image model""" + return ImageModel( + filename="test-image.jpg", + original_filename="test_image.jpg", + file_size=1024, + content_type="image/jpeg", + storage_path="images/test-image.jpg", + team_id=ObjectId(), + uploader_id=ObjectId() + ) + + def test_generate_embedding_from_image(self, embedding_service, sample_image_data): + """Test generating embeddings from image data""" + # Mock the embedding generation + mock_embedding = np.random.rand(512).tolist() + + with patch.object(embedding_service, '_extract_features') as mock_extract: + mock_extract.return_value = mock_embedding + + # Generate embedding + embedding = embedding_service.generate_embedding(sample_image_data) + + # Verify embedding was generated + assert embedding is not None + assert len(embedding) == 512 + assert isinstance(embedding, list) + assert all(isinstance(x, (int, float)) for x in embedding) + + def test_extract_image_features(self, embedding_service, sample_image_data): + """Test extracting features from images using Vision API""" + # Mock Vision API response + mock_response = MagicMock() + mock_response.label_annotations = [ + MagicMock(description="cat", score=0.95), + MagicMock(description="animal", score=0.87), + MagicMock(description="pet", score=0.82) + ] + mock_response.object_localizations = [ + MagicMock(name="Cat", score=0.9) + ] + mock_response.image_properties_annotation.dominant_colors.colors = [ + MagicMock(color=MagicMock(red=255, green=100, blue=50), score=0.8) + ] + + embedding_service.client.annotate_image.return_value = mock_response + + # Extract features + features = embedding_service.extract_image_features(sample_image_data) + + # Verify features were extracted + assert 'labels' in features + assert 'objects' in features + assert 'colors' in features + assert len(features['labels']) == 3 + assert features['labels'][0]['description'] == "cat" + assert features['labels'][0]['score'] == 0.95 + + def test_generate_embedding_with_metadata(self, embedding_service, sample_image_data, sample_image_model): + """Test generating embeddings with image metadata""" + mock_embedding = np.random.rand(512).tolist() + mock_features = { + 'labels': [{'description': 'cat', 'score': 0.95}], + 'objects': [{'name': 'Cat', 'score': 0.9}], + 'colors': [{'red': 255, 'green': 100, 'blue': 50, 'score': 0.8}] + } + + with patch.object(embedding_service, '_extract_features') as mock_extract_features, \ + patch.object(embedding_service, 'extract_image_features') as mock_extract_metadata: + + mock_extract_features.return_value = mock_embedding + mock_extract_metadata.return_value = mock_features + + # Generate embedding with metadata + result = embedding_service.generate_embedding_with_metadata( + sample_image_data, sample_image_model + ) + + # Verify result structure + assert 'embedding' in result + assert 'metadata' in result + assert 'model' in result + assert len(result['embedding']) == 512 + assert result['metadata']['labels'][0]['description'] == 'cat' + assert result['model'] == 'clip' # or whatever model is used + + def test_batch_generate_embeddings(self, embedding_service): + """Test generating embeddings for multiple images in batch""" + # Create multiple image data samples + image_batch = [] + for i in range(3): + image_data = BytesIO(b'fake_image_data_' + str(i).encode()) + image_model = ImageModel( + filename=f"image{i}.jpg", + original_filename=f"image{i}.jpg", + file_size=1024, + content_type="image/jpeg", + storage_path=f"images/image{i}.jpg", + team_id=ObjectId(), + uploader_id=ObjectId() + ) + image_batch.append((image_data, image_model)) + + # Mock embedding generation + mock_embedding = np.random.rand(512).tolist() + with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate: + mock_generate.return_value = { + 'embedding': mock_embedding, + 'metadata': {'labels': []}, + 'model': 'clip' + } + + # Generate batch embeddings + results = embedding_service.batch_generate_embeddings(image_batch) + + # Verify batch results + assert len(results) == 3 + assert all('embedding' in result for result in results) + assert all('metadata' in result for result in results) + + def test_embedding_model_consistency(self, embedding_service, sample_image_data): + """Test that the same image produces consistent embeddings""" + mock_embedding = np.random.rand(512).tolist() + + with patch.object(embedding_service, '_extract_features') as mock_extract: + mock_extract.return_value = mock_embedding + + # Generate embedding twice + embedding1 = embedding_service.generate_embedding(sample_image_data) + sample_image_data.seek(0) # Reset stream position + embedding2 = embedding_service.generate_embedding(sample_image_data) + + # Embeddings should be identical for the same image + assert embedding1 == embedding2 + + def test_embedding_dimension_validation(self, embedding_service): + """Test that embeddings have the correct dimensions""" + mock_embedding = np.random.rand(512).tolist() + + with patch.object(embedding_service, '_extract_features') as mock_extract: + mock_extract.return_value = mock_embedding + + # Validate embedding dimensions + assert embedding_service.validate_embedding_dimensions(mock_embedding) is True + + # Test wrong dimensions + wrong_embedding = np.random.rand(256).tolist() + assert embedding_service.validate_embedding_dimensions(wrong_embedding) is False + + def test_handle_unsupported_image_format(self, embedding_service): + """Test handling of unsupported image formats""" + # Create invalid image data + invalid_data = BytesIO(b'not_an_image') + + # Should raise appropriate exception + with pytest.raises(ValueError): + embedding_service.generate_embedding(invalid_data) + + def test_handle_corrupted_image(self, embedding_service): + """Test handling of corrupted image data""" + # Create corrupted image data + corrupted_data = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00corrupted') + + # Should handle gracefully + with pytest.raises(Exception): + embedding_service.generate_embedding(corrupted_data) + + def test_vision_api_error_handling(self, embedding_service, sample_image_data): + """Test handling of Vision API errors""" + # Mock Vision API error + embedding_service.client.annotate_image.side_effect = Exception("Vision API error") + + # Should handle the error gracefully + with pytest.raises(Exception): + embedding_service.extract_image_features(sample_image_data) + + def test_embedding_caching(self, embedding_service, sample_image_data): + """Test caching of embeddings for the same image""" + # This would test caching functionality if implemented + pass + + def test_embedding_quality_metrics(self, embedding_service, sample_image_data): + """Test quality metrics for generated embeddings""" + mock_embedding = np.random.rand(512).tolist() + + with patch.object(embedding_service, '_extract_features') as mock_extract: + mock_extract.return_value = mock_embedding + + # Generate embedding + embedding = embedding_service.generate_embedding(sample_image_data) + + # Check embedding quality metrics + quality_score = embedding_service.calculate_embedding_quality(embedding) + assert 0 <= quality_score <= 1 + + def test_different_image_types(self, embedding_service): + """Test embedding generation for different image types""" + image_types = [ + ('image/jpeg', b'fake_jpeg_data'), + ('image/png', b'fake_png_data'), + ('image/webp', b'fake_webp_data') + ] + + mock_embedding = np.random.rand(512).tolist() + + with patch.object(embedding_service, '_extract_features') as mock_extract: + mock_extract.return_value = mock_embedding + + for content_type, data in image_types: + image_data = BytesIO(data) + + # Should handle different image types + embedding = embedding_service.generate_embedding(image_data) + assert len(embedding) == 512 + + def test_large_image_handling(self, embedding_service): + """Test handling of large images""" + # Create large image data (simulated) + large_image_data = BytesIO(b'x' * (10 * 1024 * 1024)) # 10MB + + mock_embedding = np.random.rand(512).tolist() + + with patch.object(embedding_service, '_extract_features') as mock_extract: + mock_extract.return_value = mock_embedding + + # Should handle large images + embedding = embedding_service.generate_embedding(large_image_data) + assert len(embedding) == 512 + + def test_embedding_normalization(self, embedding_service, sample_image_data): + """Test that embeddings are properly normalized""" + # Generate raw embedding + raw_embedding = np.random.rand(512).tolist() + + with patch.object(embedding_service, '_extract_features') as mock_extract: + mock_extract.return_value = raw_embedding + + # Generate normalized embedding + embedding = embedding_service.generate_embedding(sample_image_data, normalize=True) + + # Check if embedding is normalized (L2 norm should be 1) + norm = np.linalg.norm(embedding) + assert abs(norm - 1.0) < 0.001 # Allow small floating point errors + + +class TestEmbeddingServiceIntegration: + """Integration tests for embedding service with other components""" + + def test_embedding_to_vector_store_integration(self, embedding_service, sample_image_data, sample_image_model): + """Test integration with vector store service""" + mock_embedding = np.random.rand(512).tolist() + + with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate, \ + patch('src.services.vector_store.VectorStoreService') as mock_vector_store: + + mock_generate.return_value = { + 'embedding': mock_embedding, + 'metadata': {'labels': [{'description': 'cat', 'score': 0.95}]}, + 'model': 'clip' + } + + mock_store = mock_vector_store.return_value + mock_store.store_embedding.return_value = 'embedding_id_123' + + # Process image and store embedding + result = embedding_service.process_and_store_image( + sample_image_data, sample_image_model + ) + + # Verify integration + assert result['embedding_id'] == 'embedding_id_123' + mock_store.store_embedding.assert_called_once() + + def test_pubsub_trigger_integration(self, embedding_service): + """Test integration with Pub/Sub message processing""" + # Mock Pub/Sub message + mock_message = { + 'image_id': str(ObjectId()), + 'storage_path': 'images/test.jpg', + 'team_id': str(ObjectId()) + } + + with patch.object(embedding_service, 'process_image_from_storage') as mock_process: + mock_process.return_value = {'embedding_id': 'emb123'} + + # Process Pub/Sub message + result = embedding_service.handle_pubsub_message(mock_message) + + # Verify message processing + assert result['embedding_id'] == 'emb123' + mock_process.assert_called_once_with( + mock_message['storage_path'], + mock_message['image_id'], + mock_message['team_id'] + ) + + def test_cloud_function_deployment(self, embedding_service): + """Test Cloud Function deployment compatibility""" + # Test that the service can be initialized in a Cloud Function environment + # This would test environment variable loading, authentication, etc. + pass + + def test_error_recovery_and_retry(self, embedding_service, sample_image_data): + """Test error recovery and retry mechanisms""" + # Mock transient error followed by success + mock_embedding = np.random.rand(512).tolist() + + with patch.object(embedding_service, '_extract_features') as mock_extract: + # First call fails, second succeeds + mock_extract.side_effect = [Exception("Transient error"), mock_embedding] + + # Should retry and succeed + embedding = embedding_service.generate_embedding_with_retry( + sample_image_data, max_retries=2 + ) + + assert len(embedding) == 512 + assert mock_extract.call_count == 2 \ No newline at end of file diff --git a/tests/services/test_image_processor.py b/tests/services/test_image_processor.py new file mode 100644 index 0000000..de822d8 --- /dev/null +++ b/tests/services/test_image_processor.py @@ -0,0 +1,489 @@ +import pytest +import numpy as np +from unittest.mock import patch, MagicMock +from bson import ObjectId +from io import BytesIO +from PIL import Image + +from src.services.image_processor import ImageProcessor +from src.models.image import ImageModel + + +class TestImageProcessor: + """Test image processing functionality""" + + @pytest.fixture + def image_processor(self): + """Create image processor instance""" + return ImageProcessor() + + @pytest.fixture + def sample_image_data(self): + """Create sample image data""" + # Create a simple test image using PIL + img = Image.new('RGB', (800, 600), color='red') + img_bytes = BytesIO() + img.save(img_bytes, format='JPEG') + img_bytes.seek(0) + return img_bytes + + @pytest.fixture + def sample_png_image(self): + """Create sample PNG image data""" + img = Image.new('RGBA', (400, 300), color=(255, 0, 0, 128)) + img_bytes = BytesIO() + img.save(img_bytes, format='PNG') + img_bytes.seek(0) + return img_bytes + + @pytest.fixture + def sample_image_model(self): + """Create a sample image model""" + return ImageModel( + filename="test-image.jpg", + original_filename="test_image.jpg", + file_size=1024, + content_type="image/jpeg", + storage_path="images/test-image.jpg", + team_id=ObjectId(), + uploader_id=ObjectId() + ) + + def test_extract_image_metadata(self, image_processor, sample_image_data): + """Test extracting basic image metadata""" + # Extract metadata + metadata = image_processor.extract_metadata(sample_image_data) + + # Verify metadata extraction + assert 'width' in metadata + assert 'height' in metadata + assert 'format' in metadata + assert 'mode' in metadata + assert metadata['width'] == 800 + assert metadata['height'] == 600 + assert metadata['format'] == 'JPEG' + + def test_extract_exif_data(self, image_processor): + """Test extracting EXIF data from images""" + # Create image with EXIF data (simulated) + img = Image.new('RGB', (100, 100), color='blue') + img_bytes = BytesIO() + img.save(img_bytes, format='JPEG') + img_bytes.seek(0) + + # Extract EXIF data + exif_data = image_processor.extract_exif_data(img_bytes) + + # Verify EXIF extraction (may be empty for generated images) + assert isinstance(exif_data, dict) + + def test_resize_image(self, image_processor, sample_image_data): + """Test resizing images while maintaining aspect ratio""" + # Resize image + resized_data = image_processor.resize_image( + sample_image_data, + max_width=400, + max_height=300 + ) + + # Verify resized image + assert resized_data is not None + + # Check new dimensions + resized_img = Image.open(resized_data) + assert resized_img.width <= 400 + assert resized_img.height <= 300 + + # Aspect ratio should be maintained + original_ratio = 800 / 600 + new_ratio = resized_img.width / resized_img.height + assert abs(original_ratio - new_ratio) < 0.01 + + def test_generate_thumbnail(self, image_processor, sample_image_data): + """Test generating image thumbnails""" + # Generate thumbnail + thumbnail_data = image_processor.generate_thumbnail( + sample_image_data, + size=(150, 150) + ) + + # Verify thumbnail + assert thumbnail_data is not None + + # Check thumbnail dimensions + thumbnail_img = Image.open(thumbnail_data) + assert thumbnail_img.width <= 150 + assert thumbnail_img.height <= 150 + + def test_optimize_image_quality(self, image_processor, sample_image_data): + """Test optimizing image quality and file size""" + # Get original size + original_size = len(sample_image_data.getvalue()) + + # Optimize image + optimized_data = image_processor.optimize_image( + sample_image_data, + quality=85, + optimize=True + ) + + # Verify optimization + assert optimized_data is not None + optimized_size = len(optimized_data.getvalue()) + + # Optimized image should typically be smaller or similar size + assert optimized_size <= original_size * 1.1 # Allow 10% tolerance + + def test_convert_image_format(self, image_processor, sample_png_image): + """Test converting between image formats""" + # Convert PNG to JPEG + jpeg_data = image_processor.convert_format( + sample_png_image, + target_format='JPEG' + ) + + # Verify conversion + assert jpeg_data is not None + + # Check converted image + converted_img = Image.open(jpeg_data) + assert converted_img.format == 'JPEG' + + def test_detect_image_colors(self, image_processor, sample_image_data): + """Test detecting dominant colors in images""" + # Detect colors + colors = image_processor.detect_dominant_colors( + sample_image_data, + num_colors=5 + ) + + # Verify color detection + assert isinstance(colors, list) + assert len(colors) <= 5 + + # Each color should have RGB values and percentage + for color in colors: + assert 'rgb' in color + assert 'percentage' in color + assert len(color['rgb']) == 3 + assert 0 <= color['percentage'] <= 100 + + def test_validate_image_format(self, image_processor, sample_image_data): + """Test validating supported image formats""" + # Valid image should pass validation + is_valid = image_processor.validate_image_format(sample_image_data) + assert is_valid is True + + # Invalid data should fail validation + invalid_data = BytesIO(b'not_an_image') + is_valid = image_processor.validate_image_format(invalid_data) + assert is_valid is False + + def test_calculate_image_hash(self, image_processor, sample_image_data): + """Test calculating perceptual hash for duplicate detection""" + # Calculate hash + image_hash = image_processor.calculate_perceptual_hash(sample_image_data) + + # Verify hash + assert image_hash is not None + assert isinstance(image_hash, str) + assert len(image_hash) > 0 + + # Same image should produce same hash + sample_image_data.seek(0) + hash2 = image_processor.calculate_perceptual_hash(sample_image_data) + assert image_hash == hash2 + + def test_detect_image_orientation(self, image_processor, sample_image_data): + """Test detecting and correcting image orientation""" + # Detect orientation + orientation = image_processor.detect_orientation(sample_image_data) + + # Verify orientation detection + assert orientation in [0, 90, 180, 270] + + # Auto-correct orientation if needed + corrected_data = image_processor.auto_correct_orientation(sample_image_data) + assert corrected_data is not None + + def test_extract_text_from_image(self, image_processor): + """Test OCR text extraction from images""" + # Create image with text (simulated) + img = Image.new('RGB', (200, 100), color='white') + img_bytes = BytesIO() + img.save(img_bytes, format='JPEG') + img_bytes.seek(0) + + with patch('src.services.image_processor.pytesseract') as mock_ocr: + mock_ocr.image_to_string.return_value = "Sample text" + + # Extract text + text = image_processor.extract_text(img_bytes) + + # Verify text extraction + assert text == "Sample text" + mock_ocr.image_to_string.assert_called_once() + + def test_batch_process_images(self, image_processor): + """Test batch processing multiple images""" + # Create batch of images + image_batch = [] + for i in range(3): + img = Image.new('RGB', (100, 100), color=(i*80, 0, 0)) + img_bytes = BytesIO() + img.save(img_bytes, format='JPEG') + img_bytes.seek(0) + image_batch.append(img_bytes) + + # Process batch + results = image_processor.batch_process( + image_batch, + operations=['resize', 'thumbnail', 'metadata'] + ) + + # Verify batch processing + assert len(results) == 3 + for result in results: + assert 'metadata' in result + assert 'resized' in result + assert 'thumbnail' in result + + def test_image_quality_assessment(self, image_processor, sample_image_data): + """Test assessing image quality metrics""" + # Assess quality + quality_metrics = image_processor.assess_quality(sample_image_data) + + # Verify quality metrics + assert 'sharpness' in quality_metrics + assert 'brightness' in quality_metrics + assert 'contrast' in quality_metrics + assert 'overall_score' in quality_metrics + + # Scores should be in valid ranges + assert 0 <= quality_metrics['overall_score'] <= 100 + + def test_watermark_addition(self, image_processor, sample_image_data): + """Test adding watermarks to images""" + # Add text watermark + watermarked_data = image_processor.add_watermark( + sample_image_data, + watermark_text="SEREACT", + position="bottom-right", + opacity=0.5 + ) + + # Verify watermark addition + assert watermarked_data is not None + + # Check that image is still valid + watermarked_img = Image.open(watermarked_data) + assert watermarked_img.format == 'JPEG' + + def test_image_compression_levels(self, image_processor, sample_image_data): + """Test different compression levels""" + original_size = len(sample_image_data.getvalue()) + + # Test different quality levels + for quality in [95, 85, 75, 60]: + compressed_data = image_processor.compress_image( + sample_image_data, + quality=quality + ) + + compressed_size = len(compressed_data.getvalue()) + + # Lower quality should generally result in smaller files + if quality < 95: + assert compressed_size <= original_size + + # Reset stream position + sample_image_data.seek(0) + + def test_handle_corrupted_image(self, image_processor): + """Test handling of corrupted image data""" + # Create corrupted image data + corrupted_data = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00corrupted') + + # Should handle gracefully + with pytest.raises(Exception): + image_processor.extract_metadata(corrupted_data) + + def test_large_image_processing(self, image_processor): + """Test processing very large images""" + # Create large image (simulated) + large_img = Image.new('RGB', (4000, 3000), color='green') + img_bytes = BytesIO() + large_img.save(img_bytes, format='JPEG', quality=95) + img_bytes.seek(0) + + # Process large image + metadata = image_processor.extract_metadata(img_bytes) + + # Verify processing + assert metadata['width'] == 4000 + assert metadata['height'] == 3000 + + # Test resizing large image + img_bytes.seek(0) + resized_data = image_processor.resize_image( + img_bytes, + max_width=1920, + max_height=1080 + ) + + resized_img = Image.open(resized_data) + assert resized_img.width <= 1920 + assert resized_img.height <= 1080 + + def test_progressive_jpeg_support(self, image_processor, sample_image_data): + """Test support for progressive JPEG format""" + # Convert to progressive JPEG + progressive_data = image_processor.convert_to_progressive_jpeg( + sample_image_data + ) + + # Verify progressive format + assert progressive_data is not None + + # Check that it's still a valid JPEG + progressive_img = Image.open(progressive_data) + assert progressive_img.format == 'JPEG' + + +class TestImageProcessorIntegration: + """Integration tests for image processor with other services""" + + def test_integration_with_storage_service(self, image_processor, sample_image_data): + """Test integration with storage service""" + with patch('src.services.storage.StorageService') as mock_storage: + mock_storage_instance = mock_storage.return_value + mock_storage_instance.upload_file.return_value = ( + 'images/processed.jpg', 'image/jpeg', 1024, {} + ) + + # Process and upload image + result = image_processor.process_and_upload( + sample_image_data, + operations=['resize', 'optimize'], + team_id=str(ObjectId()) + ) + + # Verify integration + assert 'storage_path' in result + mock_storage_instance.upload_file.assert_called_once() + + def test_integration_with_embedding_service(self, image_processor, sample_image_data): + """Test integration with embedding service""" + with patch('src.services.embedding_service.EmbeddingService') as mock_embedding: + mock_embedding_instance = mock_embedding.return_value + mock_embedding_instance.generate_embedding.return_value = [0.1] * 512 + + # Process image and generate embedding + result = image_processor.process_for_embedding(sample_image_data) + + # Verify integration + assert 'processed_image' in result + assert 'embedding' in result + mock_embedding_instance.generate_embedding.assert_called_once() + + def test_pubsub_message_processing(self, image_processor): + """Test processing images from Pub/Sub messages""" + # Mock Pub/Sub message + message_data = { + 'image_id': str(ObjectId()), + 'storage_path': 'images/raw/test.jpg', + 'operations': ['resize', 'thumbnail', 'optimize'] + } + + with patch.object(image_processor, 'process_from_storage') as mock_process: + mock_process.return_value = { + 'processed_path': 'images/processed/test.jpg', + 'thumbnail_path': 'images/thumbnails/test.jpg' + } + + # Process message + result = image_processor.handle_processing_message(message_data) + + # Verify message processing + assert 'processed_path' in result + mock_process.assert_called_once() + + def test_error_handling_and_retry(self, image_processor, sample_image_data): + """Test error handling and retry mechanisms""" + # Mock transient error followed by success + with patch.object(image_processor, 'extract_metadata') as mock_extract: + # First call fails, second succeeds + mock_extract.side_effect = [ + Exception("Transient error"), + {'width': 800, 'height': 600, 'format': 'JPEG'} + ] + + # Should retry and succeed + metadata = image_processor.extract_metadata_with_retry( + sample_image_data, + max_retries=2 + ) + + assert metadata['width'] == 800 + assert mock_extract.call_count == 2 + + +class TestImageProcessorPerformance: + """Performance tests for image processing""" + + def test_processing_speed_benchmarks(self, image_processor): + """Test processing speed for different image sizes""" + import time + + sizes = [(100, 100), (500, 500), (1000, 1000)] + + for width, height in sizes: + # Create test image + img = Image.new('RGB', (width, height), color='blue') + img_bytes = BytesIO() + img.save(img_bytes, format='JPEG') + img_bytes.seek(0) + + # Measure processing time + start_time = time.time() + metadata = image_processor.extract_metadata(img_bytes) + processing_time = time.time() - start_time + + # Verify reasonable processing time (adjust thresholds as needed) + assert processing_time < 5.0 # Should process within 5 seconds + assert metadata['width'] == width + assert metadata['height'] == height + + def test_memory_usage_optimization(self, image_processor): + """Test memory usage during image processing""" + # This would test memory usage patterns + # Implementation depends on memory profiling tools + pass + + def test_concurrent_processing(self, image_processor): + """Test concurrent image processing""" + import concurrent.futures + + # Create multiple test images + images = [] + for i in range(5): + img = Image.new('RGB', (200, 200), color=(i*50, 0, 0)) + img_bytes = BytesIO() + img.save(img_bytes, format='JPEG') + img_bytes.seek(0) + images.append(img_bytes) + + # Process concurrently + with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: + futures = [ + executor.submit(image_processor.extract_metadata, img) + for img in images + ] + + results = [future.result() for future in futures] + + # Verify all processed successfully + assert len(results) == 5 + for result in results: + assert 'width' in result + assert 'height' in result \ No newline at end of file diff --git a/tests/services/test_vector_store.py b/tests/services/test_vector_store.py new file mode 100644 index 0000000..881cd03 --- /dev/null +++ b/tests/services/test_vector_store.py @@ -0,0 +1,391 @@ +import pytest +import numpy as np +from unittest.mock import patch, MagicMock, AsyncMock +from bson import ObjectId + +from src.services.vector_store import VectorStoreService +from src.models.image import ImageModel + + +class TestVectorStoreService: + """Test vector store operations for semantic search""" + + @pytest.fixture + def mock_pinecone_index(self): + """Mock Pinecone index for testing""" + mock_index = MagicMock() + mock_index.upsert = MagicMock() + mock_index.query = MagicMock() + mock_index.delete = MagicMock() + mock_index.describe_index_stats = MagicMock() + return mock_index + + @pytest.fixture + def vector_store_service(self, mock_pinecone_index): + """Create vector store service with mocked dependencies""" + with patch('src.services.vector_store.pinecone') as mock_pinecone: + mock_pinecone.Index.return_value = mock_pinecone_index + service = VectorStoreService() + service.index = mock_pinecone_index + return service + + @pytest.fixture + def sample_embedding(self): + """Generate a sample embedding vector""" + return np.random.rand(512).tolist() # 512-dimensional vector + + @pytest.fixture + def sample_image(self): + """Create a sample image model""" + return ImageModel( + filename="test-image.jpg", + original_filename="test_image.jpg", + file_size=1024, + content_type="image/jpeg", + storage_path="images/test-image.jpg", + team_id=ObjectId(), + uploader_id=ObjectId(), + description="A test image", + tags=["test", "image"] + ) + + def test_store_embedding(self, vector_store_service, sample_embedding, sample_image): + """Test storing an embedding in the vector database""" + # Store the embedding + embedding_id = vector_store_service.store_embedding( + image_id=str(sample_image.id), + embedding=sample_embedding, + metadata={ + "filename": sample_image.filename, + "team_id": str(sample_image.team_id), + "tags": sample_image.tags, + "description": sample_image.description + } + ) + + # Verify the embedding was stored + assert embedding_id is not None + vector_store_service.index.upsert.assert_called_once() + + # Check the upsert call arguments + call_args = vector_store_service.index.upsert.call_args + vectors = call_args[1]['vectors'] + assert len(vectors) == 1 + assert vectors[0]['id'] == embedding_id + assert len(vectors[0]['values']) == len(sample_embedding) + assert 'metadata' in vectors[0] + + def test_search_similar_images(self, vector_store_service, sample_embedding): + """Test searching for similar images using vector similarity""" + # Mock search results + mock_results = { + 'matches': [ + { + 'id': 'embedding1', + 'score': 0.95, + 'metadata': { + 'image_id': str(ObjectId()), + 'filename': 'similar1.jpg', + 'team_id': str(ObjectId()), + 'tags': ['cat', 'animal'] + } + }, + { + 'id': 'embedding2', + 'score': 0.87, + 'metadata': { + 'image_id': str(ObjectId()), + 'filename': 'similar2.jpg', + 'team_id': str(ObjectId()), + 'tags': ['dog', 'animal'] + } + } + ] + } + vector_store_service.index.query.return_value = mock_results + + # Perform search + results = vector_store_service.search_similar( + query_embedding=sample_embedding, + team_id=str(ObjectId()), + top_k=10, + score_threshold=0.8 + ) + + # Verify search was performed + vector_store_service.index.query.assert_called_once() + + # Check results + assert len(results) == 2 + assert results[0]['score'] == 0.95 + assert results[1]['score'] == 0.87 + assert all('image_id' in result for result in results) + + def test_search_with_filters(self, vector_store_service, sample_embedding): + """Test searching with metadata filters""" + team_id = str(ObjectId()) + + # Perform search with team filter + vector_store_service.search_similar( + query_embedding=sample_embedding, + team_id=team_id, + top_k=5, + filters={"tags": {"$in": ["cat", "dog"]}} + ) + + # Verify filter was applied + call_args = vector_store_service.index.query.call_args + assert 'filter' in call_args[1] + assert call_args[1]['filter']['team_id'] == team_id + + def test_delete_embedding(self, vector_store_service): + """Test deleting an embedding from the vector database""" + embedding_id = "test-embedding-123" + + # Delete the embedding + success = vector_store_service.delete_embedding(embedding_id) + + # Verify deletion was attempted + vector_store_service.index.delete.assert_called_once_with(ids=[embedding_id]) + assert success is True + + def test_batch_store_embeddings(self, vector_store_service, sample_embedding): + """Test storing multiple embeddings in batch""" + # Create batch data + batch_data = [] + for i in range(5): + batch_data.append({ + 'image_id': str(ObjectId()), + 'embedding': sample_embedding, + 'metadata': { + 'filename': f'image{i}.jpg', + 'team_id': str(ObjectId()), + 'tags': [f'tag{i}'] + } + }) + + # Store batch + embedding_ids = vector_store_service.batch_store_embeddings(batch_data) + + # Verify batch storage + assert len(embedding_ids) == 5 + vector_store_service.index.upsert.assert_called_once() + + # Check batch upsert call + call_args = vector_store_service.index.upsert.call_args + vectors = call_args[1]['vectors'] + assert len(vectors) == 5 + + def test_get_index_stats(self, vector_store_service): + """Test getting vector database statistics""" + # Mock stats response + mock_stats = { + 'total_vector_count': 1000, + 'dimension': 512, + 'index_fullness': 0.1 + } + vector_store_service.index.describe_index_stats.return_value = mock_stats + + # Get stats + stats = vector_store_service.get_index_stats() + + # Verify stats retrieval + vector_store_service.index.describe_index_stats.assert_called_once() + assert stats['total_vector_count'] == 1000 + assert stats['dimension'] == 512 + + def test_search_with_score_threshold(self, vector_store_service, sample_embedding): + """Test filtering search results by score threshold""" + # Mock results with varying scores + mock_results = { + 'matches': [ + {'id': 'emb1', 'score': 0.95, 'metadata': {'image_id': '1'}}, + {'id': 'emb2', 'score': 0.75, 'metadata': {'image_id': '2'}}, + {'id': 'emb3', 'score': 0.65, 'metadata': {'image_id': '3'}}, + {'id': 'emb4', 'score': 0.45, 'metadata': {'image_id': '4'}} + ] + } + vector_store_service.index.query.return_value = mock_results + + # Search with score threshold + results = vector_store_service.search_similar( + query_embedding=sample_embedding, + team_id=str(ObjectId()), + top_k=10, + score_threshold=0.7 + ) + + # Only results above threshold should be returned + assert len(results) == 2 + assert all(result['score'] >= 0.7 for result in results) + + def test_update_embedding_metadata(self, vector_store_service): + """Test updating metadata for an existing embedding""" + embedding_id = "test-embedding-123" + new_metadata = { + 'tags': ['updated', 'tag'], + 'description': 'Updated description' + } + + # Update metadata + success = vector_store_service.update_embedding_metadata( + embedding_id, new_metadata + ) + + # Verify update was attempted + # This would depend on the actual implementation + assert success is True + + def test_search_by_image_id(self, vector_store_service): + """Test searching for a specific image's embedding""" + image_id = str(ObjectId()) + + # Mock search by metadata + mock_results = { + 'matches': [ + { + 'id': 'embedding1', + 'score': 1.0, + 'metadata': { + 'image_id': image_id, + 'filename': 'target.jpg' + } + } + ] + } + vector_store_service.index.query.return_value = mock_results + + # Search by image ID + result = vector_store_service.get_embedding_by_image_id(image_id) + + # Verify result + assert result is not None + assert result['metadata']['image_id'] == image_id + + def test_bulk_delete_embeddings(self, vector_store_service): + """Test deleting multiple embeddings""" + embedding_ids = ['emb1', 'emb2', 'emb3'] + + # Delete multiple embeddings + success = vector_store_service.bulk_delete_embeddings(embedding_ids) + + # Verify bulk deletion + vector_store_service.index.delete.assert_called_once_with(ids=embedding_ids) + assert success is True + + def test_search_pagination(self, vector_store_service, sample_embedding): + """Test paginated search results""" + # This would test pagination if implemented + # Implementation depends on how pagination is handled in the vector store + pass + + def test_vector_dimension_validation(self, vector_store_service): + """Test validation of embedding dimensions""" + # Test with wrong dimension + wrong_dimension_embedding = np.random.rand(256).tolist() # Wrong size + + with pytest.raises(ValueError): + vector_store_service.store_embedding( + image_id=str(ObjectId()), + embedding=wrong_dimension_embedding, + metadata={} + ) + + def test_connection_error_handling(self, vector_store_service): + """Test handling of connection errors""" + # Mock connection error + vector_store_service.index.query.side_effect = Exception("Connection failed") + + # Search should handle the error gracefully + with pytest.raises(Exception): + vector_store_service.search_similar( + query_embedding=[0.1] * 512, + team_id=str(ObjectId()), + top_k=10 + ) + + def test_empty_search_results(self, vector_store_service, sample_embedding): + """Test handling of empty search results""" + # Mock empty results + vector_store_service.index.query.return_value = {'matches': []} + + # Search should return empty list + results = vector_store_service.search_similar( + query_embedding=sample_embedding, + team_id=str(ObjectId()), + top_k=10 + ) + + assert results == [] + + +class TestVectorStoreIntegration: + """Integration tests for vector store with other services""" + + def test_embedding_lifecycle(self, vector_store_service, sample_embedding, sample_image): + """Test complete embedding lifecycle: store, search, update, delete""" + # Store embedding + embedding_id = vector_store_service.store_embedding( + image_id=str(sample_image.id), + embedding=sample_embedding, + metadata={'filename': sample_image.filename} + ) + + # Search for similar embeddings + mock_results = { + 'matches': [ + { + 'id': embedding_id, + 'score': 1.0, + 'metadata': {'image_id': str(sample_image.id)} + } + ] + } + vector_store_service.index.query.return_value = mock_results + + results = vector_store_service.search_similar( + query_embedding=sample_embedding, + team_id=str(sample_image.team_id), + top_k=1 + ) + + assert len(results) == 1 + assert results[0]['id'] == embedding_id + + # Delete embedding + success = vector_store_service.delete_embedding(embedding_id) + assert success is True + + def test_team_isolation(self, vector_store_service, sample_embedding): + """Test that team data is properly isolated""" + team1_id = str(ObjectId()) + team2_id = str(ObjectId()) + + # Mock search results that should be filtered by team + mock_results = { + 'matches': [ + { + 'id': 'emb1', + 'score': 0.9, + 'metadata': {'image_id': '1', 'team_id': team1_id} + }, + { + 'id': 'emb2', + 'score': 0.8, + 'metadata': {'image_id': '2', 'team_id': team2_id} + } + ] + } + vector_store_service.index.query.return_value = mock_results + + # Search for team1 should only return team1 results + results = vector_store_service.search_similar( + query_embedding=sample_embedding, + team_id=team1_id, + top_k=10 + ) + + # Verify team filter was applied in the query + call_args = vector_store_service.index.query.call_args + assert 'filter' in call_args[1] + assert call_args[1]['filter']['team_id'] == team1_id \ No newline at end of file