2025-05-24 18:40:36 +02:00

338 lines
12 KiB
Python

"""
Vector Database Service for handling image vectors using Qdrant.
"""
import os
import logging
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams, PointStruct
import uuid
logger = logging.getLogger(__name__)
class VectorDatabaseService:
"""Service for managing image vectors in Qdrant vector database."""
def __init__(
self,
host: str = None,
port: int = None,
api_key: str = None,
collection_name: str = "image_vectors"
):
"""
Initialize the vector database service.
Args:
host: Qdrant server host
port: Qdrant server port
api_key: API key for authentication (optional)
collection_name: Name of the collection to use
"""
self.host = host or os.getenv("QDRANT_HOST", "localhost")
self.port = port or int(os.getenv("QDRANT_PORT", "6333"))
self.api_key = api_key or os.getenv("QDRANT_API_KEY")
self.collection_name = collection_name
# Initialize Qdrant client
self.client = QdrantClient(
host=self.host,
port=self.port,
api_key=self.api_key
)
# Ensure collection exists
self._ensure_collection_exists()
def _ensure_collection_exists(self):
"""Ensure the collection exists, create if it doesn't."""
try:
collections = self.client.get_collections()
collection_names = [col.name for col in collections.collections]
if self.collection_name not in collection_names:
logger.info(f"Creating collection: {self.collection_name}")
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=512, # Typical size for image embeddings
distance=Distance.COSINE
)
)
logger.info(f"Collection {self.collection_name} created successfully")
else:
logger.info(f"Collection {self.collection_name} already exists")
except Exception as e:
logger.error(f"Error ensuring collection exists: {e}")
raise
def add_image_vector(
self,
image_id: str,
vector: List[float],
metadata: Dict[str, Any] = None
) -> str:
"""
Add an image vector to the database.
Args:
image_id: Unique identifier for the image
vector: Image embedding vector
metadata: Additional metadata for the image
Returns:
Point ID in the vector database
"""
try:
point_id = str(uuid.uuid4())
payload = {
"image_id": image_id,
"timestamp": metadata.get("timestamp") if metadata else None,
"filename": metadata.get("filename") if metadata else None,
"size": metadata.get("size") if metadata else None,
"format": metadata.get("format") if metadata else None,
**(metadata or {})
}
point = PointStruct(
id=point_id,
vector=vector,
payload=payload
)
self.client.upsert(
collection_name=self.collection_name,
points=[point]
)
logger.info(f"Added vector for image {image_id} with point ID {point_id}")
return point_id
except Exception as e:
logger.error(f"Error adding image vector: {e}")
raise
def search_similar_images(
self,
query_vector: List[float],
limit: int = 10,
score_threshold: float = 0.7,
filter_conditions: Dict[str, Any] = None
) -> List[Dict[str, Any]]:
"""
Search for similar images based on vector similarity.
Args:
query_vector: Query vector to search for
limit: Maximum number of results to return
score_threshold: Minimum similarity score threshold
filter_conditions: Additional filter conditions
Returns:
List of similar images with scores and metadata
"""
try:
search_filter = None
if filter_conditions:
search_filter = models.Filter(
must=[
models.FieldCondition(
key=key,
match=models.MatchValue(value=value)
)
for key, value in filter_conditions.items()
]
)
search_result = self.client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=search_filter,
limit=limit,
score_threshold=score_threshold
)
results = []
for hit in search_result:
result = {
"point_id": hit.id,
"score": hit.score,
"image_id": hit.payload.get("image_id"),
"metadata": hit.payload
}
results.append(result)
logger.info(f"Found {len(results)} similar images")
return results
except Exception as e:
logger.error(f"Error searching similar images: {e}")
raise
def get_image_vector(self, image_id: str) -> Optional[Dict[str, Any]]:
"""
Get vector and metadata for a specific image.
Args:
image_id: Image identifier
Returns:
Vector data and metadata if found, None otherwise
"""
try:
search_result = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="image_id",
match=models.MatchValue(value=image_id)
)
]
),
limit=1,
with_vectors=True
)
if search_result[0]: # search_result is a tuple (points, next_page_offset)
point = search_result[0][0]
return {
"point_id": point.id,
"vector": point.vector,
"image_id": point.payload.get("image_id"),
"metadata": point.payload
}
return None
except Exception as e:
logger.error(f"Error getting image vector: {e}")
raise
def delete_image_vector(self, image_id: str) -> bool:
"""
Delete vector for a specific image.
Args:
image_id: Image identifier
Returns:
True if deleted successfully, False otherwise
"""
try:
# First find the point ID
search_result = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="image_id",
match=models.MatchValue(value=image_id)
)
]
),
limit=1
)
if search_result[0]:
point_id = search_result[0][0].id
self.client.delete(
collection_name=self.collection_name,
points_selector=models.PointIdsList(
points=[point_id]
)
)
logger.info(f"Deleted vector for image {image_id}")
return True
logger.warning(f"No vector found for image {image_id}")
return False
except Exception as e:
logger.error(f"Error deleting image vector: {e}")
raise
def get_collection_info(self) -> Dict[str, Any]:
"""
Get information about the collection.
Returns:
Collection information including count and configuration
"""
try:
collection_info = self.client.get_collection(self.collection_name)
# Handle different response formats safely
result = {
"name": self.collection_name,
"points_count": getattr(collection_info, 'points_count', 0),
"status": getattr(collection_info, 'status', 'unknown')
}
# Safely access vector configuration
if hasattr(collection_info, 'config') and collection_info.config:
config = collection_info.config
if hasattr(config, 'params') and config.params:
params = config.params
if hasattr(params, 'vectors') and params.vectors:
vectors_config = params.vectors
if hasattr(vectors_config, 'size'):
result["vector_size"] = vectors_config.size
if hasattr(vectors_config, 'distance'):
result["distance"] = str(vectors_config.distance)
return result
except Exception as e:
logger.error(f"Error getting collection info: {e}")
# Return basic info if detailed info fails
return {
"name": self.collection_name,
"points_count": 0,
"status": "error",
"error": str(e)
}
def health_check(self) -> bool:
"""
Check if the vector database is healthy.
Returns:
True if healthy, False otherwise
"""
try:
collections = self.client.get_collections()
return True
except Exception as e:
logger.error(f"Vector database health check failed: {e}")
return False
# Utility functions for vector operations
def normalize_vector(vector: List[float]) -> List[float]:
"""Normalize a vector to unit length."""
vector_array = np.array(vector)
norm = np.linalg.norm(vector_array)
if norm == 0:
return vector
return (vector_array / norm).tolist()
def cosine_similarity(vector1: List[float], vector2: List[float]) -> float:
"""Calculate cosine similarity between two vectors."""
v1 = np.array(vector1)
v2 = np.array(vector2)
dot_product = np.dot(v1, v2)
norm_v1 = np.linalg.norm(v1)
norm_v2 = np.linalg.norm(v2)
if norm_v1 == 0 or norm_v2 == 0:
return 0.0
return dot_product / (norm_v1 * norm_v2)