cp
This commit is contained in:
parent
1bfd3b7d69
commit
1010ed8d4e
620
tests/api/test_collections.py
Normal file
620
tests/api/test_collections.py
Normal file
@ -0,0 +1,620 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from bson import ObjectId
|
||||
from datetime import datetime
|
||||
|
||||
from src.models.team import TeamModel
|
||||
from src.models.user import UserModel
|
||||
from src.models.image import ImageModel
|
||||
|
||||
|
||||
class CollectionModel:
|
||||
"""Mock collection model for testing"""
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.get('id', ObjectId())
|
||||
self.name = kwargs.get('name')
|
||||
self.description = kwargs.get('description')
|
||||
self.team_id = kwargs.get('team_id')
|
||||
self.created_by = kwargs.get('created_by')
|
||||
self.created_at = kwargs.get('created_at', datetime.utcnow())
|
||||
self.updated_at = kwargs.get('updated_at', datetime.utcnow())
|
||||
self.metadata = kwargs.get('metadata', {})
|
||||
self.image_count = kwargs.get('image_count', 0)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_collection(client: TestClient, admin_api_key: tuple, test_team: TeamModel):
|
||||
"""Test creating a new image collection"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers with the admin API key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a new collection
|
||||
response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Test Collection",
|
||||
"description": "A collection for testing images",
|
||||
"metadata": {
|
||||
"category": "test",
|
||||
"project": "sereact"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert data["name"] == "Test Collection"
|
||||
assert data["description"] == "A collection for testing images"
|
||||
assert data["team_id"] == str(test_team.id)
|
||||
assert "created_at" in data
|
||||
assert data["image_count"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_collection_non_admin(client: TestClient, user_api_key: tuple):
|
||||
"""Test that regular users can create collections"""
|
||||
raw_key, _ = user_api_key
|
||||
|
||||
# Set up the headers with a regular user API key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a new collection
|
||||
response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "User Collection",
|
||||
"description": "A collection created by a regular user"
|
||||
}
|
||||
)
|
||||
|
||||
# Check response - regular users should be able to create collections
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert data["name"] == "User Collection"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_collections(client: TestClient, admin_api_key: tuple, test_team: TeamModel):
|
||||
"""Test listing collections for a team"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# List collections
|
||||
response = client.get(
|
||||
"/api/v1/collections",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "collections" in data
|
||||
assert "total" in data
|
||||
assert isinstance(data["collections"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_collection_by_id(client: TestClient, admin_api_key: tuple):
|
||||
"""Test getting a specific collection by ID"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# First create a collection
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
create_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Get Test Collection",
|
||||
"description": "Collection for get testing"
|
||||
}
|
||||
)
|
||||
|
||||
assert create_response.status_code == 201
|
||||
collection_id = create_response.json()["id"]
|
||||
|
||||
# Get the collection
|
||||
response = client.get(
|
||||
f"/api/v1/collections/{collection_id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == collection_id
|
||||
assert data["name"] == "Get Test Collection"
|
||||
assert data["description"] == "Collection for get testing"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_collection(client: TestClient, admin_api_key: tuple):
|
||||
"""Test updating a collection"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# First create a collection
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
create_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Update Test Collection",
|
||||
"description": "Collection for update testing"
|
||||
}
|
||||
)
|
||||
|
||||
assert create_response.status_code == 201
|
||||
collection_id = create_response.json()["id"]
|
||||
|
||||
# Update the collection
|
||||
response = client.put(
|
||||
f"/api/v1/collections/{collection_id}",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Updated Collection Name",
|
||||
"description": "This collection has been updated",
|
||||
"metadata": {
|
||||
"updated": True,
|
||||
"version": 2
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == collection_id
|
||||
assert data["name"] == "Updated Collection Name"
|
||||
assert data["description"] == "This collection has been updated"
|
||||
assert data["metadata"]["updated"] is True
|
||||
assert "updated_at" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_collection(client: TestClient, admin_api_key: tuple):
|
||||
"""Test deleting a collection"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# First create a collection
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
create_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Delete Test Collection",
|
||||
"description": "Collection for delete testing"
|
||||
}
|
||||
)
|
||||
|
||||
assert create_response.status_code == 201
|
||||
collection_id = create_response.json()["id"]
|
||||
|
||||
# Delete the collection
|
||||
response = client.delete(
|
||||
f"/api/v1/collections/{collection_id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify collection is deleted
|
||||
get_response = client.get(
|
||||
f"/api/v1/collections/{collection_id}",
|
||||
headers=headers
|
||||
)
|
||||
assert get_response.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_image_to_collection(client: TestClient, admin_api_key: tuple):
|
||||
"""Test adding an image to a collection"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a collection
|
||||
collection_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Image Collection",
|
||||
"description": "Collection for image testing"
|
||||
}
|
||||
)
|
||||
collection_id = collection_response.json()["id"]
|
||||
|
||||
# Mock image ID (in real implementation, this would be from uploaded image)
|
||||
image_id = str(ObjectId())
|
||||
|
||||
# Add image to collection
|
||||
response = client.post(
|
||||
f"/api/v1/collections/{collection_id}/images",
|
||||
headers=headers,
|
||||
json={
|
||||
"image_id": image_id
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "success" in data
|
||||
assert data["success"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_remove_image_from_collection(client: TestClient, admin_api_key: tuple):
|
||||
"""Test removing an image from a collection"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a collection
|
||||
collection_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Remove Image Collection",
|
||||
"description": "Collection for image removal testing"
|
||||
}
|
||||
)
|
||||
collection_id = collection_response.json()["id"]
|
||||
|
||||
# Mock image ID
|
||||
image_id = str(ObjectId())
|
||||
|
||||
# Add image to collection first
|
||||
add_response = client.post(
|
||||
f"/api/v1/collections/{collection_id}/images",
|
||||
headers=headers,
|
||||
json={"image_id": image_id}
|
||||
)
|
||||
assert add_response.status_code == 200
|
||||
|
||||
# Remove image from collection
|
||||
response = client.delete(
|
||||
f"/api/v1/collections/{collection_id}/images/{image_id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_images_in_collection(client: TestClient, admin_api_key: tuple):
|
||||
"""Test listing images in a collection"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a collection
|
||||
collection_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "List Images Collection",
|
||||
"description": "Collection for listing images"
|
||||
}
|
||||
)
|
||||
collection_id = collection_response.json()["id"]
|
||||
|
||||
# List images in collection
|
||||
response = client.get(
|
||||
f"/api/v1/collections/{collection_id}/images",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "images" in data
|
||||
assert "total" in data
|
||||
assert isinstance(data["images"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_access_control(client: TestClient, user_api_key: tuple):
|
||||
"""Test that users can only access collections from their team"""
|
||||
raw_key, _ = user_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to access a collection from another team
|
||||
other_collection_id = str(ObjectId())
|
||||
|
||||
response = client.get(
|
||||
f"/api/v1/collections/{other_collection_id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Should return 403 or 404 depending on implementation
|
||||
assert response.status_code in [403, 404]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_pagination(client: TestClient, admin_api_key: tuple):
|
||||
"""Test collection listing with pagination"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Test pagination
|
||||
response = client.get(
|
||||
"/api/v1/collections?page=1&limit=10",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "collections" in data
|
||||
assert "pagination" in data
|
||||
assert "total" in data["pagination"]
|
||||
assert "page" in data["pagination"]
|
||||
assert "pages" in data["pagination"]
|
||||
assert len(data["collections"]) <= 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_search(client: TestClient, admin_api_key: tuple):
|
||||
"""Test searching collections by name or description"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a collection with searchable content
|
||||
create_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Searchable Collection",
|
||||
"description": "This collection contains vacation photos"
|
||||
}
|
||||
)
|
||||
assert create_response.status_code == 201
|
||||
|
||||
# Search for collections
|
||||
response = client.get(
|
||||
"/api/v1/collections/search?query=vacation",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "collections" in data
|
||||
assert "total" in data
|
||||
|
||||
# Results should contain collections matching the query
|
||||
for collection in data["collections"]:
|
||||
assert ("vacation" in collection["name"].lower() or
|
||||
"vacation" in collection["description"].lower())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_statistics(client: TestClient, admin_api_key: tuple):
|
||||
"""Test getting collection statistics"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a collection
|
||||
collection_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Stats Collection",
|
||||
"description": "Collection for statistics testing"
|
||||
}
|
||||
)
|
||||
collection_id = collection_response.json()["id"]
|
||||
|
||||
# Get collection statistics
|
||||
response = client.get(
|
||||
f"/api/v1/collections/{collection_id}/stats",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "image_count" in data
|
||||
assert "total_size" in data
|
||||
assert "created_at" in data
|
||||
assert "last_updated" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_add_images_to_collection(client: TestClient, admin_api_key: tuple):
|
||||
"""Test adding multiple images to a collection at once"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a collection
|
||||
collection_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Bulk Add Collection",
|
||||
"description": "Collection for bulk operations"
|
||||
}
|
||||
)
|
||||
collection_id = collection_response.json()["id"]
|
||||
|
||||
# Mock multiple image IDs
|
||||
image_ids = [str(ObjectId()) for _ in range(5)]
|
||||
|
||||
# Bulk add images to collection
|
||||
response = client.post(
|
||||
f"/api/v1/collections/{collection_id}/images/bulk",
|
||||
headers=headers,
|
||||
json={
|
||||
"image_ids": image_ids
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "added_count" in data
|
||||
assert data["added_count"] == 5
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_export(client: TestClient, admin_api_key: tuple):
|
||||
"""Test exporting collection metadata"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a collection
|
||||
collection_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Export Collection",
|
||||
"description": "Collection for export testing",
|
||||
"metadata": {"category": "test", "project": "sereact"}
|
||||
}
|
||||
)
|
||||
collection_id = collection_response.json()["id"]
|
||||
|
||||
# Export collection
|
||||
response = client.get(
|
||||
f"/api/v1/collections/{collection_id}/export",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "collection" in data
|
||||
assert "images" in data
|
||||
assert data["collection"]["name"] == "Export Collection"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_duplicate_name_validation(client: TestClient, admin_api_key: tuple):
|
||||
"""Test that duplicate collection names within a team are handled"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create first collection
|
||||
response1 = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Duplicate Name Collection",
|
||||
"description": "First collection with this name"
|
||||
}
|
||||
)
|
||||
assert response1.status_code == 201
|
||||
|
||||
# Try to create second collection with same name
|
||||
response2 = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Duplicate Name Collection",
|
||||
"description": "Second collection with same name"
|
||||
}
|
||||
)
|
||||
|
||||
# Should either allow it or reject it depending on business rules
|
||||
# For now, let's assume it's allowed but with different handling
|
||||
assert response2.status_code in [201, 400]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_metadata_validation(client: TestClient, admin_api_key: tuple):
|
||||
"""Test validation of collection metadata"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Test with invalid metadata structure
|
||||
response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Invalid Metadata Collection",
|
||||
"description": "Collection with invalid metadata",
|
||||
"metadata": "invalid_metadata_type" # Should be dict
|
||||
}
|
||||
)
|
||||
|
||||
# Should validate metadata structure
|
||||
assert response.status_code == 400
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collection_ownership_transfer(client: TestClient, admin_api_key: tuple):
|
||||
"""Test transferring collection ownership (admin only)"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a collection
|
||||
collection_response = client.post(
|
||||
"/api/v1/collections",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Transfer Collection",
|
||||
"description": "Collection for ownership transfer"
|
||||
}
|
||||
)
|
||||
collection_id = collection_response.json()["id"]
|
||||
|
||||
# Transfer ownership to another user
|
||||
new_owner_id = str(ObjectId())
|
||||
response = client.put(
|
||||
f"/api/v1/collections/{collection_id}/owner",
|
||||
headers=headers,
|
||||
json={
|
||||
"new_owner_id": new_owner_id
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["created_by"] == new_owner_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_collection_id(client: TestClient, admin_api_key: tuple):
|
||||
"""Test handling of invalid collection IDs"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to get collection with invalid ID
|
||||
response = client.get(
|
||||
"/api/v1/collections/invalid-id",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check that the request returns 400 Bad Request
|
||||
assert response.status_code == 400
|
||||
assert "detail" in response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_collection(client: TestClient, admin_api_key: tuple):
|
||||
"""Test handling of nonexistent collection IDs"""
|
||||
raw_key, _ = admin_api_key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to get nonexistent collection
|
||||
nonexistent_id = str(ObjectId())
|
||||
response = client.get(
|
||||
f"/api/v1/collections/{nonexistent_id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check that the request returns 404 Not Found
|
||||
assert response.status_code == 404
|
||||
assert "detail" in response.json()
|
||||
549
tests/api/test_users.py
Normal file
549
tests/api/test_users.py
Normal file
@ -0,0 +1,549 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from bson import ObjectId
|
||||
|
||||
from src.models.user import UserModel
|
||||
from src.models.team import TeamModel
|
||||
from src.db.repositories.user_repository import user_repository
|
||||
from src.db.repositories.team_repository import team_repository
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user(client: TestClient, admin_api_key: tuple, test_team: TeamModel):
|
||||
"""Test creating a new user (admin only)"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers with the admin API key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Create a new user
|
||||
response = client.post(
|
||||
"/api/v1/users",
|
||||
headers=headers,
|
||||
json={
|
||||
"email": "newuser@example.com",
|
||||
"name": "New User",
|
||||
"team_id": str(test_team.id),
|
||||
"is_admin": False
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 201
|
||||
data = response.json()
|
||||
assert "id" in data
|
||||
assert data["email"] == "newuser@example.com"
|
||||
assert data["name"] == "New User"
|
||||
assert data["team_id"] == str(test_team.id)
|
||||
assert data["is_admin"] is False
|
||||
assert "created_at" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_non_admin(client: TestClient, user_api_key: tuple, test_team: TeamModel):
|
||||
"""Test that non-admin users cannot create users"""
|
||||
raw_key, _ = user_api_key
|
||||
|
||||
# Set up the headers with a non-admin API key
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to create a new user
|
||||
response = client.post(
|
||||
"/api/v1/users",
|
||||
headers=headers,
|
||||
json={
|
||||
"email": "unauthorized@example.com",
|
||||
"name": "Unauthorized User",
|
||||
"team_id": str(test_team.id),
|
||||
"is_admin": False
|
||||
}
|
||||
)
|
||||
|
||||
# Check that the request is forbidden
|
||||
assert response.status_code == 403
|
||||
assert "detail" in response.json()
|
||||
assert "admin" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_in_team(client: TestClient, admin_api_key: tuple, test_team: TeamModel):
|
||||
"""Test listing users in a team"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# List users in the team
|
||||
response = client.get(
|
||||
f"/api/v1/users?team_id={test_team.id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "users" in data
|
||||
assert "total" in data
|
||||
assert data["total"] >= 1 # Should include at least the admin user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_users_admin_only(client: TestClient, admin_api_key: tuple):
|
||||
"""Test that only admins can list all users across teams"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# List all users (no team filter)
|
||||
response = client.get(
|
||||
"/api/v1/users",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "users" in data
|
||||
assert "total" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_non_admin_restricted(client: TestClient, user_api_key: tuple):
|
||||
"""Test that non-admin users can only see their own team"""
|
||||
raw_key, _ = user_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to list all users
|
||||
response = client.get(
|
||||
"/api/v1/users",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Should only return users from their own team
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "users" in data
|
||||
# All returned users should be from the same team
|
||||
if data["users"]:
|
||||
user_team_ids = set(user["team_id"] for user in data["users"])
|
||||
assert len(user_team_ids) == 1 # Only one team represented
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_id(client: TestClient, admin_api_key: tuple, admin_user: UserModel):
|
||||
"""Test getting a specific user by ID"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Get the user
|
||||
response = client.get(
|
||||
f"/api/v1/users/{admin_user.id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == str(admin_user.id)
|
||||
assert data["email"] == admin_user.email
|
||||
assert data["name"] == admin_user.name
|
||||
assert data["team_id"] == str(admin_user.team_id)
|
||||
assert data["is_admin"] == admin_user.is_admin
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_self(client: TestClient, user_api_key: tuple, regular_user: UserModel):
|
||||
"""Test that users can get their own information"""
|
||||
raw_key, _ = user_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Get own user information
|
||||
response = client.get(
|
||||
f"/api/v1/users/{regular_user.id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == str(regular_user.id)
|
||||
assert data["email"] == regular_user.email
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_other_team_forbidden(client: TestClient, user_api_key: tuple):
|
||||
"""Test that users cannot access users from other teams"""
|
||||
raw_key, _ = user_api_key
|
||||
|
||||
# Create a user in another team
|
||||
other_team = TeamModel(
|
||||
name="Other Team",
|
||||
description="Another team for testing"
|
||||
)
|
||||
created_team = await team_repository.create(other_team)
|
||||
|
||||
other_user = UserModel(
|
||||
email="other@example.com",
|
||||
name="Other User",
|
||||
team_id=created_team.id,
|
||||
is_admin=False
|
||||
)
|
||||
created_user = await user_repository.create(other_user)
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to get the other user
|
||||
response = client.get(
|
||||
f"/api/v1/users/{created_user.id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check that the request is forbidden
|
||||
assert response.status_code == 403
|
||||
assert "detail" in response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user(client: TestClient, admin_api_key: tuple, regular_user: UserModel):
|
||||
"""Test updating a user (admin only)"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Update the user
|
||||
response = client.put(
|
||||
f"/api/v1/users/{regular_user.id}",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Updated User Name",
|
||||
"email": "updated@example.com",
|
||||
"is_admin": True
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == str(regular_user.id)
|
||||
assert data["name"] == "Updated User Name"
|
||||
assert data["email"] == "updated@example.com"
|
||||
assert data["is_admin"] is True
|
||||
assert "updated_at" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_self(client: TestClient, user_api_key: tuple, regular_user: UserModel):
|
||||
"""Test that users can update their own basic information"""
|
||||
raw_key, _ = user_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Update own information (limited fields)
|
||||
response = client.put(
|
||||
f"/api/v1/users/{regular_user.id}",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Self Updated Name",
|
||||
"email": "self_updated@example.com"
|
||||
# Note: is_admin should not be updatable by self
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Self Updated Name"
|
||||
assert data["email"] == "self_updated@example.com"
|
||||
# Admin status should remain unchanged
|
||||
assert data["is_admin"] == regular_user.is_admin
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_non_admin_forbidden(client: TestClient, user_api_key: tuple, admin_user: UserModel):
|
||||
"""Test that non-admin users cannot update other users"""
|
||||
raw_key, _ = user_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to update another user
|
||||
response = client.put(
|
||||
f"/api/v1/users/{admin_user.id}",
|
||||
headers=headers,
|
||||
json={
|
||||
"name": "Unauthorized Update",
|
||||
"is_admin": False
|
||||
}
|
||||
)
|
||||
|
||||
# Check that the request is forbidden
|
||||
assert response.status_code == 403
|
||||
assert "detail" in response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user(client: TestClient, admin_api_key: tuple):
|
||||
"""Test deleting a user (admin only)"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Create a user to delete
|
||||
user_to_delete = UserModel(
|
||||
email="delete@example.com",
|
||||
name="User to Delete",
|
||||
team_id=ObjectId(),
|
||||
is_admin=False
|
||||
)
|
||||
created_user = await user_repository.create(user_to_delete)
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Delete the user
|
||||
response = client.delete(
|
||||
f"/api/v1/users/{created_user.id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 204
|
||||
|
||||
# Verify the user has been deleted
|
||||
deleted_user = await user_repository.get_by_id(created_user.id)
|
||||
assert deleted_user is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_non_admin(client: TestClient, user_api_key: tuple, admin_user: UserModel):
|
||||
"""Test that non-admin users cannot delete users"""
|
||||
raw_key, _ = user_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to delete a user
|
||||
response = client.delete(
|
||||
f"/api/v1/users/{admin_user.id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check that the request is forbidden
|
||||
assert response.status_code == 403
|
||||
assert "detail" in response.json()
|
||||
assert "admin" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_activity(client: TestClient, admin_api_key: tuple, regular_user: UserModel):
|
||||
"""Test getting user activity/statistics"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Get user activity
|
||||
response = client.get(
|
||||
f"/api/v1/users/{regular_user.id}/activity",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "user_id" in data
|
||||
assert "images_uploaded" in data
|
||||
assert "last_login" in data
|
||||
assert "api_key_usage" in data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_user_team(client: TestClient, admin_api_key: tuple, regular_user: UserModel):
|
||||
"""Test moving a user to a different team (admin only)"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Create another team
|
||||
new_team = TeamModel(
|
||||
name="New Team",
|
||||
description="A new team for testing"
|
||||
)
|
||||
created_team = await team_repository.create(new_team)
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Move user to new team
|
||||
response = client.put(
|
||||
f"/api/v1/users/{regular_user.id}/team",
|
||||
headers=headers,
|
||||
json={
|
||||
"team_id": str(created_team.id)
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["team_id"] == str(created_team.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_search(client: TestClient, admin_api_key: tuple):
|
||||
"""Test searching for users by email or name"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Search for users
|
||||
response = client.get(
|
||||
"/api/v1/users/search?query=admin",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "users" in data
|
||||
assert "total" in data
|
||||
|
||||
# Results should contain users matching the query
|
||||
for user in data["users"]:
|
||||
assert "admin" in user["email"].lower() or "admin" in user["name"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_pagination(client: TestClient, admin_api_key: tuple):
|
||||
"""Test user listing with pagination"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Test pagination
|
||||
response = client.get(
|
||||
"/api/v1/users?page=1&limit=10",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "users" in data
|
||||
assert "pagination" in data
|
||||
assert "total" in data["pagination"]
|
||||
assert "page" in data["pagination"]
|
||||
assert "pages" in data["pagination"]
|
||||
assert len(data["users"]) <= 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_user_id(client: TestClient, admin_api_key: tuple):
|
||||
"""Test handling of invalid user IDs"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to get user with invalid ID
|
||||
response = client.get(
|
||||
"/api/v1/users/invalid-id",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check that the request returns 400 Bad Request
|
||||
assert response.status_code == 400
|
||||
assert "detail" in response.json()
|
||||
assert "invalid" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nonexistent_user(client: TestClient, admin_api_key: tuple):
|
||||
"""Test handling of nonexistent user IDs"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to get nonexistent user
|
||||
nonexistent_id = str(ObjectId())
|
||||
response = client.get(
|
||||
f"/api/v1/users/{nonexistent_id}",
|
||||
headers=headers
|
||||
)
|
||||
|
||||
# Check that the request returns 404 Not Found
|
||||
assert response.status_code == 404
|
||||
assert "detail" in response.json()
|
||||
assert "not found" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_duplicate_email_validation(client: TestClient, admin_api_key: tuple, regular_user: UserModel):
|
||||
"""Test that duplicate emails are not allowed"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Try to create user with existing email
|
||||
response = client.post(
|
||||
"/api/v1/users",
|
||||
headers=headers,
|
||||
json={
|
||||
"email": regular_user.email, # Duplicate email
|
||||
"name": "Duplicate Email User",
|
||||
"team_id": str(regular_user.team_id),
|
||||
"is_admin": False
|
||||
}
|
||||
)
|
||||
|
||||
# Check that the request is rejected
|
||||
assert response.status_code == 400
|
||||
assert "detail" in response.json()
|
||||
assert "email" in response.json()["detail"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_role_management(client: TestClient, admin_api_key: tuple, regular_user: UserModel):
|
||||
"""Test managing user roles and permissions"""
|
||||
raw_key, _ = admin_api_key
|
||||
|
||||
# Set up the headers
|
||||
headers = {"X-API-Key": raw_key}
|
||||
|
||||
# Promote user to admin
|
||||
response = client.put(
|
||||
f"/api/v1/users/{regular_user.id}/role",
|
||||
headers=headers,
|
||||
json={
|
||||
"is_admin": True
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_admin"] is True
|
||||
|
||||
# Demote user back to regular
|
||||
response = client.put(
|
||||
f"/api/v1/users/{regular_user.id}/role",
|
||||
headers=headers,
|
||||
json={
|
||||
"is_admin": False
|
||||
}
|
||||
)
|
||||
|
||||
# Check response
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_admin"] is False
|
||||
268
tests/auth/test_security.py
Normal file
268
tests/auth/test_security.py
Normal file
@ -0,0 +1,268 @@
|
||||
import pytest
|
||||
import hashlib
|
||||
from datetime import datetime, timedelta
|
||||
from bson import ObjectId
|
||||
|
||||
from src.auth.security import (
|
||||
generate_api_key,
|
||||
hash_api_key,
|
||||
verify_api_key,
|
||||
create_access_token,
|
||||
verify_token
|
||||
)
|
||||
from src.models.api_key import ApiKeyModel
|
||||
from src.models.user import UserModel
|
||||
from src.models.team import TeamModel
|
||||
|
||||
|
||||
class TestApiKeySecurity:
|
||||
"""Test API key generation and validation security"""
|
||||
|
||||
def test_generate_api_key(self):
|
||||
"""Test API key generation produces unique, secure keys"""
|
||||
team_id = str(ObjectId())
|
||||
user_id = str(ObjectId())
|
||||
|
||||
# Generate multiple keys
|
||||
key1, hash1 = generate_api_key(team_id, user_id)
|
||||
key2, hash2 = generate_api_key(team_id, user_id)
|
||||
|
||||
# Keys should be different
|
||||
assert key1 != key2
|
||||
assert hash1 != hash2
|
||||
|
||||
# Keys should be sufficiently long
|
||||
assert len(key1) >= 32
|
||||
assert len(hash1) >= 32
|
||||
|
||||
# Keys should contain team and user info
|
||||
assert team_id in key1 or user_id in key1
|
||||
|
||||
def test_hash_api_key_consistency(self):
|
||||
"""Test that hashing the same key produces the same hash"""
|
||||
key = "test-api-key-123"
|
||||
|
||||
hash1 = hash_api_key(key)
|
||||
hash2 = hash_api_key(key)
|
||||
|
||||
assert hash1 == hash2
|
||||
assert len(hash1) >= 32 # Should be a proper hash length
|
||||
|
||||
def test_verify_api_key_valid(self):
|
||||
"""Test verifying a valid API key"""
|
||||
team_id = str(ObjectId())
|
||||
user_id = str(ObjectId())
|
||||
|
||||
raw_key, key_hash = generate_api_key(team_id, user_id)
|
||||
|
||||
# Verification should succeed
|
||||
assert verify_api_key(raw_key, key_hash) is True
|
||||
|
||||
def test_verify_api_key_invalid(self):
|
||||
"""Test verifying an invalid API key"""
|
||||
team_id = str(ObjectId())
|
||||
user_id = str(ObjectId())
|
||||
|
||||
raw_key, key_hash = generate_api_key(team_id, user_id)
|
||||
|
||||
# Wrong key should fail
|
||||
assert verify_api_key("wrong-key", key_hash) is False
|
||||
|
||||
# Wrong hash should fail
|
||||
assert verify_api_key(raw_key, "wrong-hash") is False
|
||||
|
||||
def test_api_key_format(self):
|
||||
"""Test that generated API keys follow expected format"""
|
||||
team_id = str(ObjectId())
|
||||
user_id = str(ObjectId())
|
||||
|
||||
raw_key, key_hash = generate_api_key(team_id, user_id)
|
||||
|
||||
# Key should have expected structure (prefix.hash format)
|
||||
assert "." in raw_key
|
||||
parts = raw_key.split(".")
|
||||
assert len(parts) == 2
|
||||
|
||||
# First part should be readable prefix
|
||||
prefix = parts[0]
|
||||
assert len(prefix) >= 8
|
||||
|
||||
# Second part should be hash-like
|
||||
hash_part = parts[1]
|
||||
assert len(hash_part) >= 32
|
||||
|
||||
|
||||
class TestTokenSecurity:
|
||||
"""Test JWT token generation and validation"""
|
||||
|
||||
def test_create_access_token(self):
|
||||
"""Test creating access tokens"""
|
||||
user_id = str(ObjectId())
|
||||
team_id = str(ObjectId())
|
||||
|
||||
token = create_access_token(
|
||||
data={"user_id": user_id, "team_id": team_id}
|
||||
)
|
||||
|
||||
assert token is not None
|
||||
assert isinstance(token, str)
|
||||
assert len(token) > 50 # JWT tokens are typically long
|
||||
|
||||
def test_verify_token_valid(self):
|
||||
"""Test verifying a valid token"""
|
||||
user_id = str(ObjectId())
|
||||
team_id = str(ObjectId())
|
||||
|
||||
token = create_access_token(
|
||||
data={"user_id": user_id, "team_id": team_id}
|
||||
)
|
||||
|
||||
payload = verify_token(token)
|
||||
assert payload is not None
|
||||
assert payload["user_id"] == user_id
|
||||
assert payload["team_id"] == team_id
|
||||
|
||||
def test_verify_token_invalid(self):
|
||||
"""Test verifying an invalid token"""
|
||||
# Invalid token should return None
|
||||
assert verify_token("invalid-token") is None
|
||||
assert verify_token("") is None
|
||||
assert verify_token(None) is None
|
||||
|
||||
def test_token_expiration(self):
|
||||
"""Test token expiration handling"""
|
||||
user_id = str(ObjectId())
|
||||
|
||||
# Create token with very short expiration
|
||||
token = create_access_token(
|
||||
data={"user_id": user_id},
|
||||
expires_delta=timedelta(seconds=-1) # Already expired
|
||||
)
|
||||
|
||||
# Should fail verification due to expiration
|
||||
payload = verify_token(token)
|
||||
assert payload is None
|
||||
|
||||
|
||||
class TestSecurityValidation:
|
||||
"""Test security validation functions"""
|
||||
|
||||
def test_validate_team_access(self):
|
||||
"""Test team access validation"""
|
||||
team_id = ObjectId()
|
||||
user_team_id = ObjectId()
|
||||
|
||||
# User should have access to their own team
|
||||
from src.auth.security import validate_team_access
|
||||
assert validate_team_access(str(team_id), str(team_id)) is True
|
||||
|
||||
# User should not have access to other teams
|
||||
assert validate_team_access(str(user_team_id), str(team_id)) is False
|
||||
|
||||
def test_validate_admin_permissions(self):
|
||||
"""Test admin permission validation"""
|
||||
from src.auth.security import validate_admin_permissions
|
||||
|
||||
admin_user = UserModel(
|
||||
email="admin@test.com",
|
||||
name="Admin User",
|
||||
team_id=ObjectId(),
|
||||
is_admin=True
|
||||
)
|
||||
|
||||
regular_user = UserModel(
|
||||
email="user@test.com",
|
||||
name="Regular User",
|
||||
team_id=ObjectId(),
|
||||
is_admin=False
|
||||
)
|
||||
|
||||
assert validate_admin_permissions(admin_user) is True
|
||||
assert validate_admin_permissions(regular_user) is False
|
||||
|
||||
def test_rate_limiting_validation(self):
|
||||
"""Test rate limiting for API keys"""
|
||||
# This would test rate limiting functionality
|
||||
# Implementation depends on the actual rate limiting strategy
|
||||
pass
|
||||
|
||||
def test_api_key_expiration_check(self):
|
||||
"""Test API key expiration validation"""
|
||||
team_id = ObjectId()
|
||||
user_id = ObjectId()
|
||||
|
||||
# Create expired API key
|
||||
expired_key = ApiKeyModel(
|
||||
key_hash="test-hash",
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
name="Expired Key",
|
||||
expiry_date=datetime.utcnow() - timedelta(days=1),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
# Create valid API key
|
||||
valid_key = ApiKeyModel(
|
||||
key_hash="test-hash-2",
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
name="Valid Key",
|
||||
expiry_date=datetime.utcnow() + timedelta(days=30),
|
||||
is_active=True
|
||||
)
|
||||
|
||||
from src.auth.security import is_api_key_valid
|
||||
|
||||
assert is_api_key_valid(expired_key) is False
|
||||
assert is_api_key_valid(valid_key) is True
|
||||
|
||||
def test_inactive_api_key_check(self):
|
||||
"""Test inactive API key validation"""
|
||||
team_id = ObjectId()
|
||||
user_id = ObjectId()
|
||||
|
||||
# Create inactive API key
|
||||
inactive_key = ApiKeyModel(
|
||||
key_hash="test-hash",
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
name="Inactive Key",
|
||||
expiry_date=datetime.utcnow() + timedelta(days=30),
|
||||
is_active=False
|
||||
)
|
||||
|
||||
from src.auth.security import is_api_key_valid
|
||||
assert is_api_key_valid(inactive_key) is False
|
||||
|
||||
|
||||
class TestSecurityHeaders:
|
||||
"""Test security headers and middleware"""
|
||||
|
||||
def test_cors_headers(self):
|
||||
"""Test CORS header configuration"""
|
||||
# This would test CORS configuration
|
||||
pass
|
||||
|
||||
def test_security_headers(self):
|
||||
"""Test security headers like X-Frame-Options, etc."""
|
||||
# This would test security headers
|
||||
pass
|
||||
|
||||
def test_https_enforcement(self):
|
||||
"""Test HTTPS enforcement in production"""
|
||||
# This would test HTTPS redirect functionality
|
||||
pass
|
||||
|
||||
|
||||
class TestPasswordSecurity:
|
||||
"""Test password hashing and validation if implemented"""
|
||||
|
||||
def test_password_hashing(self):
|
||||
"""Test password hashing functionality"""
|
||||
# If password authentication is implemented
|
||||
pass
|
||||
|
||||
def test_password_validation(self):
|
||||
"""Test password strength validation"""
|
||||
# If password authentication is implemented
|
||||
pass
|
||||
360
tests/services/test_embedding_service.py
Normal file
360
tests/services/test_embedding_service.py
Normal file
@ -0,0 +1,360 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from bson import ObjectId
|
||||
from io import BytesIO
|
||||
|
||||
from src.services.embedding_service import EmbeddingService
|
||||
from src.models.image import ImageModel
|
||||
|
||||
|
||||
class TestEmbeddingService:
|
||||
"""Test embedding generation for images"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vision_client(self):
|
||||
"""Mock Google Cloud Vision client"""
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.image_properties_annotation.dominant_colors.colors = []
|
||||
mock_response.label_annotations = []
|
||||
mock_response.object_localizations = []
|
||||
mock_response.text_annotations = []
|
||||
mock_client.annotate_image.return_value = mock_response
|
||||
return mock_client
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_service(self, mock_vision_client):
|
||||
"""Create embedding service with mocked dependencies"""
|
||||
with patch('src.services.embedding_service.vision') as mock_vision:
|
||||
mock_vision.ImageAnnotatorClient.return_value = mock_vision_client
|
||||
service = EmbeddingService()
|
||||
service.client = mock_vision_client
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image_data(self):
|
||||
"""Create sample image data"""
|
||||
# Create a simple test image (1x1 pixel PNG)
|
||||
image_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\tpHYs\x00\x00\x0b\x13\x00\x00\x0b\x13\x01\x00\x9a\x9c\x18\x00\x00\x00\nIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xdd\x8d\xb4\x1c\x00\x00\x00\x00IEND\xaeB`\x82'
|
||||
return BytesIO(image_data)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image_model(self):
|
||||
"""Create a sample image model"""
|
||||
return ImageModel(
|
||||
filename="test-image.jpg",
|
||||
original_filename="test_image.jpg",
|
||||
file_size=1024,
|
||||
content_type="image/jpeg",
|
||||
storage_path="images/test-image.jpg",
|
||||
team_id=ObjectId(),
|
||||
uploader_id=ObjectId()
|
||||
)
|
||||
|
||||
def test_generate_embedding_from_image(self, embedding_service, sample_image_data):
|
||||
"""Test generating embeddings from image data"""
|
||||
# Mock the embedding generation
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
|
||||
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
||||
mock_extract.return_value = mock_embedding
|
||||
|
||||
# Generate embedding
|
||||
embedding = embedding_service.generate_embedding(sample_image_data)
|
||||
|
||||
# Verify embedding was generated
|
||||
assert embedding is not None
|
||||
assert len(embedding) == 512
|
||||
assert isinstance(embedding, list)
|
||||
assert all(isinstance(x, (int, float)) for x in embedding)
|
||||
|
||||
def test_extract_image_features(self, embedding_service, sample_image_data):
|
||||
"""Test extracting features from images using Vision API"""
|
||||
# Mock Vision API response
|
||||
mock_response = MagicMock()
|
||||
mock_response.label_annotations = [
|
||||
MagicMock(description="cat", score=0.95),
|
||||
MagicMock(description="animal", score=0.87),
|
||||
MagicMock(description="pet", score=0.82)
|
||||
]
|
||||
mock_response.object_localizations = [
|
||||
MagicMock(name="Cat", score=0.9)
|
||||
]
|
||||
mock_response.image_properties_annotation.dominant_colors.colors = [
|
||||
MagicMock(color=MagicMock(red=255, green=100, blue=50), score=0.8)
|
||||
]
|
||||
|
||||
embedding_service.client.annotate_image.return_value = mock_response
|
||||
|
||||
# Extract features
|
||||
features = embedding_service.extract_image_features(sample_image_data)
|
||||
|
||||
# Verify features were extracted
|
||||
assert 'labels' in features
|
||||
assert 'objects' in features
|
||||
assert 'colors' in features
|
||||
assert len(features['labels']) == 3
|
||||
assert features['labels'][0]['description'] == "cat"
|
||||
assert features['labels'][0]['score'] == 0.95
|
||||
|
||||
def test_generate_embedding_with_metadata(self, embedding_service, sample_image_data, sample_image_model):
|
||||
"""Test generating embeddings with image metadata"""
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
mock_features = {
|
||||
'labels': [{'description': 'cat', 'score': 0.95}],
|
||||
'objects': [{'name': 'Cat', 'score': 0.9}],
|
||||
'colors': [{'red': 255, 'green': 100, 'blue': 50, 'score': 0.8}]
|
||||
}
|
||||
|
||||
with patch.object(embedding_service, '_extract_features') as mock_extract_features, \
|
||||
patch.object(embedding_service, 'extract_image_features') as mock_extract_metadata:
|
||||
|
||||
mock_extract_features.return_value = mock_embedding
|
||||
mock_extract_metadata.return_value = mock_features
|
||||
|
||||
# Generate embedding with metadata
|
||||
result = embedding_service.generate_embedding_with_metadata(
|
||||
sample_image_data, sample_image_model
|
||||
)
|
||||
|
||||
# Verify result structure
|
||||
assert 'embedding' in result
|
||||
assert 'metadata' in result
|
||||
assert 'model' in result
|
||||
assert len(result['embedding']) == 512
|
||||
assert result['metadata']['labels'][0]['description'] == 'cat'
|
||||
assert result['model'] == 'clip' # or whatever model is used
|
||||
|
||||
def test_batch_generate_embeddings(self, embedding_service):
|
||||
"""Test generating embeddings for multiple images in batch"""
|
||||
# Create multiple image data samples
|
||||
image_batch = []
|
||||
for i in range(3):
|
||||
image_data = BytesIO(b'fake_image_data_' + str(i).encode())
|
||||
image_model = ImageModel(
|
||||
filename=f"image{i}.jpg",
|
||||
original_filename=f"image{i}.jpg",
|
||||
file_size=1024,
|
||||
content_type="image/jpeg",
|
||||
storage_path=f"images/image{i}.jpg",
|
||||
team_id=ObjectId(),
|
||||
uploader_id=ObjectId()
|
||||
)
|
||||
image_batch.append((image_data, image_model))
|
||||
|
||||
# Mock embedding generation
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate:
|
||||
mock_generate.return_value = {
|
||||
'embedding': mock_embedding,
|
||||
'metadata': {'labels': []},
|
||||
'model': 'clip'
|
||||
}
|
||||
|
||||
# Generate batch embeddings
|
||||
results = embedding_service.batch_generate_embeddings(image_batch)
|
||||
|
||||
# Verify batch results
|
||||
assert len(results) == 3
|
||||
assert all('embedding' in result for result in results)
|
||||
assert all('metadata' in result for result in results)
|
||||
|
||||
def test_embedding_model_consistency(self, embedding_service, sample_image_data):
|
||||
"""Test that the same image produces consistent embeddings"""
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
|
||||
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
||||
mock_extract.return_value = mock_embedding
|
||||
|
||||
# Generate embedding twice
|
||||
embedding1 = embedding_service.generate_embedding(sample_image_data)
|
||||
sample_image_data.seek(0) # Reset stream position
|
||||
embedding2 = embedding_service.generate_embedding(sample_image_data)
|
||||
|
||||
# Embeddings should be identical for the same image
|
||||
assert embedding1 == embedding2
|
||||
|
||||
def test_embedding_dimension_validation(self, embedding_service):
|
||||
"""Test that embeddings have the correct dimensions"""
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
|
||||
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
||||
mock_extract.return_value = mock_embedding
|
||||
|
||||
# Validate embedding dimensions
|
||||
assert embedding_service.validate_embedding_dimensions(mock_embedding) is True
|
||||
|
||||
# Test wrong dimensions
|
||||
wrong_embedding = np.random.rand(256).tolist()
|
||||
assert embedding_service.validate_embedding_dimensions(wrong_embedding) is False
|
||||
|
||||
def test_handle_unsupported_image_format(self, embedding_service):
|
||||
"""Test handling of unsupported image formats"""
|
||||
# Create invalid image data
|
||||
invalid_data = BytesIO(b'not_an_image')
|
||||
|
||||
# Should raise appropriate exception
|
||||
with pytest.raises(ValueError):
|
||||
embedding_service.generate_embedding(invalid_data)
|
||||
|
||||
def test_handle_corrupted_image(self, embedding_service):
|
||||
"""Test handling of corrupted image data"""
|
||||
# Create corrupted image data
|
||||
corrupted_data = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00corrupted')
|
||||
|
||||
# Should handle gracefully
|
||||
with pytest.raises(Exception):
|
||||
embedding_service.generate_embedding(corrupted_data)
|
||||
|
||||
def test_vision_api_error_handling(self, embedding_service, sample_image_data):
|
||||
"""Test handling of Vision API errors"""
|
||||
# Mock Vision API error
|
||||
embedding_service.client.annotate_image.side_effect = Exception("Vision API error")
|
||||
|
||||
# Should handle the error gracefully
|
||||
with pytest.raises(Exception):
|
||||
embedding_service.extract_image_features(sample_image_data)
|
||||
|
||||
def test_embedding_caching(self, embedding_service, sample_image_data):
|
||||
"""Test caching of embeddings for the same image"""
|
||||
# This would test caching functionality if implemented
|
||||
pass
|
||||
|
||||
def test_embedding_quality_metrics(self, embedding_service, sample_image_data):
|
||||
"""Test quality metrics for generated embeddings"""
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
|
||||
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
||||
mock_extract.return_value = mock_embedding
|
||||
|
||||
# Generate embedding
|
||||
embedding = embedding_service.generate_embedding(sample_image_data)
|
||||
|
||||
# Check embedding quality metrics
|
||||
quality_score = embedding_service.calculate_embedding_quality(embedding)
|
||||
assert 0 <= quality_score <= 1
|
||||
|
||||
def test_different_image_types(self, embedding_service):
|
||||
"""Test embedding generation for different image types"""
|
||||
image_types = [
|
||||
('image/jpeg', b'fake_jpeg_data'),
|
||||
('image/png', b'fake_png_data'),
|
||||
('image/webp', b'fake_webp_data')
|
||||
]
|
||||
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
|
||||
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
||||
mock_extract.return_value = mock_embedding
|
||||
|
||||
for content_type, data in image_types:
|
||||
image_data = BytesIO(data)
|
||||
|
||||
# Should handle different image types
|
||||
embedding = embedding_service.generate_embedding(image_data)
|
||||
assert len(embedding) == 512
|
||||
|
||||
def test_large_image_handling(self, embedding_service):
|
||||
"""Test handling of large images"""
|
||||
# Create large image data (simulated)
|
||||
large_image_data = BytesIO(b'x' * (10 * 1024 * 1024)) # 10MB
|
||||
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
|
||||
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
||||
mock_extract.return_value = mock_embedding
|
||||
|
||||
# Should handle large images
|
||||
embedding = embedding_service.generate_embedding(large_image_data)
|
||||
assert len(embedding) == 512
|
||||
|
||||
def test_embedding_normalization(self, embedding_service, sample_image_data):
|
||||
"""Test that embeddings are properly normalized"""
|
||||
# Generate raw embedding
|
||||
raw_embedding = np.random.rand(512).tolist()
|
||||
|
||||
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
||||
mock_extract.return_value = raw_embedding
|
||||
|
||||
# Generate normalized embedding
|
||||
embedding = embedding_service.generate_embedding(sample_image_data, normalize=True)
|
||||
|
||||
# Check if embedding is normalized (L2 norm should be 1)
|
||||
norm = np.linalg.norm(embedding)
|
||||
assert abs(norm - 1.0) < 0.001 # Allow small floating point errors
|
||||
|
||||
|
||||
class TestEmbeddingServiceIntegration:
|
||||
"""Integration tests for embedding service with other components"""
|
||||
|
||||
def test_embedding_to_vector_store_integration(self, embedding_service, sample_image_data, sample_image_model):
|
||||
"""Test integration with vector store service"""
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
|
||||
with patch.object(embedding_service, 'generate_embedding_with_metadata') as mock_generate, \
|
||||
patch('src.services.vector_store.VectorStoreService') as mock_vector_store:
|
||||
|
||||
mock_generate.return_value = {
|
||||
'embedding': mock_embedding,
|
||||
'metadata': {'labels': [{'description': 'cat', 'score': 0.95}]},
|
||||
'model': 'clip'
|
||||
}
|
||||
|
||||
mock_store = mock_vector_store.return_value
|
||||
mock_store.store_embedding.return_value = 'embedding_id_123'
|
||||
|
||||
# Process image and store embedding
|
||||
result = embedding_service.process_and_store_image(
|
||||
sample_image_data, sample_image_model
|
||||
)
|
||||
|
||||
# Verify integration
|
||||
assert result['embedding_id'] == 'embedding_id_123'
|
||||
mock_store.store_embedding.assert_called_once()
|
||||
|
||||
def test_pubsub_trigger_integration(self, embedding_service):
|
||||
"""Test integration with Pub/Sub message processing"""
|
||||
# Mock Pub/Sub message
|
||||
mock_message = {
|
||||
'image_id': str(ObjectId()),
|
||||
'storage_path': 'images/test.jpg',
|
||||
'team_id': str(ObjectId())
|
||||
}
|
||||
|
||||
with patch.object(embedding_service, 'process_image_from_storage') as mock_process:
|
||||
mock_process.return_value = {'embedding_id': 'emb123'}
|
||||
|
||||
# Process Pub/Sub message
|
||||
result = embedding_service.handle_pubsub_message(mock_message)
|
||||
|
||||
# Verify message processing
|
||||
assert result['embedding_id'] == 'emb123'
|
||||
mock_process.assert_called_once_with(
|
||||
mock_message['storage_path'],
|
||||
mock_message['image_id'],
|
||||
mock_message['team_id']
|
||||
)
|
||||
|
||||
def test_cloud_function_deployment(self, embedding_service):
|
||||
"""Test Cloud Function deployment compatibility"""
|
||||
# Test that the service can be initialized in a Cloud Function environment
|
||||
# This would test environment variable loading, authentication, etc.
|
||||
pass
|
||||
|
||||
def test_error_recovery_and_retry(self, embedding_service, sample_image_data):
|
||||
"""Test error recovery and retry mechanisms"""
|
||||
# Mock transient error followed by success
|
||||
mock_embedding = np.random.rand(512).tolist()
|
||||
|
||||
with patch.object(embedding_service, '_extract_features') as mock_extract:
|
||||
# First call fails, second succeeds
|
||||
mock_extract.side_effect = [Exception("Transient error"), mock_embedding]
|
||||
|
||||
# Should retry and succeed
|
||||
embedding = embedding_service.generate_embedding_with_retry(
|
||||
sample_image_data, max_retries=2
|
||||
)
|
||||
|
||||
assert len(embedding) == 512
|
||||
assert mock_extract.call_count == 2
|
||||
489
tests/services/test_image_processor.py
Normal file
489
tests/services/test_image_processor.py
Normal file
@ -0,0 +1,489 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import patch, MagicMock
|
||||
from bson import ObjectId
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from src.services.image_processor import ImageProcessor
|
||||
from src.models.image import ImageModel
|
||||
|
||||
|
||||
class TestImageProcessor:
|
||||
"""Test image processing functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def image_processor(self):
|
||||
"""Create image processor instance"""
|
||||
return ImageProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image_data(self):
|
||||
"""Create sample image data"""
|
||||
# Create a simple test image using PIL
|
||||
img = Image.new('RGB', (800, 600), color='red')
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format='JPEG')
|
||||
img_bytes.seek(0)
|
||||
return img_bytes
|
||||
|
||||
@pytest.fixture
|
||||
def sample_png_image(self):
|
||||
"""Create sample PNG image data"""
|
||||
img = Image.new('RGBA', (400, 300), color=(255, 0, 0, 128))
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format='PNG')
|
||||
img_bytes.seek(0)
|
||||
return img_bytes
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image_model(self):
|
||||
"""Create a sample image model"""
|
||||
return ImageModel(
|
||||
filename="test-image.jpg",
|
||||
original_filename="test_image.jpg",
|
||||
file_size=1024,
|
||||
content_type="image/jpeg",
|
||||
storage_path="images/test-image.jpg",
|
||||
team_id=ObjectId(),
|
||||
uploader_id=ObjectId()
|
||||
)
|
||||
|
||||
def test_extract_image_metadata(self, image_processor, sample_image_data):
|
||||
"""Test extracting basic image metadata"""
|
||||
# Extract metadata
|
||||
metadata = image_processor.extract_metadata(sample_image_data)
|
||||
|
||||
# Verify metadata extraction
|
||||
assert 'width' in metadata
|
||||
assert 'height' in metadata
|
||||
assert 'format' in metadata
|
||||
assert 'mode' in metadata
|
||||
assert metadata['width'] == 800
|
||||
assert metadata['height'] == 600
|
||||
assert metadata['format'] == 'JPEG'
|
||||
|
||||
def test_extract_exif_data(self, image_processor):
|
||||
"""Test extracting EXIF data from images"""
|
||||
# Create image with EXIF data (simulated)
|
||||
img = Image.new('RGB', (100, 100), color='blue')
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format='JPEG')
|
||||
img_bytes.seek(0)
|
||||
|
||||
# Extract EXIF data
|
||||
exif_data = image_processor.extract_exif_data(img_bytes)
|
||||
|
||||
# Verify EXIF extraction (may be empty for generated images)
|
||||
assert isinstance(exif_data, dict)
|
||||
|
||||
def test_resize_image(self, image_processor, sample_image_data):
|
||||
"""Test resizing images while maintaining aspect ratio"""
|
||||
# Resize image
|
||||
resized_data = image_processor.resize_image(
|
||||
sample_image_data,
|
||||
max_width=400,
|
||||
max_height=300
|
||||
)
|
||||
|
||||
# Verify resized image
|
||||
assert resized_data is not None
|
||||
|
||||
# Check new dimensions
|
||||
resized_img = Image.open(resized_data)
|
||||
assert resized_img.width <= 400
|
||||
assert resized_img.height <= 300
|
||||
|
||||
# Aspect ratio should be maintained
|
||||
original_ratio = 800 / 600
|
||||
new_ratio = resized_img.width / resized_img.height
|
||||
assert abs(original_ratio - new_ratio) < 0.01
|
||||
|
||||
def test_generate_thumbnail(self, image_processor, sample_image_data):
|
||||
"""Test generating image thumbnails"""
|
||||
# Generate thumbnail
|
||||
thumbnail_data = image_processor.generate_thumbnail(
|
||||
sample_image_data,
|
||||
size=(150, 150)
|
||||
)
|
||||
|
||||
# Verify thumbnail
|
||||
assert thumbnail_data is not None
|
||||
|
||||
# Check thumbnail dimensions
|
||||
thumbnail_img = Image.open(thumbnail_data)
|
||||
assert thumbnail_img.width <= 150
|
||||
assert thumbnail_img.height <= 150
|
||||
|
||||
def test_optimize_image_quality(self, image_processor, sample_image_data):
|
||||
"""Test optimizing image quality and file size"""
|
||||
# Get original size
|
||||
original_size = len(sample_image_data.getvalue())
|
||||
|
||||
# Optimize image
|
||||
optimized_data = image_processor.optimize_image(
|
||||
sample_image_data,
|
||||
quality=85,
|
||||
optimize=True
|
||||
)
|
||||
|
||||
# Verify optimization
|
||||
assert optimized_data is not None
|
||||
optimized_size = len(optimized_data.getvalue())
|
||||
|
||||
# Optimized image should typically be smaller or similar size
|
||||
assert optimized_size <= original_size * 1.1 # Allow 10% tolerance
|
||||
|
||||
def test_convert_image_format(self, image_processor, sample_png_image):
|
||||
"""Test converting between image formats"""
|
||||
# Convert PNG to JPEG
|
||||
jpeg_data = image_processor.convert_format(
|
||||
sample_png_image,
|
||||
target_format='JPEG'
|
||||
)
|
||||
|
||||
# Verify conversion
|
||||
assert jpeg_data is not None
|
||||
|
||||
# Check converted image
|
||||
converted_img = Image.open(jpeg_data)
|
||||
assert converted_img.format == 'JPEG'
|
||||
|
||||
def test_detect_image_colors(self, image_processor, sample_image_data):
|
||||
"""Test detecting dominant colors in images"""
|
||||
# Detect colors
|
||||
colors = image_processor.detect_dominant_colors(
|
||||
sample_image_data,
|
||||
num_colors=5
|
||||
)
|
||||
|
||||
# Verify color detection
|
||||
assert isinstance(colors, list)
|
||||
assert len(colors) <= 5
|
||||
|
||||
# Each color should have RGB values and percentage
|
||||
for color in colors:
|
||||
assert 'rgb' in color
|
||||
assert 'percentage' in color
|
||||
assert len(color['rgb']) == 3
|
||||
assert 0 <= color['percentage'] <= 100
|
||||
|
||||
def test_validate_image_format(self, image_processor, sample_image_data):
|
||||
"""Test validating supported image formats"""
|
||||
# Valid image should pass validation
|
||||
is_valid = image_processor.validate_image_format(sample_image_data)
|
||||
assert is_valid is True
|
||||
|
||||
# Invalid data should fail validation
|
||||
invalid_data = BytesIO(b'not_an_image')
|
||||
is_valid = image_processor.validate_image_format(invalid_data)
|
||||
assert is_valid is False
|
||||
|
||||
def test_calculate_image_hash(self, image_processor, sample_image_data):
|
||||
"""Test calculating perceptual hash for duplicate detection"""
|
||||
# Calculate hash
|
||||
image_hash = image_processor.calculate_perceptual_hash(sample_image_data)
|
||||
|
||||
# Verify hash
|
||||
assert image_hash is not None
|
||||
assert isinstance(image_hash, str)
|
||||
assert len(image_hash) > 0
|
||||
|
||||
# Same image should produce same hash
|
||||
sample_image_data.seek(0)
|
||||
hash2 = image_processor.calculate_perceptual_hash(sample_image_data)
|
||||
assert image_hash == hash2
|
||||
|
||||
def test_detect_image_orientation(self, image_processor, sample_image_data):
|
||||
"""Test detecting and correcting image orientation"""
|
||||
# Detect orientation
|
||||
orientation = image_processor.detect_orientation(sample_image_data)
|
||||
|
||||
# Verify orientation detection
|
||||
assert orientation in [0, 90, 180, 270]
|
||||
|
||||
# Auto-correct orientation if needed
|
||||
corrected_data = image_processor.auto_correct_orientation(sample_image_data)
|
||||
assert corrected_data is not None
|
||||
|
||||
def test_extract_text_from_image(self, image_processor):
|
||||
"""Test OCR text extraction from images"""
|
||||
# Create image with text (simulated)
|
||||
img = Image.new('RGB', (200, 100), color='white')
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format='JPEG')
|
||||
img_bytes.seek(0)
|
||||
|
||||
with patch('src.services.image_processor.pytesseract') as mock_ocr:
|
||||
mock_ocr.image_to_string.return_value = "Sample text"
|
||||
|
||||
# Extract text
|
||||
text = image_processor.extract_text(img_bytes)
|
||||
|
||||
# Verify text extraction
|
||||
assert text == "Sample text"
|
||||
mock_ocr.image_to_string.assert_called_once()
|
||||
|
||||
def test_batch_process_images(self, image_processor):
|
||||
"""Test batch processing multiple images"""
|
||||
# Create batch of images
|
||||
image_batch = []
|
||||
for i in range(3):
|
||||
img = Image.new('RGB', (100, 100), color=(i*80, 0, 0))
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format='JPEG')
|
||||
img_bytes.seek(0)
|
||||
image_batch.append(img_bytes)
|
||||
|
||||
# Process batch
|
||||
results = image_processor.batch_process(
|
||||
image_batch,
|
||||
operations=['resize', 'thumbnail', 'metadata']
|
||||
)
|
||||
|
||||
# Verify batch processing
|
||||
assert len(results) == 3
|
||||
for result in results:
|
||||
assert 'metadata' in result
|
||||
assert 'resized' in result
|
||||
assert 'thumbnail' in result
|
||||
|
||||
def test_image_quality_assessment(self, image_processor, sample_image_data):
|
||||
"""Test assessing image quality metrics"""
|
||||
# Assess quality
|
||||
quality_metrics = image_processor.assess_quality(sample_image_data)
|
||||
|
||||
# Verify quality metrics
|
||||
assert 'sharpness' in quality_metrics
|
||||
assert 'brightness' in quality_metrics
|
||||
assert 'contrast' in quality_metrics
|
||||
assert 'overall_score' in quality_metrics
|
||||
|
||||
# Scores should be in valid ranges
|
||||
assert 0 <= quality_metrics['overall_score'] <= 100
|
||||
|
||||
def test_watermark_addition(self, image_processor, sample_image_data):
|
||||
"""Test adding watermarks to images"""
|
||||
# Add text watermark
|
||||
watermarked_data = image_processor.add_watermark(
|
||||
sample_image_data,
|
||||
watermark_text="SEREACT",
|
||||
position="bottom-right",
|
||||
opacity=0.5
|
||||
)
|
||||
|
||||
# Verify watermark addition
|
||||
assert watermarked_data is not None
|
||||
|
||||
# Check that image is still valid
|
||||
watermarked_img = Image.open(watermarked_data)
|
||||
assert watermarked_img.format == 'JPEG'
|
||||
|
||||
def test_image_compression_levels(self, image_processor, sample_image_data):
|
||||
"""Test different compression levels"""
|
||||
original_size = len(sample_image_data.getvalue())
|
||||
|
||||
# Test different quality levels
|
||||
for quality in [95, 85, 75, 60]:
|
||||
compressed_data = image_processor.compress_image(
|
||||
sample_image_data,
|
||||
quality=quality
|
||||
)
|
||||
|
||||
compressed_size = len(compressed_data.getvalue())
|
||||
|
||||
# Lower quality should generally result in smaller files
|
||||
if quality < 95:
|
||||
assert compressed_size <= original_size
|
||||
|
||||
# Reset stream position
|
||||
sample_image_data.seek(0)
|
||||
|
||||
def test_handle_corrupted_image(self, image_processor):
|
||||
"""Test handling of corrupted image data"""
|
||||
# Create corrupted image data
|
||||
corrupted_data = BytesIO(b'\x89PNG\r\n\x1a\n\x00\x00corrupted')
|
||||
|
||||
# Should handle gracefully
|
||||
with pytest.raises(Exception):
|
||||
image_processor.extract_metadata(corrupted_data)
|
||||
|
||||
def test_large_image_processing(self, image_processor):
|
||||
"""Test processing very large images"""
|
||||
# Create large image (simulated)
|
||||
large_img = Image.new('RGB', (4000, 3000), color='green')
|
||||
img_bytes = BytesIO()
|
||||
large_img.save(img_bytes, format='JPEG', quality=95)
|
||||
img_bytes.seek(0)
|
||||
|
||||
# Process large image
|
||||
metadata = image_processor.extract_metadata(img_bytes)
|
||||
|
||||
# Verify processing
|
||||
assert metadata['width'] == 4000
|
||||
assert metadata['height'] == 3000
|
||||
|
||||
# Test resizing large image
|
||||
img_bytes.seek(0)
|
||||
resized_data = image_processor.resize_image(
|
||||
img_bytes,
|
||||
max_width=1920,
|
||||
max_height=1080
|
||||
)
|
||||
|
||||
resized_img = Image.open(resized_data)
|
||||
assert resized_img.width <= 1920
|
||||
assert resized_img.height <= 1080
|
||||
|
||||
def test_progressive_jpeg_support(self, image_processor, sample_image_data):
|
||||
"""Test support for progressive JPEG format"""
|
||||
# Convert to progressive JPEG
|
||||
progressive_data = image_processor.convert_to_progressive_jpeg(
|
||||
sample_image_data
|
||||
)
|
||||
|
||||
# Verify progressive format
|
||||
assert progressive_data is not None
|
||||
|
||||
# Check that it's still a valid JPEG
|
||||
progressive_img = Image.open(progressive_data)
|
||||
assert progressive_img.format == 'JPEG'
|
||||
|
||||
|
||||
class TestImageProcessorIntegration:
|
||||
"""Integration tests for image processor with other services"""
|
||||
|
||||
def test_integration_with_storage_service(self, image_processor, sample_image_data):
|
||||
"""Test integration with storage service"""
|
||||
with patch('src.services.storage.StorageService') as mock_storage:
|
||||
mock_storage_instance = mock_storage.return_value
|
||||
mock_storage_instance.upload_file.return_value = (
|
||||
'images/processed.jpg', 'image/jpeg', 1024, {}
|
||||
)
|
||||
|
||||
# Process and upload image
|
||||
result = image_processor.process_and_upload(
|
||||
sample_image_data,
|
||||
operations=['resize', 'optimize'],
|
||||
team_id=str(ObjectId())
|
||||
)
|
||||
|
||||
# Verify integration
|
||||
assert 'storage_path' in result
|
||||
mock_storage_instance.upload_file.assert_called_once()
|
||||
|
||||
def test_integration_with_embedding_service(self, image_processor, sample_image_data):
|
||||
"""Test integration with embedding service"""
|
||||
with patch('src.services.embedding_service.EmbeddingService') as mock_embedding:
|
||||
mock_embedding_instance = mock_embedding.return_value
|
||||
mock_embedding_instance.generate_embedding.return_value = [0.1] * 512
|
||||
|
||||
# Process image and generate embedding
|
||||
result = image_processor.process_for_embedding(sample_image_data)
|
||||
|
||||
# Verify integration
|
||||
assert 'processed_image' in result
|
||||
assert 'embedding' in result
|
||||
mock_embedding_instance.generate_embedding.assert_called_once()
|
||||
|
||||
def test_pubsub_message_processing(self, image_processor):
|
||||
"""Test processing images from Pub/Sub messages"""
|
||||
# Mock Pub/Sub message
|
||||
message_data = {
|
||||
'image_id': str(ObjectId()),
|
||||
'storage_path': 'images/raw/test.jpg',
|
||||
'operations': ['resize', 'thumbnail', 'optimize']
|
||||
}
|
||||
|
||||
with patch.object(image_processor, 'process_from_storage') as mock_process:
|
||||
mock_process.return_value = {
|
||||
'processed_path': 'images/processed/test.jpg',
|
||||
'thumbnail_path': 'images/thumbnails/test.jpg'
|
||||
}
|
||||
|
||||
# Process message
|
||||
result = image_processor.handle_processing_message(message_data)
|
||||
|
||||
# Verify message processing
|
||||
assert 'processed_path' in result
|
||||
mock_process.assert_called_once()
|
||||
|
||||
def test_error_handling_and_retry(self, image_processor, sample_image_data):
|
||||
"""Test error handling and retry mechanisms"""
|
||||
# Mock transient error followed by success
|
||||
with patch.object(image_processor, 'extract_metadata') as mock_extract:
|
||||
# First call fails, second succeeds
|
||||
mock_extract.side_effect = [
|
||||
Exception("Transient error"),
|
||||
{'width': 800, 'height': 600, 'format': 'JPEG'}
|
||||
]
|
||||
|
||||
# Should retry and succeed
|
||||
metadata = image_processor.extract_metadata_with_retry(
|
||||
sample_image_data,
|
||||
max_retries=2
|
||||
)
|
||||
|
||||
assert metadata['width'] == 800
|
||||
assert mock_extract.call_count == 2
|
||||
|
||||
|
||||
class TestImageProcessorPerformance:
|
||||
"""Performance tests for image processing"""
|
||||
|
||||
def test_processing_speed_benchmarks(self, image_processor):
|
||||
"""Test processing speed for different image sizes"""
|
||||
import time
|
||||
|
||||
sizes = [(100, 100), (500, 500), (1000, 1000)]
|
||||
|
||||
for width, height in sizes:
|
||||
# Create test image
|
||||
img = Image.new('RGB', (width, height), color='blue')
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format='JPEG')
|
||||
img_bytes.seek(0)
|
||||
|
||||
# Measure processing time
|
||||
start_time = time.time()
|
||||
metadata = image_processor.extract_metadata(img_bytes)
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Verify reasonable processing time (adjust thresholds as needed)
|
||||
assert processing_time < 5.0 # Should process within 5 seconds
|
||||
assert metadata['width'] == width
|
||||
assert metadata['height'] == height
|
||||
|
||||
def test_memory_usage_optimization(self, image_processor):
|
||||
"""Test memory usage during image processing"""
|
||||
# This would test memory usage patterns
|
||||
# Implementation depends on memory profiling tools
|
||||
pass
|
||||
|
||||
def test_concurrent_processing(self, image_processor):
|
||||
"""Test concurrent image processing"""
|
||||
import concurrent.futures
|
||||
|
||||
# Create multiple test images
|
||||
images = []
|
||||
for i in range(5):
|
||||
img = Image.new('RGB', (200, 200), color=(i*50, 0, 0))
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format='JPEG')
|
||||
img_bytes.seek(0)
|
||||
images.append(img_bytes)
|
||||
|
||||
# Process concurrently
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
|
||||
futures = [
|
||||
executor.submit(image_processor.extract_metadata, img)
|
||||
for img in images
|
||||
]
|
||||
|
||||
results = [future.result() for future in futures]
|
||||
|
||||
# Verify all processed successfully
|
||||
assert len(results) == 5
|
||||
for result in results:
|
||||
assert 'width' in result
|
||||
assert 'height' in result
|
||||
391
tests/services/test_vector_store.py
Normal file
391
tests/services/test_vector_store.py
Normal file
@ -0,0 +1,391 @@
|
||||
import pytest
|
||||
import numpy as np
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from bson import ObjectId
|
||||
|
||||
from src.services.vector_store import VectorStoreService
|
||||
from src.models.image import ImageModel
|
||||
|
||||
|
||||
class TestVectorStoreService:
|
||||
"""Test vector store operations for semantic search"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pinecone_index(self):
|
||||
"""Mock Pinecone index for testing"""
|
||||
mock_index = MagicMock()
|
||||
mock_index.upsert = MagicMock()
|
||||
mock_index.query = MagicMock()
|
||||
mock_index.delete = MagicMock()
|
||||
mock_index.describe_index_stats = MagicMock()
|
||||
return mock_index
|
||||
|
||||
@pytest.fixture
|
||||
def vector_store_service(self, mock_pinecone_index):
|
||||
"""Create vector store service with mocked dependencies"""
|
||||
with patch('src.services.vector_store.pinecone') as mock_pinecone:
|
||||
mock_pinecone.Index.return_value = mock_pinecone_index
|
||||
service = VectorStoreService()
|
||||
service.index = mock_pinecone_index
|
||||
return service
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embedding(self):
|
||||
"""Generate a sample embedding vector"""
|
||||
return np.random.rand(512).tolist() # 512-dimensional vector
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image(self):
|
||||
"""Create a sample image model"""
|
||||
return ImageModel(
|
||||
filename="test-image.jpg",
|
||||
original_filename="test_image.jpg",
|
||||
file_size=1024,
|
||||
content_type="image/jpeg",
|
||||
storage_path="images/test-image.jpg",
|
||||
team_id=ObjectId(),
|
||||
uploader_id=ObjectId(),
|
||||
description="A test image",
|
||||
tags=["test", "image"]
|
||||
)
|
||||
|
||||
def test_store_embedding(self, vector_store_service, sample_embedding, sample_image):
|
||||
"""Test storing an embedding in the vector database"""
|
||||
# Store the embedding
|
||||
embedding_id = vector_store_service.store_embedding(
|
||||
image_id=str(sample_image.id),
|
||||
embedding=sample_embedding,
|
||||
metadata={
|
||||
"filename": sample_image.filename,
|
||||
"team_id": str(sample_image.team_id),
|
||||
"tags": sample_image.tags,
|
||||
"description": sample_image.description
|
||||
}
|
||||
)
|
||||
|
||||
# Verify the embedding was stored
|
||||
assert embedding_id is not None
|
||||
vector_store_service.index.upsert.assert_called_once()
|
||||
|
||||
# Check the upsert call arguments
|
||||
call_args = vector_store_service.index.upsert.call_args
|
||||
vectors = call_args[1]['vectors']
|
||||
assert len(vectors) == 1
|
||||
assert vectors[0]['id'] == embedding_id
|
||||
assert len(vectors[0]['values']) == len(sample_embedding)
|
||||
assert 'metadata' in vectors[0]
|
||||
|
||||
def test_search_similar_images(self, vector_store_service, sample_embedding):
|
||||
"""Test searching for similar images using vector similarity"""
|
||||
# Mock search results
|
||||
mock_results = {
|
||||
'matches': [
|
||||
{
|
||||
'id': 'embedding1',
|
||||
'score': 0.95,
|
||||
'metadata': {
|
||||
'image_id': str(ObjectId()),
|
||||
'filename': 'similar1.jpg',
|
||||
'team_id': str(ObjectId()),
|
||||
'tags': ['cat', 'animal']
|
||||
}
|
||||
},
|
||||
{
|
||||
'id': 'embedding2',
|
||||
'score': 0.87,
|
||||
'metadata': {
|
||||
'image_id': str(ObjectId()),
|
||||
'filename': 'similar2.jpg',
|
||||
'team_id': str(ObjectId()),
|
||||
'tags': ['dog', 'animal']
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
vector_store_service.index.query.return_value = mock_results
|
||||
|
||||
# Perform search
|
||||
results = vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=str(ObjectId()),
|
||||
top_k=10,
|
||||
score_threshold=0.8
|
||||
)
|
||||
|
||||
# Verify search was performed
|
||||
vector_store_service.index.query.assert_called_once()
|
||||
|
||||
# Check results
|
||||
assert len(results) == 2
|
||||
assert results[0]['score'] == 0.95
|
||||
assert results[1]['score'] == 0.87
|
||||
assert all('image_id' in result for result in results)
|
||||
|
||||
def test_search_with_filters(self, vector_store_service, sample_embedding):
|
||||
"""Test searching with metadata filters"""
|
||||
team_id = str(ObjectId())
|
||||
|
||||
# Perform search with team filter
|
||||
vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=team_id,
|
||||
top_k=5,
|
||||
filters={"tags": {"$in": ["cat", "dog"]}}
|
||||
)
|
||||
|
||||
# Verify filter was applied
|
||||
call_args = vector_store_service.index.query.call_args
|
||||
assert 'filter' in call_args[1]
|
||||
assert call_args[1]['filter']['team_id'] == team_id
|
||||
|
||||
def test_delete_embedding(self, vector_store_service):
|
||||
"""Test deleting an embedding from the vector database"""
|
||||
embedding_id = "test-embedding-123"
|
||||
|
||||
# Delete the embedding
|
||||
success = vector_store_service.delete_embedding(embedding_id)
|
||||
|
||||
# Verify deletion was attempted
|
||||
vector_store_service.index.delete.assert_called_once_with(ids=[embedding_id])
|
||||
assert success is True
|
||||
|
||||
def test_batch_store_embeddings(self, vector_store_service, sample_embedding):
|
||||
"""Test storing multiple embeddings in batch"""
|
||||
# Create batch data
|
||||
batch_data = []
|
||||
for i in range(5):
|
||||
batch_data.append({
|
||||
'image_id': str(ObjectId()),
|
||||
'embedding': sample_embedding,
|
||||
'metadata': {
|
||||
'filename': f'image{i}.jpg',
|
||||
'team_id': str(ObjectId()),
|
||||
'tags': [f'tag{i}']
|
||||
}
|
||||
})
|
||||
|
||||
# Store batch
|
||||
embedding_ids = vector_store_service.batch_store_embeddings(batch_data)
|
||||
|
||||
# Verify batch storage
|
||||
assert len(embedding_ids) == 5
|
||||
vector_store_service.index.upsert.assert_called_once()
|
||||
|
||||
# Check batch upsert call
|
||||
call_args = vector_store_service.index.upsert.call_args
|
||||
vectors = call_args[1]['vectors']
|
||||
assert len(vectors) == 5
|
||||
|
||||
def test_get_index_stats(self, vector_store_service):
|
||||
"""Test getting vector database statistics"""
|
||||
# Mock stats response
|
||||
mock_stats = {
|
||||
'total_vector_count': 1000,
|
||||
'dimension': 512,
|
||||
'index_fullness': 0.1
|
||||
}
|
||||
vector_store_service.index.describe_index_stats.return_value = mock_stats
|
||||
|
||||
# Get stats
|
||||
stats = vector_store_service.get_index_stats()
|
||||
|
||||
# Verify stats retrieval
|
||||
vector_store_service.index.describe_index_stats.assert_called_once()
|
||||
assert stats['total_vector_count'] == 1000
|
||||
assert stats['dimension'] == 512
|
||||
|
||||
def test_search_with_score_threshold(self, vector_store_service, sample_embedding):
|
||||
"""Test filtering search results by score threshold"""
|
||||
# Mock results with varying scores
|
||||
mock_results = {
|
||||
'matches': [
|
||||
{'id': 'emb1', 'score': 0.95, 'metadata': {'image_id': '1'}},
|
||||
{'id': 'emb2', 'score': 0.75, 'metadata': {'image_id': '2'}},
|
||||
{'id': 'emb3', 'score': 0.65, 'metadata': {'image_id': '3'}},
|
||||
{'id': 'emb4', 'score': 0.45, 'metadata': {'image_id': '4'}}
|
||||
]
|
||||
}
|
||||
vector_store_service.index.query.return_value = mock_results
|
||||
|
||||
# Search with score threshold
|
||||
results = vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=str(ObjectId()),
|
||||
top_k=10,
|
||||
score_threshold=0.7
|
||||
)
|
||||
|
||||
# Only results above threshold should be returned
|
||||
assert len(results) == 2
|
||||
assert all(result['score'] >= 0.7 for result in results)
|
||||
|
||||
def test_update_embedding_metadata(self, vector_store_service):
|
||||
"""Test updating metadata for an existing embedding"""
|
||||
embedding_id = "test-embedding-123"
|
||||
new_metadata = {
|
||||
'tags': ['updated', 'tag'],
|
||||
'description': 'Updated description'
|
||||
}
|
||||
|
||||
# Update metadata
|
||||
success = vector_store_service.update_embedding_metadata(
|
||||
embedding_id, new_metadata
|
||||
)
|
||||
|
||||
# Verify update was attempted
|
||||
# This would depend on the actual implementation
|
||||
assert success is True
|
||||
|
||||
def test_search_by_image_id(self, vector_store_service):
|
||||
"""Test searching for a specific image's embedding"""
|
||||
image_id = str(ObjectId())
|
||||
|
||||
# Mock search by metadata
|
||||
mock_results = {
|
||||
'matches': [
|
||||
{
|
||||
'id': 'embedding1',
|
||||
'score': 1.0,
|
||||
'metadata': {
|
||||
'image_id': image_id,
|
||||
'filename': 'target.jpg'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
vector_store_service.index.query.return_value = mock_results
|
||||
|
||||
# Search by image ID
|
||||
result = vector_store_service.get_embedding_by_image_id(image_id)
|
||||
|
||||
# Verify result
|
||||
assert result is not None
|
||||
assert result['metadata']['image_id'] == image_id
|
||||
|
||||
def test_bulk_delete_embeddings(self, vector_store_service):
|
||||
"""Test deleting multiple embeddings"""
|
||||
embedding_ids = ['emb1', 'emb2', 'emb3']
|
||||
|
||||
# Delete multiple embeddings
|
||||
success = vector_store_service.bulk_delete_embeddings(embedding_ids)
|
||||
|
||||
# Verify bulk deletion
|
||||
vector_store_service.index.delete.assert_called_once_with(ids=embedding_ids)
|
||||
assert success is True
|
||||
|
||||
def test_search_pagination(self, vector_store_service, sample_embedding):
|
||||
"""Test paginated search results"""
|
||||
# This would test pagination if implemented
|
||||
# Implementation depends on how pagination is handled in the vector store
|
||||
pass
|
||||
|
||||
def test_vector_dimension_validation(self, vector_store_service):
|
||||
"""Test validation of embedding dimensions"""
|
||||
# Test with wrong dimension
|
||||
wrong_dimension_embedding = np.random.rand(256).tolist() # Wrong size
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
vector_store_service.store_embedding(
|
||||
image_id=str(ObjectId()),
|
||||
embedding=wrong_dimension_embedding,
|
||||
metadata={}
|
||||
)
|
||||
|
||||
def test_connection_error_handling(self, vector_store_service):
|
||||
"""Test handling of connection errors"""
|
||||
# Mock connection error
|
||||
vector_store_service.index.query.side_effect = Exception("Connection failed")
|
||||
|
||||
# Search should handle the error gracefully
|
||||
with pytest.raises(Exception):
|
||||
vector_store_service.search_similar(
|
||||
query_embedding=[0.1] * 512,
|
||||
team_id=str(ObjectId()),
|
||||
top_k=10
|
||||
)
|
||||
|
||||
def test_empty_search_results(self, vector_store_service, sample_embedding):
|
||||
"""Test handling of empty search results"""
|
||||
# Mock empty results
|
||||
vector_store_service.index.query.return_value = {'matches': []}
|
||||
|
||||
# Search should return empty list
|
||||
results = vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=str(ObjectId()),
|
||||
top_k=10
|
||||
)
|
||||
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestVectorStoreIntegration:
|
||||
"""Integration tests for vector store with other services"""
|
||||
|
||||
def test_embedding_lifecycle(self, vector_store_service, sample_embedding, sample_image):
|
||||
"""Test complete embedding lifecycle: store, search, update, delete"""
|
||||
# Store embedding
|
||||
embedding_id = vector_store_service.store_embedding(
|
||||
image_id=str(sample_image.id),
|
||||
embedding=sample_embedding,
|
||||
metadata={'filename': sample_image.filename}
|
||||
)
|
||||
|
||||
# Search for similar embeddings
|
||||
mock_results = {
|
||||
'matches': [
|
||||
{
|
||||
'id': embedding_id,
|
||||
'score': 1.0,
|
||||
'metadata': {'image_id': str(sample_image.id)}
|
||||
}
|
||||
]
|
||||
}
|
||||
vector_store_service.index.query.return_value = mock_results
|
||||
|
||||
results = vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=str(sample_image.team_id),
|
||||
top_k=1
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0]['id'] == embedding_id
|
||||
|
||||
# Delete embedding
|
||||
success = vector_store_service.delete_embedding(embedding_id)
|
||||
assert success is True
|
||||
|
||||
def test_team_isolation(self, vector_store_service, sample_embedding):
|
||||
"""Test that team data is properly isolated"""
|
||||
team1_id = str(ObjectId())
|
||||
team2_id = str(ObjectId())
|
||||
|
||||
# Mock search results that should be filtered by team
|
||||
mock_results = {
|
||||
'matches': [
|
||||
{
|
||||
'id': 'emb1',
|
||||
'score': 0.9,
|
||||
'metadata': {'image_id': '1', 'team_id': team1_id}
|
||||
},
|
||||
{
|
||||
'id': 'emb2',
|
||||
'score': 0.8,
|
||||
'metadata': {'image_id': '2', 'team_id': team2_id}
|
||||
}
|
||||
]
|
||||
}
|
||||
vector_store_service.index.query.return_value = mock_results
|
||||
|
||||
# Search for team1 should only return team1 results
|
||||
results = vector_store_service.search_similar(
|
||||
query_embedding=sample_embedding,
|
||||
team_id=team1_id,
|
||||
top_k=10
|
||||
)
|
||||
|
||||
# Verify team filter was applied in the query
|
||||
call_args = vector_store_service.index.query.call_args
|
||||
assert 'filter' in call_args[1]
|
||||
assert call_args[1]['filter']['team_id'] == team1_id
|
||||
Loading…
x
Reference in New Issue
Block a user