import json import logging import base64 from datetime import datetime from typing import Dict, Any, Optional import functions_framework from google.cloud import vision from google.cloud import firestore from google.cloud import storage from qdrant_client import QdrantClient from qdrant_client.http import models from qdrant_client.http.models import Distance, VectorParams, PointStruct import numpy as np from PIL import Image import io import os import uuid # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize clients vision_client = vision.ImageAnnotatorClient() # Get Firestore configuration from environment variables FIRESTORE_PROJECT_ID = os.environ.get('FIRESTORE_PROJECT_ID') FIRESTORE_DATABASE_NAME = os.environ.get('FIRESTORE_DATABASE_NAME', '(default)') # Initialize Firestore client with correct project and database if FIRESTORE_PROJECT_ID: firestore_client = firestore.Client(project=FIRESTORE_PROJECT_ID, database=FIRESTORE_DATABASE_NAME) else: firestore_client = firestore.Client(database=FIRESTORE_DATABASE_NAME) storage_client = storage.Client() # Get bucket name from environment variable GCS_BUCKET_NAME = os.environ.get('GCS_BUCKET_NAME', 'sereact-images') # Initialize Qdrant QDRANT_HOST = os.environ.get('QDRANT_HOST', 'localhost') QDRANT_PORT = int(os.environ.get('QDRANT_PORT', '6333')) QDRANT_API_KEY = os.environ.get('QDRANT_API_KEY') QDRANT_COLLECTION = os.environ.get('QDRANT_COLLECTION', 'image_vectors') QDRANT_HTTPS = os.environ.get('QDRANT_HTTPS', 'false').lower() == 'true' try: qdrant_client = QdrantClient( host=QDRANT_HOST, port=QDRANT_PORT, api_key=QDRANT_API_KEY, https=QDRANT_HTTPS ) # Ensure collection exists try: collections = qdrant_client.get_collections() collection_names = [col.name for col in collections.collections] if QDRANT_COLLECTION not in collection_names: logger.info(f"Creating Qdrant collection: {QDRANT_COLLECTION}") qdrant_client.create_collection( collection_name=QDRANT_COLLECTION, vectors_config=VectorParams( size=512, # Fixed size for image embeddings distance=Distance.COSINE ) ) logger.info(f"Collection {QDRANT_COLLECTION} created successfully") else: logger.info(f"Collection {QDRANT_COLLECTION} already exists") except Exception as e: logger.error(f"Error ensuring Qdrant collection exists: {e}") qdrant_client = None except Exception as e: logger.error(f"Failed to initialize Qdrant client: {e}") qdrant_client = None if not qdrant_client: logger.warning("Qdrant not configured, embeddings will not be stored") @functions_framework.cloud_event def process_image_embedding(cloud_event): """ Cloud Function triggered by Pub/Sub to process image embeddings """ try: # Decode the Pub/Sub message message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode('utf-8') message = json.loads(message_data) logger.info(f"Processing image embedding task: {message}") # Extract task data image_id = message.get('image_id') storage_path = message.get('storage_path') team_id = message.get('team_id') retry_count = message.get('retry_count', 0) if not all([image_id, storage_path, team_id]): logger.error(f"Missing required fields in message: {message}") return # Update image status to processing update_image_status(image_id, 'processing', retry_count) # Process the image success = process_image(image_id, storage_path, team_id, retry_count) if success: logger.info(f"Successfully processed image {image_id}") update_image_status(image_id, 'success', retry_count) else: logger.error(f"Failed to process image {image_id}") update_image_status(image_id, 'failed', retry_count, "Processing failed") # Retry logic is handled by Pub/Sub retry policy # The function will be retried automatically up to 3 times except Exception as e: logger.error(f"Error in process_image_embedding: {e}") # Extract image_id if possible for status update try: message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode('utf-8') message = json.loads(message_data) image_id = message.get('image_id') retry_count = message.get('retry_count', 0) if image_id: update_image_status(image_id, 'failed', retry_count, str(e)) except: pass raise e def process_image(image_id: str, storage_path: str, team_id: str, retry_count: int) -> bool: """ Process a single image to generate embeddings Args: image_id: The ID of the image to process storage_path: The GCS path of the image team_id: The team ID that owns the image retry_count: Current retry count Returns: True if processing was successful, False otherwise """ try: # Download image from Cloud Storage # The storage_path is just the path within the bucket (e.g., "team_id/filename.jpg") # The bucket name comes from the environment variable bucket = storage_client.bucket(GCS_BUCKET_NAME) blob = bucket.blob(storage_path) if not blob.exists(): logger.error(f"Image not found in storage: {GCS_BUCKET_NAME}/{storage_path}") return False # Download image data image_data = blob.download_as_bytes() # Generate embeddings using Google Cloud Vision embeddings = generate_image_embeddings(image_data) if embeddings is None: logger.error(f"Failed to generate embeddings for image {image_id}") return False # Store embeddings in Qdrant if qdrant_client: point_id = str(uuid.uuid4()) # Prepare metadata metadata = { 'image_id': image_id, 'team_id': team_id, 'storage_path': storage_path, 'created_at': datetime.utcnow().isoformat(), 'model': 'google-vision-v1' } # Create point for Qdrant point = PointStruct( id=point_id, vector=embeddings.tolist(), payload=metadata ) # Upsert to Qdrant qdrant_client.upsert( collection_name=QDRANT_COLLECTION, points=[point] ) logger.info(f"Stored embeddings for image {image_id} in Qdrant with point ID {point_id}") # Update Firestore with embedding info update_image_embedding_info(image_id, point_id, 'google-vision-v1') return True except Exception as e: logger.error(f"Error processing image {image_id}: {e}") return False def generate_image_embeddings(image_data: bytes) -> Optional[np.ndarray]: """ Generate image embeddings using Google Cloud Vision API Args: image_data: Binary image data Returns: Numpy array of embeddings or None if failed """ try: # Create Vision API image object image = vision.Image(content=image_data) # Use object localization to get feature vectors # This provides rich semantic information about the image response = vision_client.object_localization(image=image) if response.error.message: logger.error(f"Vision API error: {response.error.message}") return None # Extract features from detected objects features = [] # Get object detection features for obj in response.localized_object_annotations: # Use object name and confidence as features features.extend([ hash(obj.name) % 1000 / 1000.0, # Normalized hash of object name obj.score, # Confidence score obj.bounding_poly.normalized_vertices[0].x, # Bounding box features obj.bounding_poly.normalized_vertices[0].y, obj.bounding_poly.normalized_vertices[2].x - obj.bounding_poly.normalized_vertices[0].x, # Width obj.bounding_poly.normalized_vertices[2].y - obj.bounding_poly.normalized_vertices[0].y, # Height ]) # Also get label detection for additional semantic information label_response = vision_client.label_detection(image=image) for label in label_response.label_annotations[:10]: # Top 10 labels features.extend([ hash(label.description) % 1000 / 1000.0, # Normalized hash of label label.score # Confidence score ]) # Get text detection for additional context text_response = vision_client.text_detection(image=image) if text_response.text_annotations: # Add text features text_content = text_response.text_annotations[0].description if text_response.text_annotations else "" text_hash = hash(text_content.lower()) % 1000 / 1000.0 features.extend([text_hash, len(text_content) / 1000.0]) # Normalized text length # Get face detection for additional features face_response = vision_client.face_detection(image=image) face_count = len(face_response.face_annotations) features.append(min(face_count / 10.0, 1.0)) # Normalized face count # Add image properties try: # Get image properties properties_response = vision_client.image_properties(image=image) if properties_response.image_properties_annotation: # Add dominant colors as features colors = properties_response.image_properties_annotation.dominant_colors.colors for i, color in enumerate(colors[:5]): # Top 5 colors features.extend([ color.color.red / 255.0, color.color.green / 255.0, color.color.blue / 255.0, color.score ]) except Exception as e: logger.warning(f"Could not extract image properties: {e}") # Pad or truncate to fixed size (512 dimensions) target_size = 512 if len(features) < target_size: features.extend([0.0] * (target_size - len(features))) else: features = features[:target_size] # Normalize the feature vector features_array = np.array(features, dtype=np.float32) norm = np.linalg.norm(features_array) if norm > 0: features_array = features_array / norm return features_array except Exception as e: logger.error(f"Error generating embeddings: {e}") return None def update_image_status(image_id: str, status: str, retry_count: int, error_message: str = None): """ Update the image embedding status in Firestore Args: image_id: The ID of the image status: The new status (processing, success, failed) retry_count: Current retry count error_message: Error message if status is failed """ try: doc_ref = firestore_client.collection('images').document(image_id) update_data = { 'embedding_status': status, 'embedding_retry_count': retry_count, 'embedding_last_attempt': datetime.utcnow() } if error_message: update_data['embedding_error'] = error_message if status == 'success': update_data['has_embedding'] = True update_data['embedding_error'] = None # Clear any previous error doc_ref.update(update_data) logger.info(f"Updated image {image_id} status to {status}") except Exception as e: logger.error(f"Error updating image status: {e}") def update_image_embedding_info(image_id: str, point_id: str, model: str): """ Update the image with embedding information Args: image_id: The ID of the image point_id: The ID of the point in the Qdrant vector database model: The model used to generate embeddings """ try: doc_ref = firestore_client.collection('images').document(image_id) update_data = { 'embedding_point_id': point_id, 'embedding_model': model, 'has_embedding': True } doc_ref.update(update_data) logger.info(f"Updated image {image_id} with embedding info") except Exception as e: logger.error(f"Error updating image embedding info: {e}")