import json import logging import base64 from datetime import datetime from typing import Dict, Any, Optional import functions_framework import vertexai from vertexai.vision_models import MultiModalEmbeddingModel, Image as VertexImage from google.cloud import firestore from google.cloud import storage from qdrant_client import QdrantClient from qdrant_client.http import models from qdrant_client.http.models import Distance, VectorParams, PointStruct import numpy as np from PIL import Image import io import os import uuid # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize Vertex AI PROJECT_ID = os.environ.get('GOOGLE_CLOUD_PROJECT') or os.environ.get('GCP_PROJECT') LOCATION = os.environ.get('VERTEX_AI_LOCATION', 'us-central1') if PROJECT_ID: vertexai.init(project=PROJECT_ID, location=LOCATION) logger.info(f"Initialized Vertex AI with project {PROJECT_ID} in location {LOCATION}") else: logger.error("PROJECT_ID not found in environment variables") # Get Firestore configuration from environment variables FIRESTORE_PROJECT_ID = os.environ.get('FIRESTORE_PROJECT_ID') FIRESTORE_DATABASE_NAME = os.environ.get('FIRESTORE_DATABASE_NAME', '(default)') # Initialize Firestore client with correct project and database if FIRESTORE_PROJECT_ID: firestore_client = firestore.Client(project=FIRESTORE_PROJECT_ID, database=FIRESTORE_DATABASE_NAME) else: firestore_client = firestore.Client(database=FIRESTORE_DATABASE_NAME) storage_client = storage.Client() # Get bucket name from environment variable GCS_BUCKET_NAME = os.environ.get('GCS_BUCKET_NAME', 'sereact-images') # Initialize Qdrant QDRANT_HOST = os.environ.get('QDRANT_HOST', 'localhost') QDRANT_PORT = int(os.environ.get('QDRANT_PORT', '6333')) QDRANT_API_KEY = os.environ.get('QDRANT_API_KEY') QDRANT_COLLECTION = os.environ.get('QDRANT_COLLECTION', 'image_vectors') QDRANT_HTTPS = os.environ.get('QDRANT_HTTPS', 'false').lower() == 'true' try: qdrant_client = QdrantClient( host=QDRANT_HOST, port=QDRANT_PORT, api_key=QDRANT_API_KEY, https=QDRANT_HTTPS ) # Ensure collection exists try: collections = qdrant_client.get_collections() collection_names = [col.name for col in collections.collections] if QDRANT_COLLECTION not in collection_names: logger.info(f"Creating Qdrant collection: {QDRANT_COLLECTION}") qdrant_client.create_collection( collection_name=QDRANT_COLLECTION, vectors_config=VectorParams( size=1408, # Vertex AI multimodal embedding size distance=Distance.COSINE ) ) logger.info(f"Collection {QDRANT_COLLECTION} created successfully") else: logger.info(f"Collection {QDRANT_COLLECTION} already exists") except Exception as e: logger.error(f"Error ensuring Qdrant collection exists: {e}") qdrant_client = None except Exception as e: logger.error(f"Failed to initialize Qdrant client: {e}") qdrant_client = None if not qdrant_client: logger.warning("Qdrant not configured, embeddings will not be stored") @functions_framework.cloud_event def process_image_embedding(cloud_event): """ Cloud Function triggered by Pub/Sub to process image embeddings """ try: # Decode the Pub/Sub message message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode('utf-8') message = json.loads(message_data) logger.info(f"Processing image embedding task: {message}") # Extract task data image_id = message.get('image_id') storage_path = message.get('storage_path') team_id = message.get('team_id') retry_count = message.get('retry_count', 0) if not all([image_id, storage_path, team_id]): logger.error(f"Missing required fields in message: {message}") return # Update image status to processing update_image_status(image_id, 'processing', retry_count) # Process the image success = process_image(image_id, storage_path, team_id, retry_count) if success: logger.info(f"Successfully processed image {image_id}") update_image_status(image_id, 'success', retry_count) else: logger.error(f"Failed to process image {image_id}") update_image_status(image_id, 'failed', retry_count, "Processing failed") # Retry logic is handled by Pub/Sub retry policy # The function will be retried automatically up to 3 times except Exception as e: logger.error(f"Error in process_image_embedding: {e}") # Extract image_id if possible for status update try: message_data = base64.b64decode(cloud_event.data["message"]["data"]).decode('utf-8') message = json.loads(message_data) image_id = message.get('image_id') retry_count = message.get('retry_count', 0) if image_id: update_image_status(image_id, 'failed', retry_count, str(e)) except: pass raise e def process_image(image_id: str, storage_path: str, team_id: str, retry_count: int) -> bool: """ Process a single image to generate embeddings Args: image_id: The ID of the image to process storage_path: The GCS path of the image team_id: The team ID that owns the image retry_count: Current retry count Returns: True if processing was successful, False otherwise """ try: # Download image from Cloud Storage # The storage_path is just the path within the bucket (e.g., "team_id/filename.jpg") # The bucket name comes from the environment variable bucket = storage_client.bucket(GCS_BUCKET_NAME) blob = bucket.blob(storage_path) if not blob.exists(): logger.error(f"Image not found in storage: {GCS_BUCKET_NAME}/{storage_path}") return False # Download image data image_data = blob.download_as_bytes() # Generate embeddings using Vertex AI embeddings = generate_image_embeddings(image_data) if embeddings is None: logger.error(f"Failed to generate embeddings for image {image_id}") return False # Store embeddings in Qdrant if qdrant_client: point_id = str(uuid.uuid4()) # Prepare metadata metadata = { 'image_id': image_id, 'team_id': team_id, 'storage_path': storage_path, 'created_at': datetime.utcnow().isoformat(), 'model': 'vertex-ai-multimodal' } # Create point for Qdrant point = PointStruct( id=point_id, vector=embeddings.tolist(), payload=metadata ) # Upsert to Qdrant qdrant_client.upsert( collection_name=QDRANT_COLLECTION, points=[point] ) logger.info(f"Stored embeddings for image {image_id} in Qdrant with point ID {point_id}") # Update Firestore with embedding info update_image_embedding_info(image_id, point_id, 'vertex-ai-multimodal') return True except Exception as e: logger.error(f"Error processing image {image_id}: {e}") return False def generate_image_embeddings(image_data: bytes) -> Optional[np.ndarray]: """ Generate image embeddings using Vertex AI multimodal embedding model Args: image_data: Binary image data Returns: Numpy array of embeddings or None if failed """ try: # Create Vertex AI image object vertex_image = VertexImage(image_data) # Use multimodal embedding model to get embeddings model = MultiModalEmbeddingModel.from_pretrained("multimodalembedding@001") embeddings = model.get_embeddings(image=vertex_image) if embeddings is None or embeddings.image_embedding is None: logger.error("Failed to generate embeddings - no image embedding returned") return None # Get the image embedding vector embedding_vector = embeddings.image_embedding # Convert to numpy array embeddings_array = np.array(embedding_vector, dtype=np.float32) # Normalize the feature vector norm = np.linalg.norm(embeddings_array) if norm > 0: embeddings_array = embeddings_array / norm logger.info(f"Generated embeddings with shape: {embeddings_array.shape}") return embeddings_array except Exception as e: logger.error(f"Error generating embeddings: {e}") return None def update_image_status(image_id: str, status: str, retry_count: int, error_message: str = None): """ Update the image embedding status in Firestore Args: image_id: The ID of the image status: The new status (processing, success, failed) retry_count: Current retry count error_message: Error message if status is failed """ try: doc_ref = firestore_client.collection('images').document(image_id) update_data = { 'embedding_status': status, 'embedding_retry_count': retry_count, 'embedding_last_attempt': datetime.utcnow() } if error_message: update_data['embedding_error'] = error_message if status == 'success': update_data['has_embedding'] = True update_data['embedding_error'] = None # Clear any previous error doc_ref.update(update_data) logger.info(f"Updated image {image_id} status to {status}") except Exception as e: logger.error(f"Error updating image status: {e}") def update_image_embedding_info(image_id: str, point_id: str, model: str): """ Update the image with embedding information Args: image_id: The ID of the image point_id: The ID of the point in the Qdrant vector database model: The model used to generate embeddings """ try: doc_ref = firestore_client.collection('images').document(image_id) update_data = { 'embedding_point_id': point_id, 'embedding_model': model, 'has_embedding': True } doc_ref.update(update_data) logger.info(f"Updated image {image_id} with embedding info") except Exception as e: logger.error(f"Error updating image embedding info: {e}")