2025-05-24 17:31:54 +02:00

346 lines
13 KiB
Python

import json
import logging
import base64
from datetime import datetime
from typing import Dict, Any, Optional
import functions_framework
from google.cloud import vision
from google.cloud import firestore
from google.cloud import storage
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams, PointStruct
import numpy as np
from PIL import Image
import io
import os
import uuid
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize clients
vision_client = vision.ImageAnnotatorClient()
firestore_client = firestore.Client()
storage_client = storage.Client()
# Initialize Qdrant
QDRANT_HOST = os.environ.get('QDRANT_HOST', 'localhost')
QDRANT_PORT = int(os.environ.get('QDRANT_PORT', '6333'))
QDRANT_API_KEY = os.environ.get('QDRANT_API_KEY')
QDRANT_COLLECTION = os.environ.get('QDRANT_COLLECTION', 'image_vectors')
try:
qdrant_client = QdrantClient(
host=QDRANT_HOST,
port=QDRANT_PORT,
api_key=QDRANT_API_KEY
)
# Ensure collection exists
try:
collections = qdrant_client.get_collections()
collection_names = [col.name for col in collections.collections]
if QDRANT_COLLECTION not in collection_names:
logger.info(f"Creating Qdrant collection: {QDRANT_COLLECTION}")
qdrant_client.create_collection(
collection_name=QDRANT_COLLECTION,
vectors_config=VectorParams(
size=512, # Fixed size for image embeddings
distance=Distance.COSINE
)
)
logger.info(f"Collection {QDRANT_COLLECTION} created successfully")
else:
logger.info(f"Collection {QDRANT_COLLECTION} already exists")
except Exception as e:
logger.error(f"Error ensuring Qdrant collection exists: {e}")
qdrant_client = None
except Exception as e:
logger.error(f"Failed to initialize Qdrant client: {e}")
qdrant_client = None
if not qdrant_client:
logger.warning("Qdrant not configured, embeddings will not be stored")
@functions_framework.cloud_event
def process_image_embedding(cloud_event):
"""
Cloud Function triggered by Pub/Sub to process image embeddings
"""
try:
# Decode the Pub/Sub message
message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode('utf-8')
message = json.loads(message_data)
logger.info(f"Processing image embedding task: {message}")
# Extract task data
image_id = message.get('image_id')
storage_path = message.get('storage_path')
team_id = message.get('team_id')
retry_count = message.get('retry_count', 0)
if not all([image_id, storage_path, team_id]):
logger.error(f"Missing required fields in message: {message}")
return
# Update image status to processing
update_image_status(image_id, 'processing', retry_count)
# Process the image
success = process_image(image_id, storage_path, team_id, retry_count)
if success:
logger.info(f"Successfully processed image {image_id}")
update_image_status(image_id, 'success', retry_count)
else:
logger.error(f"Failed to process image {image_id}")
update_image_status(image_id, 'failed', retry_count, "Processing failed")
# Retry logic is handled by Pub/Sub retry policy
# The function will be retried automatically up to 3 times
except Exception as e:
logger.error(f"Error in process_image_embedding: {e}")
# Extract image_id if possible for status update
try:
message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode('utf-8')
message = json.loads(message_data)
image_id = message.get('image_id')
retry_count = message.get('retry_count', 0)
if image_id:
update_image_status(image_id, 'failed', retry_count, str(e))
except:
pass
raise e
def process_image(image_id: str, storage_path: str, team_id: str, retry_count: int) -> bool:
"""
Process a single image to generate embeddings
Args:
image_id: The ID of the image to process
storage_path: The GCS path of the image
team_id: The team ID that owns the image
retry_count: Current retry count
Returns:
True if processing was successful, False otherwise
"""
try:
# Download image from Cloud Storage
bucket_name = storage_path.split('/')[0]
blob_path = '/'.join(storage_path.split('/')[1:])
bucket = storage_client.bucket(bucket_name)
blob = bucket.blob(blob_path)
if not blob.exists():
logger.error(f"Image not found in storage: {storage_path}")
return False
# Download image data
image_data = blob.download_as_bytes()
# Generate embeddings using Google Cloud Vision
embeddings = generate_image_embeddings(image_data)
if embeddings is None:
logger.error(f"Failed to generate embeddings for image {image_id}")
return False
# Store embeddings in Qdrant
if qdrant_client:
point_id = str(uuid.uuid4())
# Prepare metadata
metadata = {
'image_id': image_id,
'team_id': team_id,
'storage_path': storage_path,
'created_at': datetime.utcnow().isoformat(),
'model': 'google-vision-v1'
}
# Create point for Qdrant
point = PointStruct(
id=point_id,
vector=embeddings.tolist(),
payload=metadata
)
# Upsert to Qdrant
qdrant_client.upsert(
collection_name=QDRANT_COLLECTION,
points=[point]
)
logger.info(f"Stored embeddings for image {image_id} in Qdrant with point ID {point_id}")
# Update Firestore with embedding info
update_image_embedding_info(image_id, point_id, 'google-vision-v1')
return True
except Exception as e:
logger.error(f"Error processing image {image_id}: {e}")
return False
def generate_image_embeddings(image_data: bytes) -> Optional[np.ndarray]:
"""
Generate image embeddings using Google Cloud Vision API
Args:
image_data: Binary image data
Returns:
Numpy array of embeddings or None if failed
"""
try:
# Create Vision API image object
image = vision.Image(content=image_data)
# Use object localization to get feature vectors
# This provides rich semantic information about the image
response = vision_client.object_localization(image=image)
if response.error.message:
logger.error(f"Vision API error: {response.error.message}")
return None
# Extract features from detected objects
features = []
# Get object detection features
for obj in response.localized_object_annotations:
# Use object name and confidence as features
features.extend([
hash(obj.name) % 1000 / 1000.0, # Normalized hash of object name
obj.score, # Confidence score
obj.bounding_poly.normalized_vertices[0].x, # Bounding box features
obj.bounding_poly.normalized_vertices[0].y,
obj.bounding_poly.normalized_vertices[2].x - obj.bounding_poly.normalized_vertices[0].x, # Width
obj.bounding_poly.normalized_vertices[2].y - obj.bounding_poly.normalized_vertices[0].y, # Height
])
# Also get label detection for additional semantic information
label_response = vision_client.label_detection(image=image)
for label in label_response.label_annotations[:10]: # Top 10 labels
features.extend([
hash(label.description) % 1000 / 1000.0, # Normalized hash of label
label.score # Confidence score
])
# Get text detection for additional context
text_response = vision_client.text_detection(image=image)
if text_response.text_annotations:
# Add text features
text_content = text_response.text_annotations[0].description if text_response.text_annotations else ""
text_hash = hash(text_content.lower()) % 1000 / 1000.0
features.extend([text_hash, len(text_content) / 1000.0]) # Normalized text length
# Get face detection for additional features
face_response = vision_client.face_detection(image=image)
face_count = len(face_response.face_annotations)
features.append(min(face_count / 10.0, 1.0)) # Normalized face count
# Add image properties
try:
# Get image properties
properties_response = vision_client.image_properties(image=image)
if properties_response.image_properties_annotation:
# Add dominant colors as features
colors = properties_response.image_properties_annotation.dominant_colors.colors
for i, color in enumerate(colors[:5]): # Top 5 colors
features.extend([
color.color.red / 255.0,
color.color.green / 255.0,
color.color.blue / 255.0,
color.score
])
except Exception as e:
logger.warning(f"Could not extract image properties: {e}")
# Pad or truncate to fixed size (512 dimensions)
target_size = 512
if len(features) < target_size:
features.extend([0.0] * (target_size - len(features)))
else:
features = features[:target_size]
# Normalize the feature vector
features_array = np.array(features, dtype=np.float32)
norm = np.linalg.norm(features_array)
if norm > 0:
features_array = features_array / norm
return features_array
except Exception as e:
logger.error(f"Error generating embeddings: {e}")
return None
def update_image_status(image_id: str, status: str, retry_count: int, error_message: str = None):
"""
Update the image embedding status in Firestore
Args:
image_id: The ID of the image
status: The new status (processing, success, failed)
retry_count: Current retry count
error_message: Error message if status is failed
"""
try:
doc_ref = firestore_client.collection('images').document(image_id)
update_data = {
'embedding_status': status,
'embedding_retry_count': retry_count,
'embedding_last_attempt': datetime.utcnow()
}
if error_message:
update_data['embedding_error'] = error_message
if status == 'success':
update_data['has_embedding'] = True
update_data['embedding_error'] = None # Clear any previous error
doc_ref.update(update_data)
logger.info(f"Updated image {image_id} status to {status}")
except Exception as e:
logger.error(f"Error updating image status: {e}")
def update_image_embedding_info(image_id: str, point_id: str, model: str):
"""
Update the image with embedding information
Args:
image_id: The ID of the image
point_id: The ID of the point in the Qdrant vector database
model: The model used to generate embeddings
"""
try:
doc_ref = firestore_client.collection('images').document(image_id)
update_data = {
'embedding_point_id': point_id,
'embedding_model': model,
'has_embedding': True
}
doc_ref.update(update_data)
logger.info(f"Updated image {image_id} with embedding info")
except Exception as e:
logger.error(f"Error updating image embedding info: {e}")