338 lines
12 KiB
Python
338 lines
12 KiB
Python
"""
|
|
Vector Database Service for handling image vectors using Qdrant.
|
|
"""
|
|
|
|
import os
|
|
import logging
|
|
from typing import List, Dict, Any, Optional, Tuple
|
|
import numpy as np
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.http import models
|
|
from qdrant_client.http.models import Distance, VectorParams, PointStruct
|
|
import uuid
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VectorDatabaseService:
|
|
"""Service for managing image vectors in Qdrant vector database."""
|
|
|
|
def __init__(
|
|
self,
|
|
host: str = None,
|
|
port: int = None,
|
|
api_key: str = None,
|
|
collection_name: str = "image_vectors"
|
|
):
|
|
"""
|
|
Initialize the vector database service.
|
|
|
|
Args:
|
|
host: Qdrant server host
|
|
port: Qdrant server port
|
|
api_key: API key for authentication (optional)
|
|
collection_name: Name of the collection to use
|
|
"""
|
|
self.host = host or os.getenv("QDRANT_HOST", "localhost")
|
|
self.port = port or int(os.getenv("QDRANT_PORT", "6333"))
|
|
self.api_key = api_key or os.getenv("QDRANT_API_KEY")
|
|
self.collection_name = collection_name
|
|
|
|
# Initialize Qdrant client
|
|
self.client = QdrantClient(
|
|
host=self.host,
|
|
port=self.port,
|
|
api_key=self.api_key
|
|
)
|
|
|
|
# Ensure collection exists
|
|
self._ensure_collection_exists()
|
|
|
|
def _ensure_collection_exists(self):
|
|
"""Ensure the collection exists, create if it doesn't."""
|
|
try:
|
|
collections = self.client.get_collections()
|
|
collection_names = [col.name for col in collections.collections]
|
|
|
|
if self.collection_name not in collection_names:
|
|
logger.info(f"Creating collection: {self.collection_name}")
|
|
self.client.create_collection(
|
|
collection_name=self.collection_name,
|
|
vectors_config=VectorParams(
|
|
size=512, # Typical size for image embeddings
|
|
distance=Distance.COSINE
|
|
)
|
|
)
|
|
logger.info(f"Collection {self.collection_name} created successfully")
|
|
else:
|
|
logger.info(f"Collection {self.collection_name} already exists")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error ensuring collection exists: {e}")
|
|
raise
|
|
|
|
def add_image_vector(
|
|
self,
|
|
image_id: str,
|
|
vector: List[float],
|
|
metadata: Dict[str, Any] = None
|
|
) -> str:
|
|
"""
|
|
Add an image vector to the database.
|
|
|
|
Args:
|
|
image_id: Unique identifier for the image
|
|
vector: Image embedding vector
|
|
metadata: Additional metadata for the image
|
|
|
|
Returns:
|
|
Point ID in the vector database
|
|
"""
|
|
try:
|
|
point_id = str(uuid.uuid4())
|
|
payload = {
|
|
"image_id": image_id,
|
|
"timestamp": metadata.get("timestamp") if metadata else None,
|
|
"filename": metadata.get("filename") if metadata else None,
|
|
"size": metadata.get("size") if metadata else None,
|
|
"format": metadata.get("format") if metadata else None,
|
|
**(metadata or {})
|
|
}
|
|
|
|
point = PointStruct(
|
|
id=point_id,
|
|
vector=vector,
|
|
payload=payload
|
|
)
|
|
|
|
self.client.upsert(
|
|
collection_name=self.collection_name,
|
|
points=[point]
|
|
)
|
|
|
|
logger.info(f"Added vector for image {image_id} with point ID {point_id}")
|
|
return point_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error adding image vector: {e}")
|
|
raise
|
|
|
|
def search_similar_images(
|
|
self,
|
|
query_vector: List[float],
|
|
limit: int = 10,
|
|
score_threshold: float = 0.7,
|
|
filter_conditions: Dict[str, Any] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Search for similar images based on vector similarity.
|
|
|
|
Args:
|
|
query_vector: Query vector to search for
|
|
limit: Maximum number of results to return
|
|
score_threshold: Minimum similarity score threshold
|
|
filter_conditions: Additional filter conditions
|
|
|
|
Returns:
|
|
List of similar images with scores and metadata
|
|
"""
|
|
try:
|
|
search_filter = None
|
|
if filter_conditions:
|
|
search_filter = models.Filter(
|
|
must=[
|
|
models.FieldCondition(
|
|
key=key,
|
|
match=models.MatchValue(value=value)
|
|
)
|
|
for key, value in filter_conditions.items()
|
|
]
|
|
)
|
|
|
|
search_result = self.client.search(
|
|
collection_name=self.collection_name,
|
|
query_vector=query_vector,
|
|
query_filter=search_filter,
|
|
limit=limit,
|
|
score_threshold=score_threshold
|
|
)
|
|
|
|
results = []
|
|
for hit in search_result:
|
|
result = {
|
|
"point_id": hit.id,
|
|
"score": hit.score,
|
|
"image_id": hit.payload.get("image_id"),
|
|
"metadata": hit.payload
|
|
}
|
|
results.append(result)
|
|
|
|
logger.info(f"Found {len(results)} similar images")
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error searching similar images: {e}")
|
|
raise
|
|
|
|
def get_image_vector(self, image_id: str) -> Optional[Dict[str, Any]]:
|
|
"""
|
|
Get vector and metadata for a specific image.
|
|
|
|
Args:
|
|
image_id: Image identifier
|
|
|
|
Returns:
|
|
Vector data and metadata if found, None otherwise
|
|
"""
|
|
try:
|
|
search_result = self.client.scroll(
|
|
collection_name=self.collection_name,
|
|
scroll_filter=models.Filter(
|
|
must=[
|
|
models.FieldCondition(
|
|
key="image_id",
|
|
match=models.MatchValue(value=image_id)
|
|
)
|
|
]
|
|
),
|
|
limit=1,
|
|
with_vectors=True
|
|
)
|
|
|
|
if search_result[0]: # search_result is a tuple (points, next_page_offset)
|
|
point = search_result[0][0]
|
|
return {
|
|
"point_id": point.id,
|
|
"vector": point.vector,
|
|
"image_id": point.payload.get("image_id"),
|
|
"metadata": point.payload
|
|
}
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting image vector: {e}")
|
|
raise
|
|
|
|
def delete_image_vector(self, image_id: str) -> bool:
|
|
"""
|
|
Delete vector for a specific image.
|
|
|
|
Args:
|
|
image_id: Image identifier
|
|
|
|
Returns:
|
|
True if deleted successfully, False otherwise
|
|
"""
|
|
try:
|
|
# First find the point ID
|
|
search_result = self.client.scroll(
|
|
collection_name=self.collection_name,
|
|
scroll_filter=models.Filter(
|
|
must=[
|
|
models.FieldCondition(
|
|
key="image_id",
|
|
match=models.MatchValue(value=image_id)
|
|
)
|
|
]
|
|
),
|
|
limit=1
|
|
)
|
|
|
|
if search_result[0]:
|
|
point_id = search_result[0][0].id
|
|
self.client.delete(
|
|
collection_name=self.collection_name,
|
|
points_selector=models.PointIdsList(
|
|
points=[point_id]
|
|
)
|
|
)
|
|
logger.info(f"Deleted vector for image {image_id}")
|
|
return True
|
|
|
|
logger.warning(f"No vector found for image {image_id}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error deleting image vector: {e}")
|
|
raise
|
|
|
|
def get_collection_info(self) -> Dict[str, Any]:
|
|
"""
|
|
Get information about the collection.
|
|
|
|
Returns:
|
|
Collection information including count and configuration
|
|
"""
|
|
try:
|
|
collection_info = self.client.get_collection(self.collection_name)
|
|
|
|
# Handle different response formats safely
|
|
result = {
|
|
"name": self.collection_name,
|
|
"points_count": getattr(collection_info, 'points_count', 0),
|
|
"status": getattr(collection_info, 'status', 'unknown')
|
|
}
|
|
|
|
# Safely access vector configuration
|
|
if hasattr(collection_info, 'config') and collection_info.config:
|
|
config = collection_info.config
|
|
if hasattr(config, 'params') and config.params:
|
|
params = config.params
|
|
if hasattr(params, 'vectors') and params.vectors:
|
|
vectors_config = params.vectors
|
|
if hasattr(vectors_config, 'size'):
|
|
result["vector_size"] = vectors_config.size
|
|
if hasattr(vectors_config, 'distance'):
|
|
result["distance"] = str(vectors_config.distance)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error getting collection info: {e}")
|
|
# Return basic info if detailed info fails
|
|
return {
|
|
"name": self.collection_name,
|
|
"points_count": 0,
|
|
"status": "error",
|
|
"error": str(e)
|
|
}
|
|
|
|
def health_check(self) -> bool:
|
|
"""
|
|
Check if the vector database is healthy.
|
|
|
|
Returns:
|
|
True if healthy, False otherwise
|
|
"""
|
|
try:
|
|
collections = self.client.get_collections()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Vector database health check failed: {e}")
|
|
return False
|
|
|
|
|
|
# Utility functions for vector operations
|
|
def normalize_vector(vector: List[float]) -> List[float]:
|
|
"""Normalize a vector to unit length."""
|
|
vector_array = np.array(vector)
|
|
norm = np.linalg.norm(vector_array)
|
|
if norm == 0:
|
|
return vector
|
|
return (vector_array / norm).tolist()
|
|
|
|
|
|
def cosine_similarity(vector1: List[float], vector2: List[float]) -> float:
|
|
"""Calculate cosine similarity between two vectors."""
|
|
v1 = np.array(vector1)
|
|
v2 = np.array(vector2)
|
|
|
|
dot_product = np.dot(v1, v2)
|
|
norm_v1 = np.linalg.norm(v1)
|
|
norm_v2 = np.linalg.norm(v2)
|
|
|
|
if norm_v1 == 0 or norm_v2 == 0:
|
|
return 0.0
|
|
|
|
return dot_product / (norm_v1 * norm_v2) |